Adding triangle count code
This commit is contained in:
parent
8719ba83c8
commit
2093a17ff3
|
@ -2,8 +2,6 @@ package org.apache.spark.graph
|
||||||
|
|
||||||
import org.apache.spark._
|
import org.apache.spark._
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The Analytics object contains a collection of basic graph analytics
|
* The Analytics object contains a collection of basic graph analytics
|
||||||
* algorithms that operate largely on the graph structure.
|
* algorithms that operate largely on the graph structure.
|
||||||
|
@ -204,6 +202,65 @@ object Analytics extends Logging {
|
||||||
} // end of connectedComponents
|
} // end of connectedComponents
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute the number of triangles passing through each vertex.
|
||||||
|
*
|
||||||
|
* @param graph
|
||||||
|
* @tparam VD
|
||||||
|
* @tparam ED
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
def triangleCount[VD: ClassManifest, ED: ClassManifest](rawGraph: Graph[VD,ED]):
|
||||||
|
Graph[Int, ED] = {
|
||||||
|
// Remove redundant edges
|
||||||
|
val graph = rawGraph.groupEdges( (a,b) => a ).cache
|
||||||
|
|
||||||
|
// Construct set representations of the neighborhoods
|
||||||
|
val nbrSets: VertexSetRDD[VertexSet] =
|
||||||
|
graph.collectNeighborIds(EdgeDirection.Both).mapValuesWithKeys { (vid, nbrs) =>
|
||||||
|
val set = new VertexSet//(math.ceil(nbrs.size/0.7).toInt)
|
||||||
|
var i = 0
|
||||||
|
while (i < nbrs.size) {
|
||||||
|
// prevent self cycle
|
||||||
|
if(nbrs(i) != vid) set.add(nbrs(i))
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
set
|
||||||
|
}
|
||||||
|
// join the sets with the graph
|
||||||
|
val setGraph: Graph[VertexSet, ED] = graph.outerJoinVertices(nbrSets) {
|
||||||
|
(vid, _, optSet) => optSet.getOrElse(null)
|
||||||
|
}
|
||||||
|
// Edge function computes intersection of smaller vertex with larger vertex
|
||||||
|
def edgeFunc(et: EdgeTriplet[VertexSet, ED]): Array[(Vid, Int)] = {
|
||||||
|
assert(et.srcAttr != null)
|
||||||
|
assert(et.dstAttr != null)
|
||||||
|
val (smallSet, largeSet) =
|
||||||
|
if (et.srcAttr.size < et.dstAttr.size) { (et.srcAttr, et.dstAttr) }
|
||||||
|
else { (et.dstAttr, et.srcAttr) }
|
||||||
|
val iter = smallSet.iterator()
|
||||||
|
var counter: Int = 0
|
||||||
|
while (iter.hasNext) {
|
||||||
|
val vid = iter.next
|
||||||
|
if (vid != et.srcId && vid != et.dstId && largeSet.contains(vid)) { counter += 1 }
|
||||||
|
}
|
||||||
|
Array((et.srcId, counter), (et.dstId, counter))
|
||||||
|
}
|
||||||
|
// compute the intersection along edges
|
||||||
|
val counters: VertexSetRDD[Int] = setGraph.mapReduceTriplets(edgeFunc, _+_)
|
||||||
|
// Merge counters with the graph and divide by two since each triangle is counted twice
|
||||||
|
graph.outerJoinVertices(counters) {
|
||||||
|
(vid, _, optCounter: Option[Int]) =>
|
||||||
|
val dblCount = optCounter.getOrElse(0)
|
||||||
|
// double count should be even (divisible by two)
|
||||||
|
assert((dblCount & 1) == 0 )
|
||||||
|
dblCount/2
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end of TriangleCount
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main(args: Array[String]) = {
|
def main(args: Array[String]) = {
|
||||||
val host = args(0)
|
val host = args(0)
|
||||||
|
@ -277,7 +334,7 @@ object Analytics extends Logging {
|
||||||
|
|
||||||
val sc = new SparkContext(host, "PageRank(" + fname + ")")
|
val sc = new SparkContext(host, "PageRank(" + fname + ")")
|
||||||
|
|
||||||
val graph = GraphLoader.textFile(sc, fname, a => 1.0F,
|
val graph = GraphLoader.edgeListFile(sc, fname,
|
||||||
minEdgePartitions = numEPart, minVertexPartitions = numVPart).cache()
|
minEdgePartitions = numEPart, minVertexPartitions = numVPart).cache()
|
||||||
|
|
||||||
val startTime = System.currentTimeMillis
|
val startTime = System.currentTimeMillis
|
||||||
|
@ -329,7 +386,7 @@ object Analytics extends Logging {
|
||||||
|
|
||||||
val sc = new SparkContext(host, "ConnectedComponents(" + fname + ")")
|
val sc = new SparkContext(host, "ConnectedComponents(" + fname + ")")
|
||||||
//val graph = GraphLoader.textFile(sc, fname, a => 1.0F)
|
//val graph = GraphLoader.textFile(sc, fname, a => 1.0F)
|
||||||
val graph = GraphLoader.textFile(sc, fname, a => 1.0F,
|
val graph = GraphLoader.edgeListFile(sc, fname,
|
||||||
minEdgePartitions = numEPart, minVertexPartitions = numVPart).cache()
|
minEdgePartitions = numEPart, minVertexPartitions = numVPart).cache()
|
||||||
val cc = Analytics.connectedComponents(graph)
|
val cc = Analytics.connectedComponents(graph)
|
||||||
//val cc = if(isDynamic) Analytics.dynamicConnectedComponents(graph, numIter)
|
//val cc = if(isDynamic) Analytics.dynamicConnectedComponents(graph, numIter)
|
||||||
|
@ -338,6 +395,31 @@ object Analytics extends Logging {
|
||||||
|
|
||||||
sc.stop()
|
sc.stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "triangles" => {
|
||||||
|
var numVPart = 4
|
||||||
|
var numEPart = 4
|
||||||
|
|
||||||
|
options.foreach{
|
||||||
|
case ("numEPart", v) => numEPart = v.toInt
|
||||||
|
case ("numVPart", v) => numVPart = v.toInt
|
||||||
|
case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt)
|
||||||
|
}
|
||||||
|
println("======================================")
|
||||||
|
println("| Triangle Count |")
|
||||||
|
println("--------------------------------------")
|
||||||
|
val sc = new SparkContext(host, "TriangleCount(" + fname + ")")
|
||||||
|
//val graph = GraphLoader.textFile(sc, fname, a => 1.0F)
|
||||||
|
val graph = GraphLoader.edgeListFileUndirected(sc, fname,
|
||||||
|
minEdgePartitions = numEPart, minVertexPartitions = numVPart).cache()
|
||||||
|
val triangles = Analytics.triangleCount(graph)
|
||||||
|
//val cc = if(isDynamic) Analytics.dynamicConnectedComponents(graph, numIter)
|
||||||
|
// else Analytics.connectedComponents(graph, numIter)
|
||||||
|
println("Triangles: " + triangles.vertices.map{ case (vid,data) => data.toLong }.reduce(_+_) /3)
|
||||||
|
|
||||||
|
sc.stop()
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// case "shortestpath" => {
|
// case "shortestpath" => {
|
||||||
//
|
//
|
||||||
|
|
|
@ -132,7 +132,7 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
|
||||||
val chain1 = (0 until 9).map(x => (x, x+1) )
|
val chain1 = (0 until 9).map(x => (x, x+1) )
|
||||||
val chain2 = (10 until 20).map(x => (x, x+1) )
|
val chain2 = (10 until 20).map(x => (x, x+1) )
|
||||||
val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
|
val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
|
||||||
val twoChains = Graph(rawEdges)
|
val twoChains = Graph(rawEdges, 1.0)
|
||||||
val ccGraph = Analytics.connectedComponents(twoChains).cache()
|
val ccGraph = Analytics.connectedComponents(twoChains).cache()
|
||||||
val vertices = ccGraph.vertices.collect
|
val vertices = ccGraph.vertices.collect
|
||||||
for ( (id, cc) <- vertices ) {
|
for ( (id, cc) <- vertices ) {
|
||||||
|
@ -153,7 +153,7 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
|
||||||
val chain1 = (0 until 9).map(x => (x, x+1) )
|
val chain1 = (0 until 9).map(x => (x, x+1) )
|
||||||
val chain2 = (10 until 20).map(x => (x, x+1) )
|
val chain2 = (10 until 20).map(x => (x, x+1) )
|
||||||
val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
|
val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
|
||||||
val twoChains = Graph(rawEdges).reverse
|
val twoChains = Graph(rawEdges, true).reverse
|
||||||
val ccGraph = Analytics.connectedComponents(twoChains).cache()
|
val ccGraph = Analytics.connectedComponents(twoChains).cache()
|
||||||
val vertices = ccGraph.vertices.collect
|
val vertices = ccGraph.vertices.collect
|
||||||
for ( (id, cc) <- vertices ) {
|
for ( (id, cc) <- vertices ) {
|
||||||
|
@ -167,8 +167,58 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
|
||||||
else { assert(ccMap(id) === 10) }
|
else { assert(ccMap(id) === 10) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // end of chain connected components
|
} // end of reverse chain connected components
|
||||||
|
|
||||||
|
test("Count a single triangle") {
|
||||||
|
withSpark(new SparkContext("local", "test")) { sc =>
|
||||||
|
val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ), 2)
|
||||||
|
val graph = Graph(rawEdges, true).cache
|
||||||
|
val triangleCount = Analytics.triangleCount(graph)
|
||||||
|
val verts = triangleCount.vertices
|
||||||
|
verts.collect.foreach { case (vid, count) => assert(count === 1) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Count two triangles") {
|
||||||
|
withSpark(new SparkContext("local", "test")) { sc =>
|
||||||
|
val triangles = Array( 0L -> 1L, 1L -> 2L, 2L -> 0L ) ++
|
||||||
|
Array( 0L -> -1L, -1L -> -2L, -2L -> 0L )
|
||||||
|
val rawEdges = sc.parallelize(triangles, 2)
|
||||||
|
val graph = Graph(rawEdges, true).cache
|
||||||
|
val triangleCount = Analytics.triangleCount(graph)
|
||||||
|
val verts = triangleCount.vertices
|
||||||
|
verts.collect.foreach { case (vid, count) =>
|
||||||
|
if(vid == 0) { assert(count === 2) }
|
||||||
|
else { assert(count === 1) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Count two triangles with bi-directed edges") {
|
||||||
|
withSpark(new SparkContext("local", "test")) { sc =>
|
||||||
|
val triangles =
|
||||||
|
Array( 0L -> 1L, 1L -> 2L, 2L -> 0L ) ++
|
||||||
|
Array( 0L -> -1L, -1L -> -2L, -2L -> 0L )
|
||||||
|
val revTriangles = triangles.map { case (a,b) => (b,a) }
|
||||||
|
|
||||||
|
val rawEdges = sc.parallelize(triangles ++ revTriangles, 2)
|
||||||
|
val graph = Graph(rawEdges, true).cache
|
||||||
|
val triangleCount = Analytics.triangleCount(graph)
|
||||||
|
val verts = triangleCount.vertices
|
||||||
|
verts.collect.foreach { case (vid, count) =>
|
||||||
|
if(vid == 0) { assert(count === 4) }
|
||||||
|
else { assert(count === 2) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Count a single triangle with duplicate edges") {
|
||||||
|
withSpark(new SparkContext("local", "test")) { sc =>
|
||||||
|
val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ) ++ Array( 0L->1L, 1L->2L, 2L->0L ), 2)
|
||||||
|
val graph = Graph(rawEdges, true).cache
|
||||||
|
val triangleCount = Analytics.triangleCount(graph)
|
||||||
|
val verts = triangleCount.vertices
|
||||||
|
verts.collect.foreach { case (vid, count) => assert(count === 1) }
|
||||||
|
}
|
||||||
|
}
|
||||||
} // end of AnalyticsSuite
|
} // end of AnalyticsSuite
|
||||||
|
|
Loading…
Reference in a new issue