commit
1a06f707e3
|
@ -157,6 +157,16 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
|
||||||
/** Return the value at the specified position. */
|
/** Return the value at the specified position. */
|
||||||
def getValue(pos: Int): T = _data(pos)
|
def getValue(pos: Int): T = _data(pos)
|
||||||
|
|
||||||
|
def iterator() = new Iterator[T] {
|
||||||
|
var pos = nextPos(0)
|
||||||
|
override def hasNext: Boolean = pos != INVALID_POS
|
||||||
|
override def next(): T = {
|
||||||
|
val tmp = getValue(pos)
|
||||||
|
pos = nextPos(pos+1)
|
||||||
|
tmp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/** Return the value at the specified position. */
|
/** Return the value at the specified position. */
|
||||||
def getValueSafe(pos: Int): T = {
|
def getValueSafe(pos: Int): T = {
|
||||||
assert(_bitset.get(pos))
|
assert(_bitset.get(pos))
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package org.apache.spark.graph
|
package org.apache.spark.graph
|
||||||
|
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The Graph abstractly represents a graph with arbitrary objects
|
* The Graph abstractly represents a graph with arbitrary objects
|
||||||
|
@ -70,6 +70,11 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
|
||||||
*/
|
*/
|
||||||
val triplets: RDD[EdgeTriplet[VD, ED]]
|
val triplets: RDD[EdgeTriplet[VD, ED]]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def persist(newLevel: StorageLevel): Graph[VD, ED]
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return a graph that is cached when first created. This is used to
|
* Return a graph that is cached when first created. This is used to
|
||||||
* pin a graph in memory enabling multiple queries to reuse the same
|
* pin a graph in memory enabling multiple queries to reuse the same
|
||||||
|
|
|
@ -2,7 +2,7 @@ package org.apache.spark.graph
|
||||||
|
|
||||||
import com.esotericsoftware.kryo.Kryo
|
import com.esotericsoftware.kryo.Kryo
|
||||||
|
|
||||||
import org.apache.spark.graph.impl.{EdgePartition, MessageToPartition}
|
import org.apache.spark.graph.impl._
|
||||||
import org.apache.spark.serializer.KryoRegistrator
|
import org.apache.spark.serializer.KryoRegistrator
|
||||||
import org.apache.spark.util.collection.BitSet
|
import org.apache.spark.util.collection.BitSet
|
||||||
|
|
||||||
|
@ -12,6 +12,8 @@ class GraphKryoRegistrator extends KryoRegistrator {
|
||||||
kryo.register(classOf[Edge[Object]])
|
kryo.register(classOf[Edge[Object]])
|
||||||
kryo.register(classOf[MutableTuple2[Object, Object]])
|
kryo.register(classOf[MutableTuple2[Object, Object]])
|
||||||
kryo.register(classOf[MessageToPartition[Object]])
|
kryo.register(classOf[MessageToPartition[Object]])
|
||||||
|
kryo.register(classOf[VertexBroadcastMsg[Object]])
|
||||||
|
kryo.register(classOf[AggregationMsg[Object]])
|
||||||
kryo.register(classOf[(Vid, Object)])
|
kryo.register(classOf[(Vid, Object)])
|
||||||
kryo.register(classOf[EdgePartition[Object]])
|
kryo.register(classOf[EdgePartition[Object]])
|
||||||
kryo.register(classOf[BitSet])
|
kryo.register(classOf[BitSet])
|
||||||
|
|
|
@ -98,14 +98,14 @@ object Pregel {
|
||||||
: Graph[VD, ED] = {
|
: Graph[VD, ED] = {
|
||||||
|
|
||||||
// Receive the first set of messages
|
// Receive the first set of messages
|
||||||
var g = graph.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg))
|
var g = graph.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg)).cache
|
||||||
|
|
||||||
var i = 0
|
var i = 0
|
||||||
while (i < numIter) {
|
while (i < numIter) {
|
||||||
// compute the messages
|
// compute the messages
|
||||||
val messages = g.mapReduceTriplets(sendMsg, mergeMsg)
|
val messages = g.mapReduceTriplets(sendMsg, mergeMsg)
|
||||||
// receive the messages
|
// receive the messages
|
||||||
g = g.joinVertices(messages)(vprog)
|
g = g.joinVertices(messages)(vprog).cache
|
||||||
// count the iteration
|
// count the iteration
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,8 @@ import org.apache.spark.SparkContext._
|
||||||
import org.apache.spark.rdd._
|
import org.apache.spark.rdd._
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveKeyOpenHashMap}
|
import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveKeyOpenHashMap}
|
||||||
|
import org.apache.spark.graph.impl.AggregationMsg
|
||||||
|
import org.apache.spark.graph.impl.MsgRDDFunctions._
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The `VertexSetIndex` maintains the per-partition mapping from
|
* The `VertexSetIndex` maintains the per-partition mapping from
|
||||||
|
@ -659,6 +660,43 @@ object VertexSetRDD {
|
||||||
apply(rdd,index, (v:V) => v, reduceFunc, reduceFunc)
|
apply(rdd,index, (v:V) => v, reduceFunc, reduceFunc)
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate[V: ClassManifest](
|
||||||
|
rdd: RDD[AggregationMsg[V]], index: VertexSetIndex,
|
||||||
|
reduceFunc: (V, V) => V): VertexSetRDD[V] = {
|
||||||
|
|
||||||
|
val cReduceFunc = index.rdd.context.clean(reduceFunc)
|
||||||
|
assert(rdd.partitioner == index.rdd.partitioner)
|
||||||
|
// Use the index to build the new values table
|
||||||
|
val values: RDD[ (Array[V], BitSet) ] = index.rdd.zipPartitions(rdd)( (indexIter, tblIter) => {
|
||||||
|
// There is only one map
|
||||||
|
val index = indexIter.next()
|
||||||
|
assert(!indexIter.hasNext)
|
||||||
|
val values = new Array[V](index.capacity)
|
||||||
|
val bs = new BitSet(index.capacity)
|
||||||
|
for (msg <- tblIter) {
|
||||||
|
// Get the location of the key in the index
|
||||||
|
val pos = index.getPos(msg.vid)
|
||||||
|
if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
|
||||||
|
throw new SparkException("Error: Trying to bind an external index " +
|
||||||
|
"to an RDD which contains keys that are not in the index.")
|
||||||
|
} else {
|
||||||
|
// Get the actual index
|
||||||
|
val ind = pos & OpenHashSet.POSITION_MASK
|
||||||
|
// If this value has already been seen then merge
|
||||||
|
if (bs.get(ind)) {
|
||||||
|
values(ind) = cReduceFunc(values(ind), msg.data)
|
||||||
|
} else { // otherwise just store the new value
|
||||||
|
bs.set(ind)
|
||||||
|
values(ind) = msg.data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Iterator((values, bs))
|
||||||
|
})
|
||||||
|
new VertexSetRDD(index, values)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct a vertex set from an RDD using an existing index and a
|
* Construct a vertex set from an RDD using an existing index and a
|
||||||
* user defined `combiner` to merge duplicate vertices.
|
* user defined `combiner` to merge duplicate vertices.
|
||||||
|
@ -675,11 +713,11 @@ object VertexSetRDD {
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
def apply[V: ClassManifest, C: ClassManifest](
|
def apply[V: ClassManifest, C: ClassManifest](
|
||||||
rdd: RDD[(Vid,V)],
|
rdd: RDD[(Vid,V)],
|
||||||
index: VertexSetIndex,
|
index: VertexSetIndex,
|
||||||
createCombiner: V => C,
|
createCombiner: V => C,
|
||||||
mergeValue: (C, V) => C,
|
mergeValue: (C, V) => C,
|
||||||
mergeCombiners: (C, C) => C): VertexSetRDD[C] = {
|
mergeCombiners: (C, C) => C): VertexSetRDD[C] = {
|
||||||
val cCreateCombiner = index.rdd.context.clean(createCombiner)
|
val cCreateCombiner = index.rdd.context.clean(createCombiner)
|
||||||
val cMergeValue = index.rdd.context.clean(mergeValue)
|
val cMergeValue = index.rdd.context.clean(mergeValue)
|
||||||
val cMergeCombiners = index.rdd.context.clean(mergeCombiners)
|
val cMergeCombiners = index.rdd.context.clean(mergeCombiners)
|
||||||
|
|
|
@ -5,15 +5,15 @@ import scala.collection.JavaConversions._
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
import scala.collection.mutable.ArrayBuffer
|
import scala.collection.mutable.ArrayBuffer
|
||||||
|
|
||||||
|
|
||||||
import org.apache.spark.SparkContext._
|
import org.apache.spark.SparkContext._
|
||||||
import org.apache.spark.HashPartitioner
|
import org.apache.spark.HashPartitioner
|
||||||
import org.apache.spark.util.ClosureCleaner
|
import org.apache.spark.util.ClosureCleaner
|
||||||
|
|
||||||
import org.apache.spark.graph._
|
import org.apache.spark.graph._
|
||||||
import org.apache.spark.graph.impl.GraphImpl._
|
import org.apache.spark.graph.impl.GraphImpl._
|
||||||
import org.apache.spark.graph.impl.MessageToPartitionRDDFunctions._
|
import org.apache.spark.graph.impl.MsgRDDFunctions._
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
import org.apache.spark.storage.StorageLevel
|
||||||
import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveKeyOpenHashMap}
|
import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveKeyOpenHashMap}
|
||||||
|
|
||||||
|
|
||||||
|
@ -72,8 +72,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
|
||||||
|
|
||||||
def this() = this(null, null, null, null)
|
def this() = this(null, null, null, null)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* (localVidMap: VertexSetRDD[Pid, VertexIdToIndexMap]) is a version of the
|
* (localVidMap: VertexSetRDD[Pid, VertexIdToIndexMap]) is a version of the
|
||||||
* vertex data after it is replicated. Within each partition, it holds a map
|
* vertex data after it is replicated. Within each partition, it holds a map
|
||||||
|
@ -86,29 +84,28 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
|
||||||
@transient val vTableReplicatedValues: RDD[(Pid, Array[VD])] =
|
@transient val vTableReplicatedValues: RDD[(Pid, Array[VD])] =
|
||||||
createVTableReplicated(vTable, vid2pid, localVidMap)
|
createVTableReplicated(vTable, vid2pid, localVidMap)
|
||||||
|
|
||||||
|
|
||||||
/** Return a RDD of vertices. */
|
/** Return a RDD of vertices. */
|
||||||
@transient override val vertices = vTable
|
@transient override val vertices = vTable
|
||||||
|
|
||||||
|
|
||||||
/** Return a RDD of edges. */
|
/** Return a RDD of edges. */
|
||||||
@transient override val edges: RDD[Edge[ED]] = {
|
@transient override val edges: RDD[Edge[ED]] = {
|
||||||
eTable.mapPartitions( iter => iter.next()._2.iterator , true )
|
eTable.mapPartitions( iter => iter.next()._2.iterator , true )
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/** Return a RDD that brings edges with its source and destination vertices together. */
|
/** Return a RDD that brings edges with its source and destination vertices together. */
|
||||||
@transient override val triplets: RDD[EdgeTriplet[VD, ED]] =
|
@transient override val triplets: RDD[EdgeTriplet[VD, ED]] =
|
||||||
makeTriplets(localVidMap, vTableReplicatedValues, eTable)
|
makeTriplets(localVidMap, vTableReplicatedValues, eTable)
|
||||||
|
|
||||||
|
override def persist(newLevel: StorageLevel): Graph[VD, ED] = {
|
||||||
override def cache(): Graph[VD, ED] = {
|
eTable.persist(newLevel)
|
||||||
eTable.cache()
|
vid2pid.persist(newLevel)
|
||||||
vid2pid.cache()
|
vTable.persist(newLevel)
|
||||||
vTable.cache()
|
localVidMap.persist(newLevel)
|
||||||
|
// vTableReplicatedValues.persist(newLevel)
|
||||||
this
|
this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def cache(): Graph[VD, ED] = persist(StorageLevel.MEMORY_ONLY)
|
||||||
|
|
||||||
override def statistics: Map[String, Any] = {
|
override def statistics: Map[String, Any] = {
|
||||||
val numVertices = this.numVertices
|
val numVertices = this.numVertices
|
||||||
|
@ -125,7 +122,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
|
||||||
"Min Load" -> minLoad, "Max Load" -> maxLoad)
|
"Min Load" -> minLoad, "Max Load" -> maxLoad)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Display the lineage information for this graph.
|
* Display the lineage information for this graph.
|
||||||
*/
|
*/
|
||||||
|
@ -183,14 +179,12 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
|
||||||
println(visited)
|
println(visited)
|
||||||
} // end of print lineage
|
} // end of print lineage
|
||||||
|
|
||||||
|
|
||||||
override def reverse: Graph[VD, ED] = {
|
override def reverse: Graph[VD, ED] = {
|
||||||
val newEtable = eTable.mapPartitions( _.map{ case (pid, epart) => (pid, epart.reverse) },
|
val newEtable = eTable.mapPartitions( _.map{ case (pid, epart) => (pid, epart.reverse) },
|
||||||
preservesPartitioning = true)
|
preservesPartitioning = true)
|
||||||
new GraphImpl(vTable, vid2pid, localVidMap, newEtable)
|
new GraphImpl(vTable, vid2pid, localVidMap, newEtable)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override def mapVertices[VD2: ClassManifest](f: (Vid, VD) => VD2): Graph[VD2, ED] = {
|
override def mapVertices[VD2: ClassManifest](f: (Vid, VD) => VD2): Graph[VD2, ED] = {
|
||||||
val newVTable = vTable.mapValuesWithKeys((vid, data) => f(vid, data))
|
val newVTable = vTable.mapValuesWithKeys((vid, data) => f(vid, data))
|
||||||
new GraphImpl(newVTable, vid2pid, localVidMap, eTable)
|
new GraphImpl(newVTable, vid2pid, localVidMap, eTable)
|
||||||
|
@ -202,11 +196,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
|
||||||
new GraphImpl(vTable, vid2pid, localVidMap, newETable)
|
new GraphImpl(vTable, vid2pid, localVidMap, newETable)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override def mapTriplets[ED2: ClassManifest](f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] =
|
override def mapTriplets[ED2: ClassManifest](f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] =
|
||||||
GraphImpl.mapTriplets(this, f)
|
GraphImpl.mapTriplets(this, f)
|
||||||
|
|
||||||
|
|
||||||
override def subgraph(epred: EdgeTriplet[VD,ED] => Boolean = (x => true),
|
override def subgraph(epred: EdgeTriplet[VD,ED] => Boolean = (x => true),
|
||||||
vpred: (Vid, VD) => Boolean = ((a,b) => true) ): Graph[VD, ED] = {
|
vpred: (Vid, VD) => Boolean = ((a,b) => true) ): Graph[VD, ED] = {
|
||||||
|
|
||||||
|
@ -246,7 +238,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
|
||||||
new GraphImpl(newVTable, newVid2Pid, localVidMap, newETable)
|
new GraphImpl(newVTable, newVid2Pid, localVidMap, newETable)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override def groupEdgeTriplets[ED2: ClassManifest](
|
override def groupEdgeTriplets[ED2: ClassManifest](
|
||||||
f: Iterator[EdgeTriplet[VD,ED]] => ED2 ): Graph[VD,ED2] = {
|
f: Iterator[EdgeTriplet[VD,ED]] => ED2 ): Graph[VD,ED2] = {
|
||||||
val newEdges: RDD[Edge[ED2]] = triplets.mapPartitions { partIter =>
|
val newEdges: RDD[Edge[ED2]] = triplets.mapPartitions { partIter =>
|
||||||
|
@ -271,7 +262,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
|
||||||
new GraphImpl(vTable, vid2pid, localVidMap, newETable)
|
new GraphImpl(vTable, vid2pid, localVidMap, newETable)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override def groupEdges[ED2: ClassManifest](f: Iterator[Edge[ED]] => ED2 ):
|
override def groupEdges[ED2: ClassManifest](f: Iterator[Edge[ED]] => ED2 ):
|
||||||
Graph[VD,ED2] = {
|
Graph[VD,ED2] = {
|
||||||
|
|
||||||
|
@ -289,8 +279,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
|
||||||
new GraphImpl(vTable, vid2pid, localVidMap, newETable)
|
new GraphImpl(vTable, vid2pid, localVidMap, newETable)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// Lower level transformation methods
|
// Lower level transformation methods
|
||||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -301,7 +289,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
|
||||||
: VertexSetRDD[A] =
|
: VertexSetRDD[A] =
|
||||||
GraphImpl.mapReduceTriplets(this, mapFunc, reduceFunc)
|
GraphImpl.mapReduceTriplets(this, mapFunc, reduceFunc)
|
||||||
|
|
||||||
|
|
||||||
override def outerJoinVertices[U: ClassManifest, VD2: ClassManifest]
|
override def outerJoinVertices[U: ClassManifest, VD2: ClassManifest]
|
||||||
(updates: RDD[(Vid, U)])(updateF: (Vid, VD, Option[U]) => VD2)
|
(updates: RDD[(Vid, U)])(updateF: (Vid, VD, Option[U]) => VD2)
|
||||||
: Graph[VD2, ED] = {
|
: Graph[VD2, ED] = {
|
||||||
|
@ -309,15 +296,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
|
||||||
val newVTable = vTable.leftJoin(updates)(updateF)
|
val newVTable = vTable.leftJoin(updates)(updateF)
|
||||||
new GraphImpl(newVTable, vid2pid, localVidMap, eTable)
|
new GraphImpl(newVTable, vid2pid, localVidMap, eTable)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
} // end of class GraphImpl
|
} // end of class GraphImpl
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
object GraphImpl {
|
object GraphImpl {
|
||||||
|
|
||||||
def apply[VD: ClassManifest, ED: ClassManifest](
|
def apply[VD: ClassManifest, ED: ClassManifest](
|
||||||
|
@ -327,7 +308,6 @@ object GraphImpl {
|
||||||
apply(vertices, edges, defaultVertexAttr, (a:VD, b:VD) => a)
|
apply(vertices, edges, defaultVertexAttr, (a:VD, b:VD) => a)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def apply[VD: ClassManifest, ED: ClassManifest](
|
def apply[VD: ClassManifest, ED: ClassManifest](
|
||||||
vertices: RDD[(Vid, VD)],
|
vertices: RDD[(Vid, VD)],
|
||||||
edges: RDD[Edge[ED]],
|
edges: RDD[Edge[ED]],
|
||||||
|
@ -353,7 +333,6 @@ object GraphImpl {
|
||||||
new GraphImpl(vtable, vid2pid, localVidMap, etable)
|
new GraphImpl(vtable, vid2pid, localVidMap, etable)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create the edge table RDD, which is much more efficient for Java heap storage than the
|
* Create the edge table RDD, which is much more efficient for Java heap storage than the
|
||||||
* normal edges data structure (RDD[(Vid, Vid, ED)]).
|
* normal edges data structure (RDD[(Vid, Vid, ED)]).
|
||||||
|
@ -375,7 +354,7 @@ object GraphImpl {
|
||||||
//val part: Pid = canonicalEdgePartitionFunction2D(e.srcId, e.dstId, numPartitions, ceilSqrt)
|
//val part: Pid = canonicalEdgePartitionFunction2D(e.srcId, e.dstId, numPartitions, ceilSqrt)
|
||||||
|
|
||||||
// Should we be using 3-tuple or an optimized class
|
// Should we be using 3-tuple or an optimized class
|
||||||
MessageToPartition(part, (e.srcId, e.dstId, e.attr))
|
new MessageToPartition(part, (e.srcId, e.dstId, e.attr))
|
||||||
}
|
}
|
||||||
.partitionBy(new HashPartitioner(numPartitions))
|
.partitionBy(new HashPartitioner(numPartitions))
|
||||||
.mapPartitionsWithIndex( (pid, iter) => {
|
.mapPartitionsWithIndex( (pid, iter) => {
|
||||||
|
@ -389,7 +368,6 @@ object GraphImpl {
|
||||||
}, preservesPartitioning = true).cache()
|
}, preservesPartitioning = true).cache()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected def createVid2Pid[ED: ClassManifest](
|
protected def createVid2Pid[ED: ClassManifest](
|
||||||
eTable: RDD[(Pid, EdgePartition[ED])],
|
eTable: RDD[(Pid, EdgePartition[ED])],
|
||||||
vTableIndex: VertexSetIndex): VertexSetRDD[Array[Pid]] = {
|
vTableIndex: VertexSetIndex): VertexSetRDD[Array[Pid]] = {
|
||||||
|
@ -398,7 +376,7 @@ object GraphImpl {
|
||||||
val vSet = new VertexSet
|
val vSet = new VertexSet
|
||||||
edgePartition.foreach(e => {vSet.add(e.srcId); vSet.add(e.dstId)})
|
edgePartition.foreach(e => {vSet.add(e.srcId); vSet.add(e.dstId)})
|
||||||
vSet.iterator.map { vid => (vid.toLong, pid) }
|
vSet.iterator.map { vid => (vid.toLong, pid) }
|
||||||
}
|
}.partitionBy(vTableIndex.rdd.partitioner.get)
|
||||||
VertexSetRDD[Pid, ArrayBuffer[Pid]](preAgg, vTableIndex,
|
VertexSetRDD[Pid, ArrayBuffer[Pid]](preAgg, vTableIndex,
|
||||||
(p: Pid) => ArrayBuffer(p),
|
(p: Pid) => ArrayBuffer(p),
|
||||||
(ab: ArrayBuffer[Pid], p:Pid) => {ab.append(p); ab},
|
(ab: ArrayBuffer[Pid], p:Pid) => {ab.append(p); ab},
|
||||||
|
@ -406,7 +384,6 @@ object GraphImpl {
|
||||||
.mapValues(a => a.toArray).cache()
|
.mapValues(a => a.toArray).cache()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected def createLocalVidMap[ED: ClassManifest](eTable: RDD[(Pid, EdgePartition[ED])]):
|
protected def createLocalVidMap[ED: ClassManifest](eTable: RDD[(Pid, EdgePartition[ED])]):
|
||||||
RDD[(Pid, VertexIdToIndexMap)] = {
|
RDD[(Pid, VertexIdToIndexMap)] = {
|
||||||
eTable.mapPartitions( _.map{ case (pid, epart) =>
|
eTable.mapPartitions( _.map{ case (pid, epart) =>
|
||||||
|
@ -419,7 +396,6 @@ object GraphImpl {
|
||||||
}, preservesPartitioning = true).cache()
|
}, preservesPartitioning = true).cache()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected def createVTableReplicated[VD: ClassManifest](
|
protected def createVTableReplicated[VD: ClassManifest](
|
||||||
vTable: VertexSetRDD[VD],
|
vTable: VertexSetRDD[VD],
|
||||||
vid2pid: VertexSetRDD[Array[Pid]],
|
vid2pid: VertexSetRDD[Array[Pid]],
|
||||||
|
@ -428,7 +404,10 @@ object GraphImpl {
|
||||||
// Join vid2pid and vTable, generate a shuffle dependency on the joined
|
// Join vid2pid and vTable, generate a shuffle dependency on the joined
|
||||||
// result, and get the shuffle id so we can use it on the slave.
|
// result, and get the shuffle id so we can use it on the slave.
|
||||||
val msgsByPartition = vTable.zipJoinFlatMap(vid2pid) { (vid, vdata, pids) =>
|
val msgsByPartition = vTable.zipJoinFlatMap(vid2pid) { (vid, vdata, pids) =>
|
||||||
pids.iterator.map { pid => MessageToPartition(pid, (vid, vdata)) }
|
// TODO(rxin): reuse VertexBroadcastMessage
|
||||||
|
pids.iterator.map { pid =>
|
||||||
|
new VertexBroadcastMsg[VD](pid, vid, vdata)
|
||||||
|
}
|
||||||
}.partitionBy(replicationMap.partitioner.get).cache()
|
}.partitionBy(replicationMap.partitioner.get).cache()
|
||||||
|
|
||||||
replicationMap.zipPartitions(msgsByPartition){
|
replicationMap.zipPartitions(msgsByPartition){
|
||||||
|
@ -438,8 +417,8 @@ object GraphImpl {
|
||||||
// Populate the vertex array using the vidToIndex map
|
// Populate the vertex array using the vidToIndex map
|
||||||
val vertexArray = new Array[VD](vidToIndex.capacity)
|
val vertexArray = new Array[VD](vidToIndex.capacity)
|
||||||
for (msg <- msgsIter) {
|
for (msg <- msgsIter) {
|
||||||
val ind = vidToIndex.getPos(msg.data._1) & OpenHashSet.POSITION_MASK
|
val ind = vidToIndex.getPos(msg.vid) & OpenHashSet.POSITION_MASK
|
||||||
vertexArray(ind) = msg.data._2
|
vertexArray(ind) = msg.data
|
||||||
}
|
}
|
||||||
Iterator((pid, vertexArray))
|
Iterator((pid, vertexArray))
|
||||||
}.cache()
|
}.cache()
|
||||||
|
@ -447,7 +426,6 @@ object GraphImpl {
|
||||||
// @todo assert edge table has partitioner
|
// @todo assert edge table has partitioner
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def makeTriplets[VD: ClassManifest, ED: ClassManifest](
|
def makeTriplets[VD: ClassManifest, ED: ClassManifest](
|
||||||
localVidMap: RDD[(Pid, VertexIdToIndexMap)],
|
localVidMap: RDD[(Pid, VertexIdToIndexMap)],
|
||||||
vTableReplicatedValues: RDD[(Pid, Array[VD]) ],
|
vTableReplicatedValues: RDD[(Pid, Array[VD]) ],
|
||||||
|
@ -461,7 +439,6 @@ object GraphImpl {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def mapTriplets[VD: ClassManifest, ED: ClassManifest, ED2: ClassManifest](
|
def mapTriplets[VD: ClassManifest, ED: ClassManifest, ED2: ClassManifest](
|
||||||
g: GraphImpl[VD, ED],
|
g: GraphImpl[VD, ED],
|
||||||
f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
|
f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
|
||||||
|
@ -483,7 +460,6 @@ object GraphImpl {
|
||||||
new GraphImpl(g.vTable, g.vid2pid, g.localVidMap, newETable)
|
new GraphImpl(g.vTable, g.vid2pid, g.localVidMap, newETable)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def mapReduceTriplets[VD: ClassManifest, ED: ClassManifest, A: ClassManifest](
|
def mapReduceTriplets[VD: ClassManifest, ED: ClassManifest, A: ClassManifest](
|
||||||
g: GraphImpl[VD, ED],
|
g: GraphImpl[VD, ED],
|
||||||
mapFunc: EdgeTriplet[VD, ED] => Array[(Vid, A)],
|
mapFunc: EdgeTriplet[VD, ED] => Array[(Vid, A)],
|
||||||
|
@ -495,33 +471,35 @@ object GraphImpl {
|
||||||
// Map and preaggregate
|
// Map and preaggregate
|
||||||
val preAgg = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues){
|
val preAgg = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues){
|
||||||
(edgePartitionIter, vidToIndexIter, vertexArrayIter) =>
|
(edgePartitionIter, vidToIndexIter, vertexArrayIter) =>
|
||||||
val (pid, edgePartition) = edgePartitionIter.next()
|
val (_, edgePartition) = edgePartitionIter.next()
|
||||||
val (_, vidToIndex) = vidToIndexIter.next()
|
val (_, vidToIndex) = vidToIndexIter.next()
|
||||||
val (_, vertexArray) = vertexArrayIter.next()
|
val (_, vertexArray) = vertexArrayIter.next()
|
||||||
assert(!edgePartitionIter.hasNext)
|
assert(!edgePartitionIter.hasNext)
|
||||||
assert(!vidToIndexIter.hasNext)
|
assert(!vidToIndexIter.hasNext)
|
||||||
assert(!vertexArrayIter.hasNext)
|
assert(!vertexArrayIter.hasNext)
|
||||||
assert(vidToIndex.capacity == vertexArray.size)
|
assert(vidToIndex.capacity == vertexArray.size)
|
||||||
|
// Reuse the vidToIndex map to run aggregation.
|
||||||
val vmap = new PrimitiveKeyOpenHashMap[Vid, VD](vidToIndex, vertexArray)
|
val vmap = new PrimitiveKeyOpenHashMap[Vid, VD](vidToIndex, vertexArray)
|
||||||
// We can reuse the vidToIndex map for aggregation here as well.
|
// TODO(jegonzal): This doesn't allow users to send messages to arbitrary vertices.
|
||||||
/** @todo Since this has the downside of not allowing "messages" to arbitrary
|
|
||||||
* vertices we should consider just using a fresh map.
|
|
||||||
*/
|
|
||||||
val msgArray = new Array[A](vertexArray.size)
|
val msgArray = new Array[A](vertexArray.size)
|
||||||
val msgBS = new BitSet(vertexArray.size)
|
val msgBS = new BitSet(vertexArray.size)
|
||||||
// Iterate over the partition
|
// Iterate over the partition
|
||||||
val et = new EdgeTriplet[VD, ED]
|
val et = new EdgeTriplet[VD, ED]
|
||||||
edgePartition.foreach{e =>
|
|
||||||
|
edgePartition.foreach { e =>
|
||||||
et.set(e)
|
et.set(e)
|
||||||
et.srcAttr = vmap(e.srcId)
|
et.srcAttr = vmap(e.srcId)
|
||||||
et.dstAttr = vmap(e.dstId)
|
et.dstAttr = vmap(e.dstId)
|
||||||
mapFunc(et).foreach{ case (vid, msg) =>
|
// TODO(rxin): rewrite the foreach using a simple while loop to speed things up.
|
||||||
|
// Also given we are only allowing zero, one, or two messages, we can completely unroll
|
||||||
|
// the for loop.
|
||||||
|
mapFunc(et).foreach { case (vid, msg) =>
|
||||||
// verify that the vid is valid
|
// verify that the vid is valid
|
||||||
assert(vid == et.srcId || vid == et.dstId)
|
assert(vid == et.srcId || vid == et.dstId)
|
||||||
// Get the index of the key
|
// Get the index of the key
|
||||||
val ind = vidToIndex.getPos(vid) & OpenHashSet.POSITION_MASK
|
val ind = vidToIndex.getPos(vid) & OpenHashSet.POSITION_MASK
|
||||||
// Populate the aggregator map
|
// Populate the aggregator map
|
||||||
if(msgBS.get(ind)) {
|
if (msgBS.get(ind)) {
|
||||||
msgArray(ind) = reduceFunc(msgArray(ind), msg)
|
msgArray(ind) = reduceFunc(msgArray(ind), msg)
|
||||||
} else {
|
} else {
|
||||||
msgArray(ind) = msg
|
msgArray(ind) = msg
|
||||||
|
@ -530,20 +508,19 @@ object GraphImpl {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// construct an iterator of tuples Iterator[(Vid, A)]
|
// construct an iterator of tuples Iterator[(Vid, A)]
|
||||||
msgBS.iterator.map( ind => (vidToIndex.getValue(ind), msgArray(ind)) )
|
msgBS.iterator.map { ind =>
|
||||||
|
new AggregationMsg[A](vidToIndex.getValue(ind), msgArray(ind))
|
||||||
|
}
|
||||||
}.partitionBy(g.vTable.index.rdd.partitioner.get)
|
}.partitionBy(g.vTable.index.rdd.partitioner.get)
|
||||||
// do the final reduction reusing the index map
|
// do the final reduction reusing the index map
|
||||||
VertexSetRDD(preAgg, g.vTable.index, reduceFunc)
|
VertexSetRDD.aggregate(preAgg, g.vTable.index, reduceFunc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected def edgePartitionFunction1D(src: Vid, dst: Vid, numParts: Pid): Pid = {
|
protected def edgePartitionFunction1D(src: Vid, dst: Vid, numParts: Pid): Pid = {
|
||||||
val mixingPrime: Vid = 1125899906842597L
|
val mixingPrime: Vid = 1125899906842597L
|
||||||
(math.abs(src) * mixingPrime).toInt % numParts
|
(math.abs(src) * mixingPrime).toInt % numParts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This function implements a classic 2D-Partitioning of a sparse matrix.
|
* This function implements a classic 2D-Partitioning of a sparse matrix.
|
||||||
* Suppose we have a graph with 11 vertices that we want to partition
|
* Suppose we have a graph with 11 vertices that we want to partition
|
||||||
|
@ -596,7 +573,6 @@ object GraphImpl {
|
||||||
(col * ceilSqrtNumParts + row) % numParts
|
(col * ceilSqrtNumParts + row) % numParts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Assign edges to an aribtrary machine corresponding to a
|
* Assign edges to an aribtrary machine corresponding to a
|
||||||
* random vertex cut.
|
* random vertex cut.
|
||||||
|
@ -605,7 +581,6 @@ object GraphImpl {
|
||||||
math.abs((src, dst).hashCode()) % numParts
|
math.abs((src, dst).hashCode()) % numParts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @todo This will only partition edges to the upper diagonal
|
* @todo This will only partition edges to the upper diagonal
|
||||||
* of the 2D processor space.
|
* of the 2D processor space.
|
||||||
|
@ -622,4 +597,3 @@ object GraphImpl {
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end of object GraphImpl
|
} // end of object GraphImpl
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,35 @@
|
||||||
package org.apache.spark.graph.impl
|
package org.apache.spark.graph.impl
|
||||||
|
|
||||||
import org.apache.spark.Partitioner
|
import org.apache.spark.Partitioner
|
||||||
import org.apache.spark.graph.Pid
|
import org.apache.spark.graph.{Pid, Vid}
|
||||||
import org.apache.spark.rdd.{ShuffledRDD, RDD}
|
import org.apache.spark.rdd.{ShuffledRDD, RDD}
|
||||||
|
|
||||||
|
|
||||||
|
class VertexBroadcastMsg[@specialized(Int, Long, Double, Boolean) T](
|
||||||
|
@transient var partition: Pid,
|
||||||
|
var vid: Vid,
|
||||||
|
var data: T)
|
||||||
|
extends Product2[Pid, (Vid, T)] {
|
||||||
|
|
||||||
|
override def _1 = partition
|
||||||
|
|
||||||
|
override def _2 = (vid, data)
|
||||||
|
|
||||||
|
override def canEqual(that: Any): Boolean = that.isInstanceOf[VertexBroadcastMsg[_]]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AggregationMsg[@specialized(Int, Long, Double, Boolean) T](var vid: Vid, var data: T)
|
||||||
|
extends Product2[Vid, T] {
|
||||||
|
|
||||||
|
override def _1 = vid
|
||||||
|
|
||||||
|
override def _2 = data
|
||||||
|
|
||||||
|
override def canEqual(that: Any): Boolean = that.isInstanceOf[AggregationMsg[_]]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A message used to send a specific value to a partition.
|
* A message used to send a specific value to a partition.
|
||||||
* @param partition index of the target partition.
|
* @param partition index of the target partition.
|
||||||
|
@ -22,15 +47,42 @@ class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef
|
||||||
override def canEqual(that: Any): Boolean = that.isInstanceOf[MessageToPartition[_]]
|
override def canEqual(that: Any): Boolean = that.isInstanceOf[MessageToPartition[_]]
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Companion object for MessageToPartition.
|
class VertexBroadcastMsgRDDFunctions[T: ClassManifest](self: RDD[VertexBroadcastMsg[T]]) {
|
||||||
*/
|
def partitionBy(partitioner: Partitioner): RDD[VertexBroadcastMsg[T]] = {
|
||||||
object MessageToPartition {
|
val rdd = new ShuffledRDD[Pid, (Vid, T), VertexBroadcastMsg[T]](self, partitioner)
|
||||||
def apply[T](partition: Pid, value: T) = new MessageToPartition(partition, value)
|
|
||||||
|
// Set a custom serializer if the data is of int or double type.
|
||||||
|
if (classManifest[T] == ClassManifest.Int) {
|
||||||
|
rdd.setSerializer(classOf[IntVertexBroadcastMsgSerializer].getName)
|
||||||
|
} else if (classManifest[T] == ClassManifest.Long) {
|
||||||
|
rdd.setSerializer(classOf[LongVertexBroadcastMsgSerializer].getName)
|
||||||
|
} else if (classManifest[T] == ClassManifest.Double) {
|
||||||
|
rdd.setSerializer(classOf[DoubleVertexBroadcastMsgSerializer].getName)
|
||||||
|
}
|
||||||
|
rdd
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class MessageToPartitionRDDFunctions[T: ClassManifest](self: RDD[MessageToPartition[T]]) {
|
class AggregationMessageRDDFunctions[T: ClassManifest](self: RDD[AggregationMsg[T]]) {
|
||||||
|
def partitionBy(partitioner: Partitioner): RDD[AggregationMsg[T]] = {
|
||||||
|
val rdd = new ShuffledRDD[Vid, T, AggregationMsg[T]](self, partitioner)
|
||||||
|
|
||||||
|
// Set a custom serializer if the data is of int or double type.
|
||||||
|
if (classManifest[T] == ClassManifest.Int) {
|
||||||
|
rdd.setSerializer(classOf[IntAggMsgSerializer].getName)
|
||||||
|
} else if (classManifest[T] == ClassManifest.Long) {
|
||||||
|
rdd.setSerializer(classOf[LongAggMsgSerializer].getName)
|
||||||
|
} else if (classManifest[T] == ClassManifest.Double) {
|
||||||
|
rdd.setSerializer(classOf[DoubleAggMsgSerializer].getName)
|
||||||
|
}
|
||||||
|
rdd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MsgRDDFunctions[T: ClassManifest](self: RDD[MessageToPartition[T]]) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return a copy of the RDD partitioned using the specified partitioner.
|
* Return a copy of the RDD partitioned using the specified partitioner.
|
||||||
|
@ -42,8 +94,16 @@ class MessageToPartitionRDDFunctions[T: ClassManifest](self: RDD[MessageToPartit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
object MessageToPartitionRDDFunctions {
|
object MsgRDDFunctions {
|
||||||
implicit def rdd2PartitionRDDFunctions[T: ClassManifest](rdd: RDD[MessageToPartition[T]]) = {
|
implicit def rdd2PartitionRDDFunctions[T: ClassManifest](rdd: RDD[MessageToPartition[T]]) = {
|
||||||
new MessageToPartitionRDDFunctions(rdd)
|
new MsgRDDFunctions(rdd)
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit def rdd2vertexMessageRDDFunctions[T: ClassManifest](rdd: RDD[VertexBroadcastMsg[T]]) = {
|
||||||
|
new VertexBroadcastMsgRDDFunctions(rdd)
|
||||||
|
}
|
||||||
|
|
||||||
|
implicit def rdd2aggMessageRDDFunctions[T: ClassManifest](rdd: RDD[AggregationMsg[T]]) = {
|
||||||
|
new AggregationMessageRDDFunctions(rdd)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,224 @@
|
||||||
|
package org.apache.spark.graph.impl
|
||||||
|
|
||||||
|
import java.io.{EOFException, InputStream, OutputStream}
|
||||||
|
import java.nio.ByteBuffer
|
||||||
|
|
||||||
|
import org.apache.spark.serializer._
|
||||||
|
|
||||||
|
|
||||||
|
/** A special shuffle serializer for VertexBroadcastMessage[Int]. */
|
||||||
|
class IntVertexBroadcastMsgSerializer extends Serializer {
|
||||||
|
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
|
||||||
|
|
||||||
|
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
|
||||||
|
def writeObject[T](t: T) = {
|
||||||
|
val msg = t.asInstanceOf[VertexBroadcastMsg[Int]]
|
||||||
|
writeLong(msg.vid)
|
||||||
|
writeInt(msg.data)
|
||||||
|
this
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
|
||||||
|
override def readObject[T](): T = {
|
||||||
|
new VertexBroadcastMsg[Int](0, readLong(), readInt()).asInstanceOf[T]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** A special shuffle serializer for VertexBroadcastMessage[Long]. */
|
||||||
|
class LongVertexBroadcastMsgSerializer extends Serializer {
|
||||||
|
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
|
||||||
|
|
||||||
|
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
|
||||||
|
def writeObject[T](t: T) = {
|
||||||
|
val msg = t.asInstanceOf[VertexBroadcastMsg[Long]]
|
||||||
|
writeLong(msg.vid)
|
||||||
|
writeLong(msg.data)
|
||||||
|
this
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
|
||||||
|
override def readObject[T](): T = {
|
||||||
|
val a = readLong()
|
||||||
|
val b = readLong()
|
||||||
|
new VertexBroadcastMsg[Long](0, a, b).asInstanceOf[T]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** A special shuffle serializer for VertexBroadcastMessage[Double]. */
|
||||||
|
class DoubleVertexBroadcastMsgSerializer extends Serializer {
|
||||||
|
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
|
||||||
|
|
||||||
|
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
|
||||||
|
def writeObject[T](t: T) = {
|
||||||
|
val msg = t.asInstanceOf[VertexBroadcastMsg[Double]]
|
||||||
|
writeLong(msg.vid)
|
||||||
|
writeDouble(msg.data)
|
||||||
|
this
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
|
||||||
|
def readObject[T](): T = {
|
||||||
|
val a = readLong()
|
||||||
|
val b = readDouble()
|
||||||
|
new VertexBroadcastMsg[Double](0, a, b).asInstanceOf[T]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/** A special shuffle serializer for AggregationMessage[Int]. */
|
||||||
|
class IntAggMsgSerializer extends Serializer {
|
||||||
|
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
|
||||||
|
|
||||||
|
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
|
||||||
|
def writeObject[T](t: T) = {
|
||||||
|
val msg = t.asInstanceOf[AggregationMsg[Int]]
|
||||||
|
writeLong(msg.vid)
|
||||||
|
writeInt(msg.data)
|
||||||
|
this
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
|
||||||
|
override def readObject[T](): T = {
|
||||||
|
val a = readLong()
|
||||||
|
val b = readInt()
|
||||||
|
new AggregationMsg[Int](a, b).asInstanceOf[T]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** A special shuffle serializer for AggregationMessage[Long]. */
|
||||||
|
class LongAggMsgSerializer extends Serializer {
|
||||||
|
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
|
||||||
|
|
||||||
|
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
|
||||||
|
def writeObject[T](t: T) = {
|
||||||
|
val msg = t.asInstanceOf[AggregationMsg[Long]]
|
||||||
|
writeLong(msg.vid)
|
||||||
|
writeLong(msg.data)
|
||||||
|
this
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
|
||||||
|
override def readObject[T](): T = {
|
||||||
|
val a = readLong()
|
||||||
|
val b = readLong()
|
||||||
|
new AggregationMsg[Long](a, b).asInstanceOf[T]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/** A special shuffle serializer for AggregationMessage[Double]. */
|
||||||
|
class DoubleAggMsgSerializer extends Serializer {
|
||||||
|
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
|
||||||
|
|
||||||
|
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
|
||||||
|
def writeObject[T](t: T) = {
|
||||||
|
val msg = t.asInstanceOf[AggregationMsg[Double]]
|
||||||
|
writeLong(msg.vid)
|
||||||
|
writeDouble(msg.data)
|
||||||
|
this
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
|
||||||
|
def readObject[T](): T = {
|
||||||
|
val a = readLong()
|
||||||
|
val b = readDouble()
|
||||||
|
new AggregationMsg[Double](a, b).asInstanceOf[T]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Helper classes to shorten the implementation of those special serializers.
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
sealed abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream {
|
||||||
|
// The implementation should override this one.
|
||||||
|
def writeObject[T](t: T): SerializationStream
|
||||||
|
|
||||||
|
def writeInt(v: Int) {
|
||||||
|
s.write(v >> 24)
|
||||||
|
s.write(v >> 16)
|
||||||
|
s.write(v >> 8)
|
||||||
|
s.write(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
def writeLong(v: Long) {
|
||||||
|
s.write((v >>> 56).toInt)
|
||||||
|
s.write((v >>> 48).toInt)
|
||||||
|
s.write((v >>> 40).toInt)
|
||||||
|
s.write((v >>> 32).toInt)
|
||||||
|
s.write((v >>> 24).toInt)
|
||||||
|
s.write((v >>> 16).toInt)
|
||||||
|
s.write((v >>> 8).toInt)
|
||||||
|
s.write(v.toInt)
|
||||||
|
}
|
||||||
|
|
||||||
|
def writeDouble(v: Double) {
|
||||||
|
writeLong(java.lang.Double.doubleToLongBits(v))
|
||||||
|
}
|
||||||
|
|
||||||
|
override def flush(): Unit = s.flush()
|
||||||
|
|
||||||
|
override def close(): Unit = s.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
sealed abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream {
|
||||||
|
// The implementation should override this one.
|
||||||
|
def readObject[T](): T
|
||||||
|
|
||||||
|
def readInt(): Int = {
|
||||||
|
val first = s.read()
|
||||||
|
if (first < 0) throw new EOFException
|
||||||
|
(first & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF)
|
||||||
|
}
|
||||||
|
|
||||||
|
def readLong(): Long = {
|
||||||
|
val first = s.read()
|
||||||
|
if (first < 0) throw new EOFException()
|
||||||
|
(first.toLong << 56) |
|
||||||
|
(s.read() & 0xFF).toLong << 48 |
|
||||||
|
(s.read() & 0xFF).toLong << 40 |
|
||||||
|
(s.read() & 0xFF).toLong << 32 |
|
||||||
|
(s.read() & 0xFF).toLong << 24 |
|
||||||
|
(s.read() & 0xFF) << 16 |
|
||||||
|
(s.read() & 0xFF) << 8 |
|
||||||
|
(s.read() & 0xFF)
|
||||||
|
}
|
||||||
|
|
||||||
|
def readDouble(): Double = java.lang.Double.longBitsToDouble(readLong())
|
||||||
|
|
||||||
|
override def close(): Unit = s.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
sealed trait ShuffleSerializerInstance extends SerializerInstance {
|
||||||
|
|
||||||
|
override def serialize[T](t: T): ByteBuffer = throw new UnsupportedOperationException
|
||||||
|
|
||||||
|
override def deserialize[T](bytes: ByteBuffer): T = throw new UnsupportedOperationException
|
||||||
|
|
||||||
|
override def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T =
|
||||||
|
throw new UnsupportedOperationException
|
||||||
|
|
||||||
|
// The implementation should override the following two.
|
||||||
|
override def serializeStream(s: OutputStream): SerializationStream
|
||||||
|
override def deserializeStream(s: InputStream): DeserializationStream
|
||||||
|
}
|
|
@ -8,8 +8,7 @@ package object graph {
|
||||||
type Vid = Long
|
type Vid = Long
|
||||||
type Pid = Int
|
type Pid = Int
|
||||||
|
|
||||||
type VertexHashMap[T] = it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap[T]
|
type VertexSet = OpenHashSet[Vid]
|
||||||
type VertexSet = it.unimi.dsi.fastutil.longs.LongOpenHashSet
|
|
||||||
type VertexArrayList = it.unimi.dsi.fastutil.longs.LongArrayList
|
type VertexArrayList = it.unimi.dsi.fastutil.longs.LongArrayList
|
||||||
|
|
||||||
// type VertexIdToIndexMap = it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap
|
// type VertexIdToIndexMap = it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap
|
||||||
|
|
|
@ -0,0 +1,160 @@
|
||||||
|
package org.apache.spark.graph
|
||||||
|
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
import org.apache.spark.SparkContext
|
||||||
|
import org.apache.spark.graph.LocalSparkContext._
|
||||||
|
import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream}
|
||||||
|
import org.apache.spark.graph.impl._
|
||||||
|
import org.apache.spark.graph.impl.MsgRDDFunctions._
|
||||||
|
import org.apache.spark._
|
||||||
|
|
||||||
|
|
||||||
|
class SerializerSuite extends FunSuite with LocalSparkContext {
|
||||||
|
|
||||||
|
System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||||
|
System.setProperty("spark.kryo.registrator", "org.apache.spark.graph.GraphKryoRegistrator")
|
||||||
|
|
||||||
|
test("TestVertexBroadcastMessageInt") {
|
||||||
|
val outMsg = new VertexBroadcastMsg[Int](3,4,5)
|
||||||
|
val bout = new ByteArrayOutputStream
|
||||||
|
val outStrm = new IntVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
|
||||||
|
outStrm.writeObject(outMsg)
|
||||||
|
outStrm.writeObject(outMsg)
|
||||||
|
bout.flush
|
||||||
|
val bin = new ByteArrayInputStream(bout.toByteArray)
|
||||||
|
val inStrm = new IntVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
|
||||||
|
val inMsg1: VertexBroadcastMsg[Int] = inStrm.readObject()
|
||||||
|
val inMsg2: VertexBroadcastMsg[Int] = inStrm.readObject()
|
||||||
|
assert(outMsg.vid === inMsg1.vid)
|
||||||
|
assert(outMsg.vid === inMsg2.vid)
|
||||||
|
assert(outMsg.data === inMsg1.data)
|
||||||
|
assert(outMsg.data === inMsg2.data)
|
||||||
|
|
||||||
|
intercept[EOFException] {
|
||||||
|
inStrm.readObject()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("TestVertexBroadcastMessageLong") {
|
||||||
|
val outMsg = new VertexBroadcastMsg[Long](3,4,5)
|
||||||
|
val bout = new ByteArrayOutputStream
|
||||||
|
val outStrm = new LongVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
|
||||||
|
outStrm.writeObject(outMsg)
|
||||||
|
outStrm.writeObject(outMsg)
|
||||||
|
bout.flush
|
||||||
|
val bin = new ByteArrayInputStream(bout.toByteArray)
|
||||||
|
val inStrm = new LongVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
|
||||||
|
val inMsg1: VertexBroadcastMsg[Long] = inStrm.readObject()
|
||||||
|
val inMsg2: VertexBroadcastMsg[Long] = inStrm.readObject()
|
||||||
|
assert(outMsg.vid === inMsg1.vid)
|
||||||
|
assert(outMsg.vid === inMsg2.vid)
|
||||||
|
assert(outMsg.data === inMsg1.data)
|
||||||
|
assert(outMsg.data === inMsg2.data)
|
||||||
|
|
||||||
|
intercept[EOFException] {
|
||||||
|
inStrm.readObject()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("TestVertexBroadcastMessageDouble") {
|
||||||
|
val outMsg = new VertexBroadcastMsg[Double](3,4,5.0)
|
||||||
|
val bout = new ByteArrayOutputStream
|
||||||
|
val outStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
|
||||||
|
outStrm.writeObject(outMsg)
|
||||||
|
outStrm.writeObject(outMsg)
|
||||||
|
bout.flush
|
||||||
|
val bin = new ByteArrayInputStream(bout.toByteArray)
|
||||||
|
val inStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
|
||||||
|
val inMsg1: VertexBroadcastMsg[Double] = inStrm.readObject()
|
||||||
|
val inMsg2: VertexBroadcastMsg[Double] = inStrm.readObject()
|
||||||
|
assert(outMsg.vid === inMsg1.vid)
|
||||||
|
assert(outMsg.vid === inMsg2.vid)
|
||||||
|
assert(outMsg.data === inMsg1.data)
|
||||||
|
assert(outMsg.data === inMsg2.data)
|
||||||
|
|
||||||
|
intercept[EOFException] {
|
||||||
|
inStrm.readObject()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("TestAggregationMessageInt") {
|
||||||
|
val outMsg = new AggregationMsg[Int](4,5)
|
||||||
|
val bout = new ByteArrayOutputStream
|
||||||
|
val outStrm = new IntAggMsgSerializer().newInstance().serializeStream(bout)
|
||||||
|
outStrm.writeObject(outMsg)
|
||||||
|
outStrm.writeObject(outMsg)
|
||||||
|
bout.flush
|
||||||
|
val bin = new ByteArrayInputStream(bout.toByteArray)
|
||||||
|
val inStrm = new IntAggMsgSerializer().newInstance().deserializeStream(bin)
|
||||||
|
val inMsg1: AggregationMsg[Int] = inStrm.readObject()
|
||||||
|
val inMsg2: AggregationMsg[Int] = inStrm.readObject()
|
||||||
|
assert(outMsg.vid === inMsg1.vid)
|
||||||
|
assert(outMsg.vid === inMsg2.vid)
|
||||||
|
assert(outMsg.data === inMsg1.data)
|
||||||
|
assert(outMsg.data === inMsg2.data)
|
||||||
|
|
||||||
|
intercept[EOFException] {
|
||||||
|
inStrm.readObject()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("TestAggregationMessageLong") {
|
||||||
|
val outMsg = new AggregationMsg[Long](4,5)
|
||||||
|
val bout = new ByteArrayOutputStream
|
||||||
|
val outStrm = new LongAggMsgSerializer().newInstance().serializeStream(bout)
|
||||||
|
outStrm.writeObject(outMsg)
|
||||||
|
outStrm.writeObject(outMsg)
|
||||||
|
bout.flush
|
||||||
|
val bin = new ByteArrayInputStream(bout.toByteArray)
|
||||||
|
val inStrm = new LongAggMsgSerializer().newInstance().deserializeStream(bin)
|
||||||
|
val inMsg1: AggregationMsg[Long] = inStrm.readObject()
|
||||||
|
val inMsg2: AggregationMsg[Long] = inStrm.readObject()
|
||||||
|
assert(outMsg.vid === inMsg1.vid)
|
||||||
|
assert(outMsg.vid === inMsg2.vid)
|
||||||
|
assert(outMsg.data === inMsg1.data)
|
||||||
|
assert(outMsg.data === inMsg2.data)
|
||||||
|
|
||||||
|
intercept[EOFException] {
|
||||||
|
inStrm.readObject()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("TestAggregationMessageDouble") {
|
||||||
|
val outMsg = new AggregationMsg[Double](4,5.0)
|
||||||
|
val bout = new ByteArrayOutputStream
|
||||||
|
val outStrm = new DoubleAggMsgSerializer().newInstance().serializeStream(bout)
|
||||||
|
outStrm.writeObject(outMsg)
|
||||||
|
outStrm.writeObject(outMsg)
|
||||||
|
bout.flush
|
||||||
|
val bin = new ByteArrayInputStream(bout.toByteArray)
|
||||||
|
val inStrm = new DoubleAggMsgSerializer().newInstance().deserializeStream(bin)
|
||||||
|
val inMsg1: AggregationMsg[Double] = inStrm.readObject()
|
||||||
|
val inMsg2: AggregationMsg[Double] = inStrm.readObject()
|
||||||
|
assert(outMsg.vid === inMsg1.vid)
|
||||||
|
assert(outMsg.vid === inMsg2.vid)
|
||||||
|
assert(outMsg.data === inMsg1.data)
|
||||||
|
assert(outMsg.data === inMsg2.data)
|
||||||
|
|
||||||
|
intercept[EOFException] {
|
||||||
|
inStrm.readObject()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("TestShuffleVertexBroadcastMsg") {
|
||||||
|
withSpark(new SparkContext("local[2]", "test")) { sc =>
|
||||||
|
val bmsgs = sc.parallelize(0 until 100, 10).map { pid =>
|
||||||
|
new VertexBroadcastMsg[Int](pid, pid, pid)
|
||||||
|
}
|
||||||
|
bmsgs.partitionBy(new HashPartitioner(3)).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("TestShuffleAggregationMsg") {
|
||||||
|
withSpark(new SparkContext("local[2]", "test")) { sc =>
|
||||||
|
val bmsgs = sc.parallelize(0 until 100, 10).map(pid => new AggregationMsg[Int](pid, pid))
|
||||||
|
bmsgs.partitionBy(new HashPartitioner(3)).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in a new issue