Refactor and add aggregator support
Refactored out the agg() and comp() methods from Pregel.run. Defined an implicit conversion to allow applications that don't use aggregators to avoid including a null argument for the result of the aggregator in the compute function.
This commit is contained in:
parent
c18fa3ebc6
commit
563c5e717c
|
@ -6,37 +6,62 @@ import spark.SparkContext._
|
|||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
object Pregel extends Logging {
|
||||
/**
|
||||
* Runs a Pregel job on the given vertices consisting of the
|
||||
* specified compute function.
|
||||
*
|
||||
* Before beginning the first superstep, the given messages are sent
|
||||
* to their destination vertices.
|
||||
*
|
||||
* During the job, the specified combiner functions are applied to
|
||||
* messages as they travel between vertices.
|
||||
*
|
||||
* The job halts and returns the resulting set of vertices when no
|
||||
* messages are being sent between vertices and all vertices have
|
||||
* voted to halt by setting their state to inactive.
|
||||
*/
|
||||
def run[V <: Vertex : Manifest, M <: Message : Manifest, C](
|
||||
def run[V <: Vertex : Manifest, M <: Message : Manifest, C : Manifest, A : Manifest](
|
||||
sc: SparkContext,
|
||||
verts: RDD[(String, V)],
|
||||
msgs: RDD[(String, M)],
|
||||
combiner: Combiner[M, C],
|
||||
numSplits: Int,
|
||||
superstep: Int = 0
|
||||
)(compute: (V, Option[C], Int) => (V, Iterable[M])): RDD[V] = {
|
||||
msgs: RDD[(String, M)]
|
||||
)(
|
||||
combiner: Combiner[M, C] = new DefaultCombiner[M],
|
||||
aggregator: Aggregator[V, A] = new NullAggregator[V],
|
||||
superstep: Int = 0,
|
||||
numSplits: Int = sc.numCores
|
||||
)(
|
||||
compute: (V, Option[C], A, Int) => (V, Iterable[M])
|
||||
): RDD[V] = {
|
||||
|
||||
logInfo("Starting superstep "+superstep+".")
|
||||
val startTime = System.currentTimeMillis
|
||||
|
||||
// Bring together vertices and messages
|
||||
val aggregated = agg(verts, aggregator)
|
||||
val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits)
|
||||
val grouped = verts.groupWith(combinedMsgs)
|
||||
val (processed, numMsgs, numActiveVerts) = comp[V, M, C](sc, grouped, compute(_, _, aggregated, superstep))
|
||||
|
||||
// Run compute on each vertex
|
||||
val timeTaken = System.currentTimeMillis - startTime
|
||||
logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
|
||||
|
||||
// Check stopping condition and iterate
|
||||
val noActivity = numMsgs == 0 && numActiveVerts == 0
|
||||
if (noActivity) {
|
||||
processed.map { case (id, (vert, msgs)) => vert }
|
||||
} else {
|
||||
val newVerts = processed.mapValues { case (vert, msgs) => vert }
|
||||
val newMsgs = processed.flatMap {
|
||||
case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
|
||||
}
|
||||
run(sc, newVerts, newMsgs)(combiner, aggregator, superstep + 1, numSplits)(compute)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Aggregates the given vertices using the given aggregator, or does
|
||||
* nothing if it is a NullAggregator.
|
||||
*/
|
||||
def agg[V <: Vertex, A : Manifest](verts: RDD[(String, V)], aggregator: Aggregator[V, A]): A = aggregator match {
|
||||
case _: NullAggregator[_] =>
|
||||
None
|
||||
case _ =>
|
||||
verts.map {
|
||||
case (id, vert) => aggregator.createAggregator(vert)
|
||||
}.reduce(aggregator.mergeAggregators(_, _))
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes the given vertex-message RDD using the compute
|
||||
* function. Returns the processed RDD, the number of messages
|
||||
* created, and the number of active vertices.
|
||||
*/
|
||||
def comp[V <: Vertex, M <: Message, C](sc: SparkContext, grouped: RDD[(String, (Seq[V], Seq[C]))], compute: (V, Option[C]) => (V, Iterable[M])): (RDD[(String, (V, Iterable[M]))], Int, Int) = {
|
||||
var numMsgs = sc.accumulator(0)
|
||||
var numActiveVerts = sc.accumulator(0)
|
||||
val processed = grouped.flatMapValues {
|
||||
|
@ -46,7 +71,7 @@ object Pregel extends Logging {
|
|||
compute(v, c match {
|
||||
case Seq(comb) => Some(comb)
|
||||
case Seq() => None
|
||||
}, superstep)
|
||||
})
|
||||
|
||||
numMsgs += newMsgs.size
|
||||
if (newVert.active)
|
||||
|
@ -58,30 +83,36 @@ object Pregel extends Logging {
|
|||
// Force evaluation of processed RDD for accurate performance measurements
|
||||
processed.foreach(x => {})
|
||||
|
||||
val timeTaken = System.currentTimeMillis - startTime
|
||||
logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
|
||||
(processed, numMsgs.value, numActiveVerts.value)
|
||||
}
|
||||
|
||||
// Check stopping condition and iterate
|
||||
val noActivity = numMsgs.value == 0 && numActiveVerts.value == 0
|
||||
if (noActivity) {
|
||||
processed.map { case (id, (vert, msgs)) => vert }
|
||||
} else {
|
||||
val newVerts = processed.mapValues { case (vert, msgs) => vert }
|
||||
val newMsgs = processed.flatMap {
|
||||
case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
|
||||
}
|
||||
run(sc, newVerts, newMsgs, combiner, numSplits, superstep + 1)(compute)
|
||||
}
|
||||
/**
|
||||
* Converts a compute function that doesn't take an aggregator to
|
||||
* one that does, so it can be passed to Pregel.run.
|
||||
*/
|
||||
implicit def addAggregatorArg[
|
||||
V <: Vertex : Manifest, M <: Message : Manifest, C
|
||||
](
|
||||
compute: (V, Option[C], Int) => (V, Iterable[M])
|
||||
): (V, Option[C], Option[Nothing], Int) => (V, Iterable[M]) = {
|
||||
(vert: V, messages: Option[C], aggregator: Option[Nothing], superstep: Int) => compute(vert, messages, superstep)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Simplify Combiner interface and make it more OO.
|
||||
trait Combiner[M, C] {
|
||||
def createCombiner(msg: M): C
|
||||
def mergeMsg(combiner: C, msg: M): C
|
||||
def mergeCombiners(a: C, b: C): C
|
||||
}
|
||||
|
||||
@serializable class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] {
|
||||
trait Aggregator[V, A] {
|
||||
def createAggregator(vert: V): A
|
||||
def mergeAggregators(a: A, b: A): A
|
||||
}
|
||||
|
||||
@serializable
|
||||
class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] {
|
||||
def createCombiner(msg: M): ArrayBuffer[M] =
|
||||
ArrayBuffer(msg)
|
||||
def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] =
|
||||
|
@ -90,6 +121,12 @@ trait Combiner[M, C] {
|
|||
a ++= b
|
||||
}
|
||||
|
||||
@serializable
|
||||
class NullAggregator[V] extends Aggregator[V, Option[Nothing]] {
|
||||
def createAggregator(vert: V): Option[Nothing] = None
|
||||
def mergeAggregators(a: Option[Nothing], b: Option[Nothing]): Option[Nothing] = None
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a Pregel vertex.
|
||||
*
|
||||
|
|
|
@ -5,6 +5,8 @@ import spark.SparkContext._
|
|||
|
||||
import scala.math.min
|
||||
|
||||
import bagel.Pregel._
|
||||
|
||||
object ShortestPath {
|
||||
def main(args: Array[String]) {
|
||||
if (args.length < 4) {
|
||||
|
@ -49,7 +51,7 @@ object ShortestPath {
|
|||
messages.count()+" messages.")
|
||||
|
||||
// Do the computation
|
||||
val result = Pregel.run(sc, vertices, messages, MinCombiner, numSplits) {
|
||||
val compute = addAggregatorArg {
|
||||
(self: SPVertex, messageMinValue: Option[Int], superstep: Int) =>
|
||||
val newValue = messageMinValue match {
|
||||
case Some(minVal) => min(self.value, minVal)
|
||||
|
@ -65,6 +67,7 @@ object ShortestPath {
|
|||
|
||||
(new SPVertex(self.id, newValue, self.outEdges, false), outbox)
|
||||
}
|
||||
val result = Pregel.run(sc, vertices, messages)(combiner = MinCombiner, numSplits = numSplits)(compute)
|
||||
|
||||
// Print the result
|
||||
System.err.println("Shortest path from "+startVertex+" to all vertices:")
|
||||
|
|
|
@ -3,6 +3,8 @@ package bagel
|
|||
import spark._
|
||||
import spark.SparkContext._
|
||||
|
||||
import bagel.Pregel._
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.xml.{XML,NodeSeq}
|
||||
|
||||
|
@ -60,9 +62,9 @@ object WikipediaPageRank {
|
|||
val messages = sc.parallelize(List[(String, PRMessage)]())
|
||||
val result =
|
||||
if (noCombiner) {
|
||||
Pregel.run(sc, vertices, messages, PRNoCombiner, numSplits)(PRNoCombiner.compute(numVertices, epsilon))
|
||||
Pregel.run(sc, vertices, messages)(numSplits = numSplits)(PRNoCombiner.compute(numVertices, epsilon))
|
||||
} else {
|
||||
Pregel.run(sc, vertices, messages, PRCombiner, numSplits)(PRCombiner.compute(numVertices, epsilon))
|
||||
Pregel.run(sc, vertices, messages)(combiner = PRCombiner, numSplits = numSplits)(PRCombiner.compute(numVertices, epsilon))
|
||||
}
|
||||
|
||||
// Print the result
|
||||
|
|
|
@ -10,6 +10,8 @@ import scala.collection.mutable.ArrayBuffer
|
|||
|
||||
import spark._
|
||||
|
||||
import bagel.Pregel._
|
||||
|
||||
@serializable class TestVertex(val id: String, val active: Boolean, val age: Int) extends Vertex
|
||||
@serializable class TestMessage(val targetId: String) extends Message
|
||||
|
||||
|
@ -20,10 +22,10 @@ class BagelSuite extends FunSuite with Assertions {
|
|||
val msgs = sc.parallelize(Array[(String, TestMessage)]())
|
||||
val numSupersteps = 5
|
||||
val result =
|
||||
Pregel.run(sc, verts, msgs, new DefaultCombiner[TestMessage], 1) {
|
||||
Pregel.run(sc, verts, msgs)()(addAggregatorArg {
|
||||
(self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) =>
|
||||
(new TestVertex(self.id, superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
|
||||
}
|
||||
})
|
||||
for (vert <- result.collect)
|
||||
assert(vert.age === numSupersteps)
|
||||
}
|
||||
|
@ -34,7 +36,7 @@ class BagelSuite extends FunSuite with Assertions {
|
|||
val msgs = sc.parallelize(Array("a" -> new TestMessage("a")))
|
||||
val numSupersteps = 5
|
||||
val result =
|
||||
Pregel.run(sc, verts, msgs, new DefaultCombiner[TestMessage], 1) {
|
||||
Pregel.run(sc, verts, msgs)()(addAggregatorArg {
|
||||
(self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) =>
|
||||
val msgsOut =
|
||||
msgs match {
|
||||
|
@ -44,7 +46,7 @@ class BagelSuite extends FunSuite with Assertions {
|
|||
new ArrayBuffer[TestMessage]()
|
||||
}
|
||||
(new TestVertex(self.id, self.active, self.age + 1), msgsOut)
|
||||
}
|
||||
})
|
||||
for (vert <- result.collect)
|
||||
assert(vert.age === numSupersteps)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue