Factor out VTableReplicatedValues

This commit is contained in:
Ankur Dave 2013-11-09 04:20:30 -08:00
parent cdbd19bbee
commit bf4e45e685
3 changed files with 112 additions and 74 deletions

View file

@ -63,6 +63,13 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest](
/**
* A Graph RDD that supports computation on graphs.
*
* @param localVidMap Stores the location of vertex attributes after they are
* replicated. Within each partition, localVidMap holds a map from vertex ID to
* the index where that vertex's attribute is stored. This index refers to the
* arrays in the same partition in the variants of
* [[VTableReplicatedValues]]. Therefore, localVidMap can be reused across
* changes to the vertex attributes.
*/
class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
@transient val vTable: VertexSetRDD[VD],
@ -73,27 +80,8 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
def this() = this(null, null, null, null)
/**
* (localVidMap: VertexSetRDD[Pid, VertexIdToIndexMap]) is a version of the
* vertex data after it is replicated. Within each partition, it holds a map
* from vertex ID to the index where that vertex's attribute is stored. This
* index refers to an array in the same partition in vTableReplicatedValues.
*
* (vTableReplicatedValues: VertexSetRDD[Pid, Array[VD]]) holds the vertex data
* and is arranged as described above.
*/
@transient val vTableReplicatedValuesBothAttrs: RDD[(Pid, Array[VD])] =
createVTableReplicated(vTable, vid2pid.bothAttrs, localVidMap)
@transient val vTableReplicatedValuesSrcAttrOnly: RDD[(Pid, Array[VD])] =
createVTableReplicated(vTable, vid2pid.srcAttrOnly, localVidMap)
@transient val vTableReplicatedValuesDstAttrOnly: RDD[(Pid, Array[VD])] =
createVTableReplicated(vTable, vid2pid.dstAttrOnly, localVidMap)
// TODO(ankurdave): create this more efficiently
@transient val vTableReplicatedValuesNoAttrs: RDD[(Pid, Array[VD])] =
createVTableReplicated(vTable, vid2pid.noAttrs, localVidMap)
@transient val vTableReplicatedValues: VTableReplicatedValues[VD] =
new VTableReplicatedValues(vTable, vid2pid, localVidMap)
/** Return a RDD of vertices. */
@transient override val vertices = vTable
@ -105,7 +93,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
/** Return a RDD that brings edges with its source and destination vertices together. */
@transient override val triplets: RDD[EdgeTriplet[VD, ED]] =
makeTriplets(localVidMap, vTableReplicatedValuesBothAttrs, eTable)
makeTriplets(localVidMap, vTableReplicatedValues.bothAttrs, eTable)
override def persist(newLevel: StorageLevel): Graph[VD, ED] = {
eTable.persist(newLevel)
@ -188,9 +176,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
traverseLineage(localVidMap, " ", visited)
visited += (localVidMap.id -> "localVidMap")
println("\n\nvTableReplicatedValuesBothAttrs -----------------")
traverseLineage(vTableReplicatedValuesBothAttrs, " ", visited)
visited += (vTableReplicatedValuesBothAttrs.id -> "vTableReplicatedValuesBothAttrs")
println("\n\nvTableReplicatedValues.bothAttrs ----------------")
traverseLineage(vTableReplicatedValues.bothAttrs, " ", visited)
visited += (vTableReplicatedValues.bothAttrs.id -> "vTableReplicatedValues.bothAttrs")
println("\n\ntriplets ----------------------------------------")
traverseLineage(triplets, " ", visited)
@ -386,8 +374,9 @@ object GraphImpl {
}, preservesPartitioning = true).cache()
}
protected def createLocalVidMap[ED: ClassManifest](eTable: RDD[(Pid, EdgePartition[ED])]):
RDD[(Pid, VertexIdToIndexMap)] = {
private def createLocalVidMap(
eTable: RDD[(Pid, EdgePartition[ED])] forSome { type ED }
): RDD[(Pid, VertexIdToIndexMap)] = {
eTable.mapPartitions( _.map{ case (pid, epart) =>
val vidToIndex = new VertexIdToIndexMap
epart.foreach{ e =>
@ -398,36 +387,6 @@ object GraphImpl {
}, preservesPartitioning = true).cache()
}
protected def createVTableReplicated[VD: ClassManifest](
vTable: VertexSetRDD[VD],
vid2pid: VertexSetRDD[Array[Pid]],
replicationMap: RDD[(Pid, VertexIdToIndexMap)]):
RDD[(Pid, Array[VD])] = {
// Join vid2pid and vTable, generate a shuffle dependency on the joined
// result, and get the shuffle id so we can use it on the slave.
val msgsByPartition = vTable.zipJoinFlatMap(vid2pid) { (vid, vdata, pids) =>
// TODO(rxin): reuse VertexBroadcastMessage
pids.iterator.map { pid =>
new VertexBroadcastMsg[VD](pid, vid, vdata)
}
}.partitionBy(replicationMap.partitioner.get).cache()
replicationMap.zipPartitions(msgsByPartition){
(mapIter, msgsIter) =>
val (pid, vidToIndex) = mapIter.next()
assert(!mapIter.hasNext)
// Populate the vertex array using the vidToIndex map
val vertexArray = new Array[VD](vidToIndex.capacity)
for (msg <- msgsIter) {
val ind = vidToIndex.getPos(msg.vid) & OpenHashSet.POSITION_MASK
vertexArray(ind) = msg.data
}
Iterator((pid, vertexArray))
}.cache()
// @todo assert edge table has partitioner
}
def makeTriplets[VD: ClassManifest, ED: ClassManifest](
localVidMap: RDD[(Pid, VertexIdToIndexMap)],
vTableReplicatedValues: RDD[(Pid, Array[VD]) ],
@ -444,7 +403,7 @@ object GraphImpl {
def mapTriplets[VD: ClassManifest, ED: ClassManifest, ED2: ClassManifest](
g: GraphImpl[VD, ED],
f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
val newETable = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValuesBothAttrs){
val newETable = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues.bothAttrs){
(edgePartitionIter, vidToIndexIter, vertexArrayIter) =>
val (pid, edgePartition) = edgePartitionIter.next()
val (_, vidToIndex) = vidToIndexIter.next()
@ -476,15 +435,12 @@ object GraphImpl {
BytecodeUtils.invokedMethod(mapFunc, classOf[EdgeTriplet[VD, ED]], "srcAttr")
val mapFuncUsesDstAttr =
BytecodeUtils.invokedMethod(mapFunc, classOf[EdgeTriplet[VD, ED]], "dstAttr")
val vTableReplicatedValues = (mapFuncUsesSrcAttr, mapFuncUsesDstAttr) match {
case (true, true) => g.vTableReplicatedValuesBothAttrs
case (true, false) => g.vTableReplicatedValuesSrcAttrOnly
case (false, true) => g.vTableReplicatedValuesDstAttrOnly
case (false, false) => g.vTableReplicatedValuesNoAttrs
}
// Map and preaggregate
val preAgg = g.eTable.zipPartitions(g.localVidMap, vTableReplicatedValues){
val preAgg = g.eTable.zipPartitions(
g.localVidMap,
g.vTableReplicatedValues.get(mapFuncUsesSrcAttr, mapFuncUsesDstAttr)
){
(edgePartitionIter, vidToIndexIter, vertexArrayIter) =>
val (_, edgePartition) = edgePartitionIter.next()
val (_, vidToIndex) = vidToIndexIter.next()

View file

@ -0,0 +1,72 @@
package org.apache.spark.graph.impl
import org.apache.spark.rdd.RDD
import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.graph._
import org.apache.spark.graph.impl.MsgRDDFunctions._
/**
* Stores the vertex attribute values after they are replicated. See
* the description of localVidMap in [[GraphImpl]].
*/
class VTableReplicatedValues[VD: ClassManifest](
vTable: VertexSetRDD[VD],
vid2pid: Vid2Pid,
localVidMap: RDD[(Pid, VertexIdToIndexMap)]) {
val bothAttrs: RDD[(Pid, Array[VD])] =
VTableReplicatedValues.createVTableReplicated(vTable, vid2pid, localVidMap, true, true)
val srcAttrOnly: RDD[(Pid, Array[VD])] =
VTableReplicatedValues.createVTableReplicated(vTable, vid2pid, localVidMap, true, false)
val dstAttrOnly: RDD[(Pid, Array[VD])] =
VTableReplicatedValues.createVTableReplicated(vTable, vid2pid, localVidMap, false, true)
val noAttrs: RDD[(Pid, Array[VD])] =
VTableReplicatedValues.createVTableReplicated(vTable, vid2pid, localVidMap, false, false)
def get(includeSrcAttr: Boolean, includeDstAttr: Boolean): RDD[(Pid, Array[VD])] =
(includeSrcAttr, includeDstAttr) match {
case (true, true) => bothAttrs
case (true, false) => srcAttrOnly
case (false, true) => dstAttrOnly
case (false, false) => noAttrs
}
}
object VTableReplicatedValues {
protected def createVTableReplicated[VD: ClassManifest](
vTable: VertexSetRDD[VD],
vid2pid: Vid2Pid,
localVidMap: RDD[(Pid, VertexIdToIndexMap)],
includeSrcAttr: Boolean,
includeDstAttr: Boolean): RDD[(Pid, Array[VD])] = {
// Join vid2pid and vTable, generate a shuffle dependency on the joined
// result, and get the shuffle id so we can use it on the slave.
val msgsByPartition = vTable.zipJoinFlatMap(vid2pid.get(includeSrcAttr, includeDstAttr)) {
// TODO(rxin): reuse VertexBroadcastMessage
(vid, vdata, pids) => pids.iterator.map { pid =>
new VertexBroadcastMsg[VD](pid, vid, vdata)
}
}.partitionBy(localVidMap.partitioner.get).cache()
localVidMap.zipPartitions(msgsByPartition){
(mapIter, msgsIter) =>
val (pid, vidToIndex) = mapIter.next()
assert(!mapIter.hasNext)
// Populate the vertex array using the vidToIndex map
val vertexArray = new Array[VD](vidToIndex.capacity)
for (msg <- msgsIter) {
val ind = vidToIndex.getPos(msg.vid) & OpenHashSet.POSITION_MASK
vertexArray(ind) = msg.data
}
Iterator((pid, vertexArray))
}.cache()
// @todo assert edge table has partitioner
}
}

View file

@ -3,12 +3,13 @@ package org.apache.spark.graph.impl
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.graph._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.graph._
/**
* Stores the layout of vertex attributes.
* Stores the layout of vertex attributes for GraphImpl.
*/
class Vid2Pid(
eTable: RDD[(Pid, EdgePartition[ED])] forSome { type ED },
@ -17,9 +18,16 @@ class Vid2Pid(
val bothAttrs: VertexSetRDD[Array[Pid]] = createVid2Pid(true, true)
val srcAttrOnly: VertexSetRDD[Array[Pid]] = createVid2Pid(true, false)
val dstAttrOnly: VertexSetRDD[Array[Pid]] = createVid2Pid(false, true)
// TODO(ankurdave): create this more efficiently
val noAttrs: VertexSetRDD[Array[Pid]] = createVid2Pid(false, false)
def get(includeSrcAttr: Boolean, includeDstAttr: Boolean): VertexSetRDD[Array[Pid]] =
(includeSrcAttr, includeDstAttr) match {
case (true, true) => bothAttrs
case (true, false) => srcAttrOnly
case (false, true) => dstAttrOnly
case (false, false) => noAttrs
}
def persist(newLevel: StorageLevel) {
bothAttrs.persist(newLevel)
srcAttrOnly.persist(newLevel)
@ -33,10 +41,12 @@ class Vid2Pid(
val preAgg = eTable.mapPartitions { iter =>
val (pid, edgePartition) = iter.next()
val vSet = new VertexSet
if (includeSrcAttr || includeDstAttr) {
edgePartition.foreach(e => {
if (includeSrcAttr) vSet.add(e.srcId)
if (includeDstAttr) vSet.add(e.dstId)
})
}
vSet.iterator.map { vid => (vid.toLong, pid) }
}
VertexSetRDD[Pid, ArrayBuffer[Pid]](preAgg, vTableIndex,