Merge pull request #91 from amplab/standalone-pagerank

Standalone PageRank
This commit is contained in:
Reynold Xin 2013-12-14 12:52:18 -08:00
commit 9bf192b01c
3 changed files with 114 additions and 64 deletions

View file

@ -54,8 +54,6 @@ object Analytics extends Logging {
taskType match { taskType match {
case "pagerank" => { case "pagerank" => {
var numIter = Int.MaxValue
var isDynamic = false
var tol:Float = 0.001F var tol:Float = 0.001F
var outFname = "" var outFname = ""
var numVPart = 4 var numVPart = 4
@ -63,8 +61,6 @@ object Analytics extends Logging {
var partitionStrategy: PartitionStrategy = RandomVertexCut var partitionStrategy: PartitionStrategy = RandomVertexCut
options.foreach{ options.foreach{
case ("numIter", v) => numIter = v.toInt
case ("dynamic", v) => isDynamic = v.toBoolean
case ("tol", v) => tol = v.toFloat case ("tol", v) => tol = v.toFloat
case ("output", v) => outFname = v case ("output", v) => outFname = v
case ("numVPart", v) => numVPart = v.toInt case ("numVPart", v) => numVPart = v.toInt
@ -73,17 +69,8 @@ object Analytics extends Logging {
case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt) case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt)
} }
if(!isDynamic && numIter == Int.MaxValue) {
println("Set number of iterations!")
sys.exit(1)
}
println("======================================") println("======================================")
println("| PageRank |") println("| PageRank |")
println("--------------------------------------")
println(" Using parameters:")
println(" \tDynamic: " + isDynamic)
if(isDynamic) println(" \t |-> Tolerance: " + tol)
println(" \tNumIter: " + numIter)
println("======================================") println("======================================")
val sc = new SparkContext(host, "PageRank(" + fname + ")") val sc = new SparkContext(host, "PageRank(" + fname + ")")
@ -91,22 +78,18 @@ object Analytics extends Logging {
val graph = GraphLoader.edgeListFile(sc, fname, val graph = GraphLoader.edgeListFile(sc, fname,
minEdgePartitions = numEPart, partitionStrategy = partitionStrategy).cache() minEdgePartitions = numEPart, partitionStrategy = partitionStrategy).cache()
val startTime = System.currentTimeMillis
println("GRAPHX: starting tasks")
println("GRAPHX: Number of vertices " + graph.vertices.count) println("GRAPHX: Number of vertices " + graph.vertices.count)
println("GRAPHX: Number of edges " + graph.edges.count) println("GRAPHX: Number of edges " + graph.edges.count)
//val pr = Analytics.pagerank(graph, numIter) //val pr = Analytics.pagerank(graph, numIter)
val pr = if(isDynamic) PageRank.runUntillConvergence(graph, tol, numIter) val pr = PageRank.runStandalone(graph, tol)
else PageRank.run(graph, numIter)
println("GRAPHX: Total rank: " + pr.vertices.map{ case (id,r) => r }.reduce(_+_) )
if (!outFname.isEmpty) {
println("Saving pageranks of pages to " + outFname)
pr.vertices.map{case (id, r) => id + "\t" + r}.saveAsTextFile(outFname)
}
println("GRAPHX: Runtime: " + ((System.currentTimeMillis - startTime)/1000.0) + " seconds")
Thread.sleep(100000) println("GRAPHX: Total rank: " + pr.map(_._2).reduce(_+_))
if (!outFname.isEmpty) {
logWarning("Saving pageranks of pages to " + outFname)
pr.map{case (id, r) => id + "\t" + r}.saveAsTextFile(outFname)
}
sc.stop() sc.stop()
} }

View file

@ -1,9 +1,10 @@
package org.apache.spark.graph.algorithms package org.apache.spark.graph.algorithms
import org.apache.spark.Logging
import org.apache.spark.graph._ import org.apache.spark.graph._
object PageRank { object PageRank extends Logging {
/** /**
* Run PageRank for a fixed number of iterations returning a graph * Run PageRank for a fixed number of iterations returning a graph
@ -60,7 +61,7 @@ object PageRank {
.mapVertices( (id, attr) => 1.0 ) .mapVertices( (id, attr) => 1.0 )
// Display statistics about pagerank // Display statistics about pagerank
println(pagerankGraph.statistics) logInfo(pagerankGraph.statistics.toString)
// Define the three functions needed to implement PageRank in the GraphX // Define the three functions needed to implement PageRank in the GraphX
// version of Pregel // version of Pregel
@ -124,7 +125,7 @@ object PageRank {
.mapVertices( (id, attr) => (0.0, 0.0) ) .mapVertices( (id, attr) => (0.0, 0.0) )
// Display statistics about pagerank // Display statistics about pagerank
println(pagerankGraph.statistics) logInfo(pagerankGraph.statistics.toString)
// Define the three functions needed to implement PageRank in the GraphX // Define the three functions needed to implement PageRank in the GraphX
// version of Pregel // version of Pregel
@ -151,4 +152,49 @@ object PageRank {
Pregel(pagerankGraph, initialMessage)(vertexProgram, sendMessage, messageCombiner) Pregel(pagerankGraph, initialMessage)(vertexProgram, sendMessage, messageCombiner)
.mapVertices((vid, attr) => attr._1) .mapVertices((vid, attr) => attr._1)
} // end of deltaPageRank } // end of deltaPageRank
def runStandalone[VD: Manifest, ED: Manifest](
graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): VertexRDD[Double] = {
// Initialize the ranks
var ranks: VertexRDD[Double] = graph.vertices.mapValues((vid, attr) => resetProb).cache()
// Initialize the delta graph where each vertex stores its delta and each edge knows its weight
var deltaGraph: Graph[Double, Double] =
graph.outerJoinVertices(graph.outDegrees)((vid, vdata, deg) => deg.getOrElse(0))
.mapTriplets(e => 1.0 / e.srcAttr)
.mapVertices((vid, degree) => resetProb).cache()
var numDeltas: Long = ranks.count()
var i = 0
val weight = (1.0 - resetProb)
while (numDeltas > 0) {
// Compute new deltas
val deltas = deltaGraph
.mapReduceTriplets[Double](
et => {
if (et.srcMask) Iterator((et.dstId, et.srcAttr * et.attr * weight))
else Iterator.empty
},
_ + _)
.filter { case (vid, delta) => delta > tol }
.cache()
numDeltas = deltas.count()
logInfo("Standalone PageRank: iter %d has %d deltas".format(i, numDeltas))
// Apply deltas. Sets the mask for each vertex to false if it does not appear in deltas.
deltaGraph = deltaGraph.deltaJoinVertices(deltas).cache()
// Update ranks
ranks = ranks.leftZipJoin(deltas) { (vid, oldRank, deltaOpt) =>
oldRank + deltaOpt.getOrElse(0.0)
}
ranks.foreach(x => {}) // force the iteration for ease of debugging
i += 1
}
ranks
}
} }

View file

@ -51,35 +51,38 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer") System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
System.setProperty("spark.kryo.registrator", "org.apache.spark.graph.GraphKryoRegistrator") System.setProperty("spark.kryo.registrator", "org.apache.spark.graph.GraphKryoRegistrator")
def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = {
a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) }
.map { case (id, error) => error }.sum
}
test("Star PageRank") { test("Star PageRank") {
withSpark(new SparkContext("local", "test")) { sc => withSpark(new SparkContext("local", "test")) { sc =>
val nVertices = 100 val nVertices = 100
val starGraph = GraphGenerators.starGraph(sc, nVertices).cache() val starGraph = GraphGenerators.starGraph(sc, nVertices).cache()
val resetProb = 0.15 val resetProb = 0.15
val prGraph1 = PageRank.run(starGraph, 1, resetProb) val errorTol = 1.0e-5
val prGraph2 = PageRank.run(starGraph, 2, resetProb)
val notMatching = prGraph1.vertices.zipJoin(prGraph2.vertices) { (vid, pr1, pr2) => val staticRanks1 = PageRank.run(starGraph, numIter = 1, resetProb).vertices.cache()
if (pr1 != pr2) { 1 } else { 0 } val staticRanks2 = PageRank.run(starGraph, numIter = 2, resetProb).vertices.cache()
// Static PageRank should only take 2 iterations to converge
val notMatching = staticRanks1.zipJoin(staticRanks2) { (vid, pr1, pr2) =>
if (pr1 != pr2) 1 else 0
}.map { case (vid, test) => test }.sum }.map { case (vid, test) => test }.sum
assert(notMatching === 0) assert(notMatching === 0)
//prGraph2.vertices.foreach(println(_))
val errors = prGraph2.vertices.map { case (vid, pr) => val staticErrors = staticRanks2.map { case (vid, pr) =>
val correct = (vid > 0 && pr == resetProb) || val correct = (vid > 0 && pr == resetProb) ||
(vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) < 1.0E-5) (vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) < 1.0E-5)
if ( !correct ) { 1 } else { 0 } if (!correct) 1 else 0
} }
assert(errors.sum === 0) assert(staticErrors.sum === 0)
val prGraph3 = PageRank.runUntillConvergence(starGraph, 0, resetProb) val dynamicRanks = PageRank.runUntillConvergence(starGraph, 0, resetProb).vertices.cache()
val errors2 = prGraph2.vertices.leftJoin(prGraph3.vertices){ (vid, pr1, pr2Opt) => val standaloneRanks = PageRank.runStandalone(starGraph, 0, resetProb).cache()
pr2Opt match { assert(compareRanks(staticRanks2, dynamicRanks) < errorTol)
case Some(pr2) if(pr1 == pr2) => 0 assert(compareRanks(staticRanks2, standaloneRanks) < errorTol)
case _ => 1
}
}.map { case (vid, test) => test }.sum
assert(errors2 === 0)
} }
} // end of test Star PageRank } // end of test Star PageRank
@ -87,27 +90,46 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
test("Grid PageRank") { test("Grid PageRank") {
withSpark(new SparkContext("local", "test")) { sc => withSpark(new SparkContext("local", "test")) { sc =>
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).cache() val rows = 10
val cols = 10
val resetProb = 0.15 val resetProb = 0.15
val prGraph1 = PageRank.run(gridGraph, 50, resetProb).cache() val tol = 0.0001
val prGraph2 = PageRank.runUntillConvergence(gridGraph, 0.0001, resetProb).cache() val numIter = 50
val error = prGraph1.vertices.zipJoin(prGraph2.vertices) { case (id, a, b) => (a - b) * (a - b) } val errorTol = 1.0e-5
.map { case (id, error) => error }.sum val gridGraph = GraphGenerators.gridGraph(sc, rows, cols).cache()
//prGraph1.vertices.zipJoin(prGraph2.vertices) { (id, a, b) => (a, b, a-b) }.foreach(println(_))
println(error) val staticRanks = PageRank.run(gridGraph, numIter, resetProb).vertices.cache()
assert(error < 1.0e-5) val dynamicRanks = PageRank.runUntillConvergence(gridGraph, tol, resetProb).vertices.cache()
val pr3: RDD[(Vid, Double)] = sc.parallelize(GridPageRank(10,10, 50, resetProb)) val standaloneRanks = PageRank.runStandalone(gridGraph, tol, resetProb).cache()
val error2 = prGraph1.vertices.leftJoin(pr3) { (id, a, bOpt) => val referenceRanks = VertexRDD(sc.parallelize(GridPageRank(rows, cols, numIter, resetProb)))
val b: Double = bOpt.get
(a - b) * (a - b) assert(compareRanks(staticRanks, referenceRanks) < errorTol)
}.map { case (id, error) => error }.sum assert(compareRanks(dynamicRanks, referenceRanks) < errorTol)
//prGraph1.vertices.leftJoin(pr3) { (id, a, b) => (a, b) }.foreach( println(_) ) assert(compareRanks(standaloneRanks, referenceRanks) < errorTol)
println(error2)
assert(error2 < 1.0e-5)
} }
} // end of Grid PageRank } // end of Grid PageRank
test("Chain PageRank") {
withSpark(new SparkContext("local", "test")) { sc =>
val chain1 = (0 until 9).map(x => (x, x+1) )
val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) }
val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache()
val resetProb = 0.15
val tol = 0.0001
val numIter = 10
val errorTol = 1.0e-5
val staticRanks = PageRank.run(chain, numIter, resetProb).vertices.cache()
val dynamicRanks = PageRank.runUntillConvergence(chain, tol, resetProb).vertices.cache()
val standaloneRanks = PageRank.runStandalone(chain, tol, resetProb).cache()
assert(compareRanks(staticRanks, dynamicRanks) < errorTol)
assert(compareRanks(dynamicRanks, standaloneRanks) < errorTol)
}
}
test("Grid Connected Components") { test("Grid Connected Components") {
withSpark(new SparkContext("local", "test")) { sc => withSpark(new SparkContext("local", "test")) { sc =>
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).cache() val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).cache()
@ -167,7 +189,6 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
} }
} }
val ccMap = vertices.toMap val ccMap = vertices.toMap
println(ccMap)
for ( id <- 0 until 20 ) { for ( id <- 0 until 20 ) {
if (id < 10) { if (id < 10) {
assert(ccMap(id) === 0) assert(ccMap(id) === 0)