diff --git a/bagel/src/main/scala/spark/bagel/Bagel.scala b/bagel/src/main/scala/spark/bagel/Bagel.scala new file mode 100644 index 0000000000..08ff1d8a01 --- /dev/null +++ b/bagel/src/main/scala/spark/bagel/Bagel.scala @@ -0,0 +1,159 @@ +package spark.bagel + +import spark._ +import spark.SparkContext._ + +import scala.collection.mutable.ArrayBuffer + +object Bagel extends Logging { + def run[V <: Vertex : Manifest, M <: Message : Manifest, C : Manifest, A : Manifest]( + sc: SparkContext, + verts: RDD[(String, V)], + msgs: RDD[(String, M)] + )( + combiner: Combiner[M, C] = new DefaultCombiner[M], + aggregator: Aggregator[V, A] = new NullAggregator[V], + superstep: Int = 0, + numSplits: Int = sc.numCores + )( + compute: (V, Option[C], A, Int) => (V, Iterable[M]) + ): RDD[V] = { + + logInfo("Starting superstep "+superstep+".") + val startTime = System.currentTimeMillis + + val aggregated = agg(verts, aggregator) + val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits) + val grouped = verts.groupWith(combinedMsgs) + val (processed, numMsgs, numActiveVerts) = comp[V, M, C](sc, grouped, compute(_, _, aggregated, superstep)) + + val timeTaken = System.currentTimeMillis - startTime + logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) + + // Check stopping condition and iterate + val noActivity = numMsgs == 0 && numActiveVerts == 0 + if (noActivity) { + processed.map { case (id, (vert, msgs)) => vert } + } else { + val newVerts = processed.mapValues { case (vert, msgs) => vert } + val newMsgs = processed.flatMap { + case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) + } + run(sc, newVerts, newMsgs)(combiner, aggregator, superstep + 1, numSplits)(compute) + } + } + + /** + * Aggregates the given vertices using the given aggregator, or does + * nothing if it is a NullAggregator. + */ + def agg[V <: Vertex, A : Manifest](verts: RDD[(String, V)], aggregator: Aggregator[V, A]): A = aggregator match { + case _: NullAggregator[_] => + None + case _ => + verts.map { + case (id, vert) => aggregator.createAggregator(vert) + }.reduce(aggregator.mergeAggregators(_, _)) + } + + /** + * Processes the given vertex-message RDD using the compute + * function. Returns the processed RDD, the number of messages + * created, and the number of active vertices. + */ + def comp[V <: Vertex, M <: Message, C](sc: SparkContext, grouped: RDD[(String, (Seq[V], Seq[C]))], compute: (V, Option[C]) => (V, Iterable[M])): (RDD[(String, (V, Iterable[M]))], Int, Int) = { + var numMsgs = sc.accumulator(0) + var numActiveVerts = sc.accumulator(0) + val processed = grouped.flatMapValues { + case (Seq(), _) => None + case (Seq(v), c) => + val (newVert, newMsgs) = + compute(v, c match { + case Seq(comb) => Some(comb) + case Seq() => None + }) + + numMsgs += newMsgs.size + if (newVert.active) + numActiveVerts += 1 + + Some((newVert, newMsgs)) + }.cache + + // Force evaluation of processed RDD for accurate performance measurements + processed.foreach(x => {}) + + (processed, numMsgs.value, numActiveVerts.value) + } + + /** + * Converts a compute function that doesn't take an aggregator to + * one that does, so it can be passed to Bagel.run. + */ + implicit def addAggregatorArg[ + V <: Vertex : Manifest, M <: Message : Manifest, C + ]( + compute: (V, Option[C], Int) => (V, Iterable[M]) + ): (V, Option[C], Option[Nothing], Int) => (V, Iterable[M]) = { + (vert: V, messages: Option[C], aggregator: Option[Nothing], superstep: Int) => compute(vert, messages, superstep) + } +} + +// TODO: Simplify Combiner interface and make it more OO. +trait Combiner[M, C] { + def createCombiner(msg: M): C + def mergeMsg(combiner: C, msg: M): C + def mergeCombiners(a: C, b: C): C +} + +trait Aggregator[V, A] { + def createAggregator(vert: V): A + def mergeAggregators(a: A, b: A): A +} + +@serializable +class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] { + def createCombiner(msg: M): ArrayBuffer[M] = + ArrayBuffer(msg) + def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] = + combiner += msg + def mergeCombiners(a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] = + a ++= b +} + +@serializable +class NullAggregator[V] extends Aggregator[V, Option[Nothing]] { + def createAggregator(vert: V): Option[Nothing] = None + def mergeAggregators(a: Option[Nothing], b: Option[Nothing]): Option[Nothing] = None +} + +/** + * Represents a Bagel vertex. + * + * Subclasses may store state along with each vertex and must be + * annotated with @serializable. + */ +trait Vertex { + def id: String + def active: Boolean +} + +/** + * Represents a Bagel message to a target vertex. + * + * Subclasses may contain a payload to deliver to the target vertex + * and must be annotated with @serializable. + */ +trait Message { + def targetId: String +} + +/** + * Represents a directed edge between two vertices. + * + * Subclasses may store state along each edge and must be annotated + * with @serializable. + */ +trait Edge { + def targetId: String +} diff --git a/bagel/src/main/scala/spark/bagel/examples/ShortestPath.scala b/bagel/src/main/scala/spark/bagel/examples/ShortestPath.scala new file mode 100644 index 0000000000..a7fd386310 --- /dev/null +++ b/bagel/src/main/scala/spark/bagel/examples/ShortestPath.scala @@ -0,0 +1,96 @@ +package spark.bagel.examples + +import spark._ +import spark.SparkContext._ + +import scala.math.min + +import spark.bagel._ +import spark.bagel.Bagel._ + +object ShortestPath { + def main(args: Array[String]) { + if (args.length < 4) { + System.err.println("Usage: ShortestPath " + + " ") + System.exit(-1) + } + + val graphFile = args(0) + val startVertex = args(1) + val numSplits = args(2).toInt + val host = args(3) + val sc = new SparkContext(host, "ShortestPath") + + // Parse the graph data from a file into two RDDs, vertices and messages + val lines = + (sc.textFile(graphFile) + .filter(!_.matches("^\\s*#.*")) + .map(line => line.split("\t"))) + + val vertices: RDD[(String, SPVertex)] = + (lines.groupBy(line => line(0)) + .map { + case (vertexId, lines) => { + val outEdges = lines.collect { + case Array(_, targetId, edgeValue) => + new SPEdge(targetId, edgeValue.toInt) + } + + (vertexId, new SPVertex(vertexId, Int.MaxValue, outEdges, true)) + } + }) + + val messages: RDD[(String, SPMessage)] = + (lines.filter(_.length == 2) + .map { + case Array(vertexId, messageValue) => + (vertexId, new SPMessage(vertexId, messageValue.toInt)) + }) + + System.err.println("Read "+vertices.count()+" vertices and "+ + messages.count()+" messages.") + + // Do the computation + val compute = addAggregatorArg { + (self: SPVertex, messageMinValue: Option[Int], superstep: Int) => + val newValue = messageMinValue match { + case Some(minVal) => min(self.value, minVal) + case None => self.value + } + + val outbox = + if (newValue != self.value) + self.outEdges.map(edge => + new SPMessage(edge.targetId, newValue + edge.value)) + else + List() + + (new SPVertex(self.id, newValue, self.outEdges, false), outbox) + } + val result = Bagel.run(sc, vertices, messages)(combiner = MinCombiner, numSplits = numSplits)(compute) + + // Print the result + System.err.println("Shortest path from "+startVertex+" to all vertices:") + val shortest = result.map(vertex => + "%s\t%s\n".format(vertex.id, vertex.value match { + case x if x == Int.MaxValue => "inf" + case x => x + })).collect.mkString + println(shortest) + } +} + +@serializable +object MinCombiner extends Combiner[SPMessage, Int] { + def createCombiner(msg: SPMessage): Int = + msg.value + def mergeMsg(combiner: Int, msg: SPMessage): Int = + min(combiner, msg.value) + def mergeCombiners(a: Int, b: Int): Int = + min(a, b) +} + +@serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex +@serializable class SPEdge(val targetId: String, val value: Int) extends Edge +@serializable class SPMessage(val targetId: String, val value: Int) extends Message diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala new file mode 100644 index 0000000000..1bce5bebad --- /dev/null +++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala @@ -0,0 +1,158 @@ +package spark.bagel.examples + +import spark._ +import spark.SparkContext._ + +import spark.bagel._ +import spark.bagel.Bagel._ + +import scala.collection.mutable.ArrayBuffer +import scala.xml.{XML,NodeSeq} + +import java.io.{Externalizable,ObjectInput,ObjectOutput,DataOutputStream,DataInputStream} + +import com.esotericsoftware.kryo._ + +object WikipediaPageRank { + def main(args: Array[String]) { + if (args.length < 4) { + System.err.println("Usage: WikipediaPageRank []") + System.exit(-1) + } + + System.setProperty("spark.serialization", "spark.KryoSerialization") + System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) + + val inputFile = args(0) + val threshold = args(1).toDouble + val numSplits = args(2).toInt + val host = args(3) + val noCombiner = args.length > 4 && args(4).nonEmpty + val sc = new SparkContext(host, "WikipediaPageRank") + + // Parse the Wikipedia page data into a graph + val input = sc.textFile(inputFile) + + println("Counting vertices...") + val numVertices = input.count() + println("Done counting vertices.") + + println("Parsing input file...") + val vertices: RDD[(String, PRVertex)] = input.map(line => { + val fields = line.split("\t") + val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) + val links = + if (body == "\\N") + NodeSeq.Empty + else + try { + XML.loadString(body) \\ "link" \ "target" + } catch { + case e: org.xml.sax.SAXParseException => + System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) + NodeSeq.Empty + } + val outEdges = ArrayBuffer(links.map(link => new PREdge(new String(link.text))): _*) + val id = new String(title) + (id, new PRVertex(id, 1.0 / numVertices, outEdges, true)) + }).cache + println("Done parsing input file.") + + // Do the computation + val epsilon = 0.01 / numVertices + val messages = sc.parallelize(List[(String, PRMessage)]()) + val result = + if (noCombiner) { + Bagel.run(sc, vertices, messages)(numSplits = numSplits)(PRNoCombiner.compute(numVertices, epsilon)) + } else { + Bagel.run(sc, vertices, messages)(combiner = PRCombiner, numSplits = numSplits)(PRCombiner.compute(numVertices, epsilon)) + } + + // Print the result + System.err.println("Articles with PageRank >= "+threshold+":") + val top = result.filter(_.value >= threshold).map(vertex => + "%s\t%s\n".format(vertex.id, vertex.value)).collect.mkString + println(top) + } +} + +@serializable +object PRCombiner extends Combiner[PRMessage, Double] { + def createCombiner(msg: PRMessage): Double = + msg.value + def mergeMsg(combiner: Double, msg: PRMessage): Double = + combiner + msg.value + def mergeCombiners(a: Double, b: Double): Double = + a + b + + def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = { + val newValue = messageSum match { + case Some(msgSum) if msgSum != 0 => + 0.15 / numVertices + 0.85 * msgSum + case _ => self.value + } + + val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30 + + val outbox = + if (!terminate) + self.outEdges.map(edge => + new PRMessage(edge.targetId, newValue / self.outEdges.size)) + else + ArrayBuffer[PRMessage]() + + (new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox) + } +} + +@serializable +object PRNoCombiner extends DefaultCombiner[PRMessage] { + def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) = + PRCombiner.compute(numVertices, epsilon)(self, messages match { + case Some(msgs) => Some(msgs.map(_.value).sum) + case None => None + }, superstep) +} + +@serializable class PRVertex() extends Vertex { + var id: String = _ + var value: Double = _ + var outEdges: ArrayBuffer[PREdge] = _ + var active: Boolean = true + + def this(id: String, value: Double, outEdges: ArrayBuffer[PREdge], active: Boolean) { + this() + this.id = id + this.value = value + this.outEdges = outEdges + this.active = active + } +} + +@serializable class PRMessage() extends Message { + var targetId: String = _ + var value: Double = _ + + def this(targetId: String, value: Double) { + this() + this.targetId = targetId + this.value = value + } +} + +@serializable class PREdge() extends Edge { + var targetId: String = _ + + def this(targetId: String) { + this() + this.targetId = targetId + } +} + +class PRKryoRegistrator extends KryoRegistrator { + def registerClasses(kryo: Kryo) { + kryo.register(classOf[PRVertex]) + kryo.register(classOf[PRMessage]) + kryo.register(classOf[PREdge]) + } +} diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala new file mode 100644 index 0000000000..1b47fc9ed5 --- /dev/null +++ b/bagel/src/test/scala/bagel/BagelSuite.scala @@ -0,0 +1,53 @@ +package spark.bagel + +import org.scalatest.{FunSuite, Assertions} +import org.scalatest.prop.Checkers +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +import scala.collection.mutable.ArrayBuffer + +import spark._ + +import spark.bagel.Bagel._ + +@serializable class TestVertex(val id: String, val active: Boolean, val age: Int) extends Vertex +@serializable class TestMessage(val targetId: String) extends Message + +class BagelSuite extends FunSuite with Assertions { + test("halting by voting") { + val sc = new SparkContext("local", "test") + val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(id, true, 0)))) + val msgs = sc.parallelize(Array[(String, TestMessage)]()) + val numSupersteps = 5 + val result = + Bagel.run(sc, verts, msgs)()(addAggregatorArg { + (self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) => + (new TestVertex(self.id, superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) + }) + for (vert <- result.collect) + assert(vert.age === numSupersteps) + } + + test("halting by message silence") { + val sc = new SparkContext("local", "test") + val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(id, false, 0)))) + val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) + val numSupersteps = 5 + val result = + Bagel.run(sc, verts, msgs)()(addAggregatorArg { + (self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) => + val msgsOut = + msgs match { + case Some(ms) if (superstep < numSupersteps - 1) => + ms + case _ => + new ArrayBuffer[TestMessage]() + } + (new TestVertex(self.id, self.active, self.age + 1), msgsOut) + }) + for (vert <- result.collect) + assert(vert.age === numSupersteps) + } +} diff --git a/project/build/SparkProject.scala b/project/build/SparkProject.scala index 484daf5c50..a6520d1f03 100644 --- a/project/build/SparkProject.scala +++ b/project/build/SparkProject.scala @@ -14,6 +14,8 @@ extends ParentProject(info) with IdeaProject lazy val examples = project("examples", "Spark Examples", new ExamplesProject(_), core) + lazy val bagel = project("bagel", "Bagel", new BagelProject(_), core) + class CoreProject(info: ProjectInfo) extends DefaultProject(info) with Eclipsify with IdeaProject with DepJar with XmlTestReport {} @@ -21,6 +23,10 @@ extends ParentProject(info) with IdeaProject class ExamplesProject(info: ProjectInfo) extends DefaultProject(info) with Eclipsify with IdeaProject {} + + class BagelProject(info: ProjectInfo) + extends DefaultProject(info) with DepJar with XmlTestReport + {} } diff --git a/run b/run index dd656f66cb..d3346f53e7 100755 --- a/run +++ b/run @@ -35,6 +35,7 @@ export JAVA_OPTS CORE_DIR=$FWDIR/core EXAMPLES_DIR=$FWDIR/examples +BAGEL_DIR=$FWDIR/bagel # Build up classpath CLASSPATH="$SPARK_CLASSPATH:$CORE_DIR/target/scala_2.8.1/classes:$MESOS_CLASSPATH" @@ -60,6 +61,7 @@ CLASSPATH+=:$EXAMPLES_DIR/target/scala_2.8.1/classes for jar in $CORE_DIR/lib/hadoop-0.20.2/lib/*.jar; do CLASSPATH+=:$jar done +CLASSPATH+=:$BAGEL_DIR/target/scala_2.8.1/classes export CLASSPATH # Needed for spark-shell if [ -n "$SCALA_HOME" ]; then