Add Bagel, an implementation of Pregel on Spark
This commit is contained in:
parent
94ba95bcb2
commit
c0736f6f68
103
bagel/src/main/scala/bagel/Pregel.scala
Normal file
103
bagel/src/main/scala/bagel/Pregel.scala
Normal file
|
@ -0,0 +1,103 @@
|
|||
package bagel
|
||||
|
||||
import spark._
|
||||
import spark.SparkContext._
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
object Pregel extends Logging {
|
||||
/**
|
||||
* Runs a Pregel job on the given vertices, running the specified
|
||||
* compute function on each vertex in every superstep. Before
|
||||
* beginning the first superstep, sends the given messages to their
|
||||
* destination vertices. In the join stage, launches splits
|
||||
* separate tasks (where splits is manually specified to work
|
||||
* around a bug in Spark).
|
||||
*
|
||||
* Halts when no more messages are being sent between vertices, and
|
||||
* all vertices have voted to halt by setting their state to
|
||||
* Inactive.
|
||||
*/
|
||||
def run[V <: Vertex : Manifest, M <: Message : Manifest, C](sc: SparkContext, verts: RDD[(String, V)], msgs: RDD[(String, M)], splits: Int, messageCombiner: (C, M) => C, defaultCombined: () => C, mergeCombined: (C, C) => C, superstep: Int = 0)(compute: (V, C, Int) => (V, Iterable[M])): RDD[V] = {
|
||||
println("Starting superstep "+superstep+".")
|
||||
val startTime = System.currentTimeMillis
|
||||
|
||||
// Bring together vertices and messages
|
||||
println("Joining vertices and messages...")
|
||||
val combinedMsgs = msgs.combineByKey({x => messageCombiner(defaultCombined(), x)}, messageCombiner, mergeCombined, splits)
|
||||
println("verts.splits.size = " + verts.splits.size)
|
||||
println("combinedMsgs.splits.size = " + combinedMsgs.splits.size)
|
||||
println("verts.partitioner = " + verts.partitioner)
|
||||
println("combinedMsgs.partitioner = " + combinedMsgs.partitioner)
|
||||
val joined = verts.groupWith(combinedMsgs)
|
||||
println("joined.splits.size = " + joined.splits.size)
|
||||
println("joined.partitioner = " + joined.partitioner)
|
||||
//val joined = graph.groupByKeyAsymmetrical(messageCombiner, defaultCombined, mergeCombined, splits)
|
||||
println("Done joining vertices and messages.")
|
||||
|
||||
// Run compute on each vertex
|
||||
println("Running compute on each vertex...")
|
||||
var messageCount = sc.accumulator(0)
|
||||
var activeVertexCount = sc.accumulator(0)
|
||||
val processed = joined.flatMapValues {
|
||||
case (Seq(), _) => None
|
||||
case (Seq(v), Seq(comb)) =>
|
||||
val (newVertex, newMessages) = compute(v, comb, superstep)
|
||||
messageCount += newMessages.size
|
||||
if (newVertex.active)
|
||||
activeVertexCount += 1
|
||||
Some((newVertex, newMessages))
|
||||
//val result = ArrayBuffer[(String, Either[V, M])]((newVertex.id, Left(newVertex)))
|
||||
//result ++= newMessages.map(m => (m.targetId, Right(m)))
|
||||
case (Seq(v), Seq()) =>
|
||||
val (newVertex, newMessages) = compute(v, defaultCombined(), superstep)
|
||||
messageCount += newMessages.size
|
||||
if (newVertex.active)
|
||||
activeVertexCount += 1
|
||||
Some((newVertex, newMessages))
|
||||
}.cache
|
||||
//MATEI: Added this
|
||||
processed.foreach(x => {})
|
||||
println("Done running compute on each vertex.")
|
||||
|
||||
println("Checking stopping condition...")
|
||||
val stop = messageCount.value == 0 && activeVertexCount.value == 0
|
||||
|
||||
val timeTaken = System.currentTimeMillis - startTime
|
||||
println("Superstep %d took %d s".format(superstep, timeTaken / 1000))
|
||||
|
||||
val newVerts = processed.mapValues(_._1)
|
||||
val newMsgs = processed.flatMap(x => x._2._2.map(m => (m.targetId, m)))
|
||||
|
||||
if (superstep >= 10)
|
||||
processed.map { _._2._1 }
|
||||
else
|
||||
run(sc, newVerts, newMsgs, splits, messageCombiner, defaultCombined, mergeCombined, superstep + 1)(compute)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a Pregel vertex. Must be subclassed to store state
|
||||
* along with each vertex. Must be annotated with @serializable.
|
||||
*/
|
||||
trait Vertex {
|
||||
def id: String
|
||||
def active: Boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a Pregel message to a target vertex. Must be
|
||||
* subclassed to contain a payload. Must be annotated with @serializable.
|
||||
*/
|
||||
trait Message {
|
||||
def targetId: String
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a directed edge between two vertices. Owned by the
|
||||
* source vertex, and contains the ID of the target vertex. Must
|
||||
* be subclassed to store state along with each edge. Must be annotated with @serializable.
|
||||
*/
|
||||
trait Edge {
|
||||
def targetId: String
|
||||
}
|
86
bagel/src/main/scala/bagel/ShortestPath.scala
Normal file
86
bagel/src/main/scala/bagel/ShortestPath.scala
Normal file
|
@ -0,0 +1,86 @@
|
|||
package bagel
|
||||
|
||||
import spark._
|
||||
import spark.SparkContext._
|
||||
|
||||
import scala.math.min
|
||||
|
||||
/*
|
||||
object ShortestPath {
|
||||
def main(args: Array[String]) {
|
||||
if (args.length < 4) {
|
||||
System.err.println("Usage: ShortestPath <graphFile> <startVertex> " +
|
||||
"<numSplits> <host>")
|
||||
System.exit(-1)
|
||||
}
|
||||
|
||||
val graphFile = args(0)
|
||||
val startVertex = args(1)
|
||||
val numSplits = args(2).toInt
|
||||
val host = args(3)
|
||||
val sc = new SparkContext(host, "ShortestPath")
|
||||
|
||||
// Parse the graph data from a file into two RDDs, vertices and messages
|
||||
val lines =
|
||||
(sc.textFile(graphFile)
|
||||
.filter(!_.matches("^\\s*#.*"))
|
||||
.map(line => line.split("\t")))
|
||||
|
||||
val vertices: RDD[(String, Either[SPVertex, SPMessage])] =
|
||||
(lines.groupBy(line => line(0))
|
||||
.map {
|
||||
case (vertexId, lines) => {
|
||||
val outEdges = lines.collect {
|
||||
case Array(_, targetId, edgeValue) =>
|
||||
new SPEdge(targetId, edgeValue.toInt)
|
||||
}
|
||||
|
||||
(vertexId, Left[SPVertex, SPMessage](new SPVertex(vertexId, Int.MaxValue, outEdges, true)))
|
||||
}
|
||||
})
|
||||
|
||||
val messages: RDD[(String, Either[SPVertex, SPMessage])] =
|
||||
(lines.filter(_.length == 2)
|
||||
.map {
|
||||
case Array(vertexId, messageValue) =>
|
||||
(vertexId, Right[SPVertex, SPMessage](new SPMessage(vertexId, messageValue.toInt)))
|
||||
})
|
||||
|
||||
val graph: RDD[(String, Either[SPVertex, SPMessage])] = vertices ++ messages
|
||||
|
||||
System.err.println("Read "+vertices.count()+" vertices and "+
|
||||
messages.count()+" messages.")
|
||||
|
||||
// Do the computation
|
||||
def messageCombiner(minSoFar: Int, message: SPMessage): Int =
|
||||
min(minSoFar, message.value)
|
||||
|
||||
val result = Pregel.run(sc, graph, numSplits, messageCombiner, () => Int.MaxValue, min _) {
|
||||
(self: SPVertex, messageMinValue: Int, superstep: Int) =>
|
||||
val newValue = min(self.value, messageMinValue)
|
||||
|
||||
val outbox =
|
||||
if (newValue != self.value)
|
||||
self.outEdges.map(edge =>
|
||||
new SPMessage(edge.targetId, newValue + edge.value))
|
||||
else
|
||||
List()
|
||||
|
||||
(new SPVertex(self.id, newValue, self.outEdges, false), outbox)
|
||||
}
|
||||
|
||||
// Print the result
|
||||
System.err.println("Shortest path from "+startVertex+" to all vertices:")
|
||||
val shortest = result.map(vertex =>
|
||||
"%s\t%s\n".format(vertex.id, vertex.value match {
|
||||
case x if x == Int.MaxValue => "inf"
|
||||
case x => x
|
||||
})).collect.mkString
|
||||
println(shortest)
|
||||
}
|
||||
}
|
||||
|
||||
@serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex
|
||||
@serializable class SPEdge(val targetId: String, val value: Int) extends Edge
|
||||
@serializable class SPMessage(val targetId: String, val value: Int) extends Message
|
||||
*/
|
201
bagel/src/main/scala/bagel/WikipediaPageRank.scala
Normal file
201
bagel/src/main/scala/bagel/WikipediaPageRank.scala
Normal file
|
@ -0,0 +1,201 @@
|
|||
package bagel
|
||||
|
||||
import spark._
|
||||
import spark.SparkContext._
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import scala.xml.{XML,NodeSeq}
|
||||
|
||||
import java.io.{Externalizable,ObjectInput,ObjectOutput,DataOutputStream,DataInputStream}
|
||||
|
||||
import com.esotericsoftware.kryo._
|
||||
|
||||
object WikipediaPageRank {
|
||||
def main(args: Array[String]) {
|
||||
if (args.length < 4) {
|
||||
System.err.println("Usage: PageRank <inputFile> <threshold> <numSplits> <host> [<noCombiner>]")
|
||||
System.exit(-1)
|
||||
}
|
||||
|
||||
System.setProperty("spark.serialization", "spark.KryoSerialization")
|
||||
System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName)
|
||||
|
||||
val inputFile = args(0)
|
||||
val threshold = args(1).toDouble
|
||||
val numSplits = args(2).toInt
|
||||
val host = args(3)
|
||||
val noCombiner = args.length > 4 && args(4).nonEmpty
|
||||
val sc = new SparkContext(host, "WikipediaPageRank")
|
||||
|
||||
// Parse the Wikipedia page data into a graph
|
||||
val input = sc.textFile(inputFile)
|
||||
|
||||
println("Counting vertices...")
|
||||
val numVertices = input.count()
|
||||
println("Done counting vertices.")
|
||||
|
||||
println("Parsing input file...")
|
||||
val vertices: RDD[(String, PRVertex)] = input.map(line => {
|
||||
val fields = line.split("\t")
|
||||
val (title, body) = (fields(1), fields(3).replace("\\n", "\n"))
|
||||
val links =
|
||||
if (body == "\\N")
|
||||
NodeSeq.Empty
|
||||
else
|
||||
try {
|
||||
XML.loadString(body) \\ "link" \ "target"
|
||||
} catch {
|
||||
case e: org.xml.sax.SAXParseException =>
|
||||
System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body)
|
||||
NodeSeq.Empty
|
||||
}
|
||||
val outEdges = ArrayBuffer(links.map(link => new PREdge(new String(link.text))): _*)
|
||||
val id = new String(title)
|
||||
(id, (new PRVertex(id, 1.0 / numVertices, outEdges, true)))
|
||||
})
|
||||
val graph = vertices.groupByKey(numSplits).mapValues(_.head).cache
|
||||
|
||||
println("Done parsing input file.")
|
||||
println("Input file had "+graph.count+" vertices.")
|
||||
|
||||
// Do the computation
|
||||
val epsilon = 0.01 / numVertices
|
||||
val result =
|
||||
if (noCombiner) {
|
||||
val messages = sc.parallelize(List[(String, PRMessage)]())
|
||||
Pregel.run[PRVertex, PRMessage, ArrayBuffer[PRMessage]](sc, graph, messages, numSplits, NoCombiner.messageCombiner, NoCombiner.defaultCombined, NoCombiner.mergeCombined)(NoCombiner.compute(numVertices, epsilon))
|
||||
} else {
|
||||
val messages = sc.parallelize(List[(String, PRMessage)]())
|
||||
Pregel.run[PRVertex, PRMessage, Double](sc, graph, messages, numSplits, Combiner.messageCombiner, Combiner.defaultCombined, Combiner.mergeCombined)(Combiner.compute(numVertices, epsilon))
|
||||
}
|
||||
|
||||
// Print the result
|
||||
System.err.println("Articles with PageRank >= "+threshold+":")
|
||||
val top = result.filter(_.value >= threshold).map(vertex =>
|
||||
"%s\t%s\n".format(vertex.id, vertex.value)).collect.mkString
|
||||
println(top)
|
||||
}
|
||||
|
||||
object Combiner {
|
||||
def messageCombiner(minSoFar: Double, message: PRMessage): Double =
|
||||
minSoFar + message.value
|
||||
|
||||
def mergeCombined(a: Double, b: Double) = a + b
|
||||
|
||||
def defaultCombined(): Double = 0.0
|
||||
|
||||
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Double, superstep: Int): (PRVertex, Iterable[PRMessage]) = {
|
||||
val newValue =
|
||||
if (messageSum != 0)
|
||||
0.15 / numVertices + 0.85 * messageSum
|
||||
else
|
||||
self.value
|
||||
|
||||
val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30
|
||||
|
||||
val outbox =
|
||||
if (!terminate)
|
||||
self.outEdges.map(edge =>
|
||||
new PRMessage(edge.targetId, newValue / self.outEdges.size))
|
||||
else
|
||||
ArrayBuffer[PRMessage]()
|
||||
|
||||
(new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox)
|
||||
}
|
||||
}
|
||||
|
||||
object NoCombiner {
|
||||
def messageCombiner(messagesSoFar: ArrayBuffer[PRMessage], message: PRMessage): ArrayBuffer[PRMessage] =
|
||||
messagesSoFar += message
|
||||
|
||||
def mergeCombined(a: ArrayBuffer[PRMessage], b: ArrayBuffer[PRMessage]): ArrayBuffer[PRMessage] =
|
||||
a ++= b
|
||||
|
||||
def defaultCombined(): ArrayBuffer[PRMessage] = ArrayBuffer[PRMessage]()
|
||||
|
||||
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Seq[PRMessage], superstep: Int): (PRVertex, Iterable[PRMessage]) =
|
||||
Combiner.compute(numVertices, epsilon)(self, messages.map(_.value).sum, superstep)
|
||||
}
|
||||
}
|
||||
|
||||
@serializable class PRVertex() extends Vertex with Externalizable {
|
||||
var id: String = _
|
||||
var value: Double = _
|
||||
var outEdges: ArrayBuffer[PREdge] = _
|
||||
var active: Boolean = true
|
||||
|
||||
def this(id: String, value: Double, outEdges: ArrayBuffer[PREdge], active: Boolean) {
|
||||
this()
|
||||
this.id = id
|
||||
this.value = value
|
||||
this.outEdges = outEdges
|
||||
this.active = active
|
||||
}
|
||||
|
||||
def writeExternal(out: ObjectOutput) {
|
||||
out.writeUTF(id)
|
||||
out.writeDouble(value)
|
||||
out.writeInt(outEdges.length)
|
||||
for (e <- outEdges)
|
||||
out.writeUTF(e.targetId)
|
||||
out.writeBoolean(active)
|
||||
}
|
||||
|
||||
def readExternal(in: ObjectInput) {
|
||||
id = in.readUTF()
|
||||
value = in.readDouble()
|
||||
val numEdges = in.readInt()
|
||||
outEdges = new ArrayBuffer[PREdge](numEdges)
|
||||
for (i <- 0 until numEdges) {
|
||||
outEdges += new PREdge(in.readUTF())
|
||||
}
|
||||
active = in.readBoolean()
|
||||
}
|
||||
}
|
||||
|
||||
@serializable class PRMessage() extends Message with Externalizable {
|
||||
var targetId: String = _
|
||||
var value: Double = _
|
||||
|
||||
def this(targetId: String, value: Double) {
|
||||
this()
|
||||
this.targetId = targetId
|
||||
this.value = value
|
||||
}
|
||||
|
||||
def writeExternal(out: ObjectOutput) {
|
||||
out.writeUTF(targetId)
|
||||
out.writeDouble(value)
|
||||
}
|
||||
|
||||
def readExternal(in: ObjectInput) {
|
||||
targetId = in.readUTF()
|
||||
value = in.readDouble()
|
||||
}
|
||||
}
|
||||
|
||||
@serializable class PREdge() extends Edge with Externalizable {
|
||||
var targetId: String = _
|
||||
|
||||
def this(targetId: String) {
|
||||
this()
|
||||
this.targetId = targetId
|
||||
}
|
||||
|
||||
def writeExternal(out: ObjectOutput) {
|
||||
out.writeUTF(targetId)
|
||||
}
|
||||
|
||||
def readExternal(in: ObjectInput) {
|
||||
targetId = in.readUTF()
|
||||
}
|
||||
}
|
||||
|
||||
class PRKryoRegistrator extends KryoRegistrator {
|
||||
def registerClasses(kryo: Kryo) {
|
||||
kryo.register(classOf[PRVertex])
|
||||
kryo.register(classOf[PRMessage])
|
||||
kryo.register(classOf[PREdge])
|
||||
}
|
||||
}
|
|
@ -14,6 +14,8 @@ extends ParentProject(info) with IdeaProject
|
|||
lazy val examples =
|
||||
project("examples", "Spark Examples", new ExamplesProject(_), core)
|
||||
|
||||
lazy val bagel = project("bagel", "Bagel", core)
|
||||
|
||||
class CoreProject(info: ProjectInfo)
|
||||
extends DefaultProject(info) with Eclipsify with IdeaProject with DepJar with XmlTestReport
|
||||
{}
|
||||
|
|
Loading…
Reference in a new issue