Factor out VTableReplicatedValues
This commit is contained in:
parent
cdbd19bbee
commit
bf4e45e685
|
@ -63,6 +63,13 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest](
|
|||
|
||||
/**
|
||||
* A Graph RDD that supports computation on graphs.
|
||||
*
|
||||
* @param localVidMap Stores the location of vertex attributes after they are
|
||||
* replicated. Within each partition, localVidMap holds a map from vertex ID to
|
||||
* the index where that vertex's attribute is stored. This index refers to the
|
||||
* arrays in the same partition in the variants of
|
||||
* [[VTableReplicatedValues]]. Therefore, localVidMap can be reused across
|
||||
* changes to the vertex attributes.
|
||||
*/
|
||||
class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
|
||||
@transient val vTable: VertexSetRDD[VD],
|
||||
|
@ -73,27 +80,8 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
|
|||
|
||||
def this() = this(null, null, null, null)
|
||||
|
||||
/**
|
||||
* (localVidMap: VertexSetRDD[Pid, VertexIdToIndexMap]) is a version of the
|
||||
* vertex data after it is replicated. Within each partition, it holds a map
|
||||
* from vertex ID to the index where that vertex's attribute is stored. This
|
||||
* index refers to an array in the same partition in vTableReplicatedValues.
|
||||
*
|
||||
* (vTableReplicatedValues: VertexSetRDD[Pid, Array[VD]]) holds the vertex data
|
||||
* and is arranged as described above.
|
||||
*/
|
||||
@transient val vTableReplicatedValuesBothAttrs: RDD[(Pid, Array[VD])] =
|
||||
createVTableReplicated(vTable, vid2pid.bothAttrs, localVidMap)
|
||||
|
||||
@transient val vTableReplicatedValuesSrcAttrOnly: RDD[(Pid, Array[VD])] =
|
||||
createVTableReplicated(vTable, vid2pid.srcAttrOnly, localVidMap)
|
||||
|
||||
@transient val vTableReplicatedValuesDstAttrOnly: RDD[(Pid, Array[VD])] =
|
||||
createVTableReplicated(vTable, vid2pid.dstAttrOnly, localVidMap)
|
||||
|
||||
// TODO(ankurdave): create this more efficiently
|
||||
@transient val vTableReplicatedValuesNoAttrs: RDD[(Pid, Array[VD])] =
|
||||
createVTableReplicated(vTable, vid2pid.noAttrs, localVidMap)
|
||||
@transient val vTableReplicatedValues: VTableReplicatedValues[VD] =
|
||||
new VTableReplicatedValues(vTable, vid2pid, localVidMap)
|
||||
|
||||
/** Return a RDD of vertices. */
|
||||
@transient override val vertices = vTable
|
||||
|
@ -105,7 +93,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
|
|||
|
||||
/** Return a RDD that brings edges with its source and destination vertices together. */
|
||||
@transient override val triplets: RDD[EdgeTriplet[VD, ED]] =
|
||||
makeTriplets(localVidMap, vTableReplicatedValuesBothAttrs, eTable)
|
||||
makeTriplets(localVidMap, vTableReplicatedValues.bothAttrs, eTable)
|
||||
|
||||
override def persist(newLevel: StorageLevel): Graph[VD, ED] = {
|
||||
eTable.persist(newLevel)
|
||||
|
@ -188,9 +176,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
|
|||
traverseLineage(localVidMap, " ", visited)
|
||||
visited += (localVidMap.id -> "localVidMap")
|
||||
|
||||
println("\n\nvTableReplicatedValuesBothAttrs -----------------")
|
||||
traverseLineage(vTableReplicatedValuesBothAttrs, " ", visited)
|
||||
visited += (vTableReplicatedValuesBothAttrs.id -> "vTableReplicatedValuesBothAttrs")
|
||||
println("\n\nvTableReplicatedValues.bothAttrs ----------------")
|
||||
traverseLineage(vTableReplicatedValues.bothAttrs, " ", visited)
|
||||
visited += (vTableReplicatedValues.bothAttrs.id -> "vTableReplicatedValues.bothAttrs")
|
||||
|
||||
println("\n\ntriplets ----------------------------------------")
|
||||
traverseLineage(triplets, " ", visited)
|
||||
|
@ -386,8 +374,9 @@ object GraphImpl {
|
|||
}, preservesPartitioning = true).cache()
|
||||
}
|
||||
|
||||
protected def createLocalVidMap[ED: ClassManifest](eTable: RDD[(Pid, EdgePartition[ED])]):
|
||||
RDD[(Pid, VertexIdToIndexMap)] = {
|
||||
private def createLocalVidMap(
|
||||
eTable: RDD[(Pid, EdgePartition[ED])] forSome { type ED }
|
||||
): RDD[(Pid, VertexIdToIndexMap)] = {
|
||||
eTable.mapPartitions( _.map{ case (pid, epart) =>
|
||||
val vidToIndex = new VertexIdToIndexMap
|
||||
epart.foreach{ e =>
|
||||
|
@ -398,36 +387,6 @@ object GraphImpl {
|
|||
}, preservesPartitioning = true).cache()
|
||||
}
|
||||
|
||||
protected def createVTableReplicated[VD: ClassManifest](
|
||||
vTable: VertexSetRDD[VD],
|
||||
vid2pid: VertexSetRDD[Array[Pid]],
|
||||
replicationMap: RDD[(Pid, VertexIdToIndexMap)]):
|
||||
RDD[(Pid, Array[VD])] = {
|
||||
// Join vid2pid and vTable, generate a shuffle dependency on the joined
|
||||
// result, and get the shuffle id so we can use it on the slave.
|
||||
val msgsByPartition = vTable.zipJoinFlatMap(vid2pid) { (vid, vdata, pids) =>
|
||||
// TODO(rxin): reuse VertexBroadcastMessage
|
||||
pids.iterator.map { pid =>
|
||||
new VertexBroadcastMsg[VD](pid, vid, vdata)
|
||||
}
|
||||
}.partitionBy(replicationMap.partitioner.get).cache()
|
||||
|
||||
replicationMap.zipPartitions(msgsByPartition){
|
||||
(mapIter, msgsIter) =>
|
||||
val (pid, vidToIndex) = mapIter.next()
|
||||
assert(!mapIter.hasNext)
|
||||
// Populate the vertex array using the vidToIndex map
|
||||
val vertexArray = new Array[VD](vidToIndex.capacity)
|
||||
for (msg <- msgsIter) {
|
||||
val ind = vidToIndex.getPos(msg.vid) & OpenHashSet.POSITION_MASK
|
||||
vertexArray(ind) = msg.data
|
||||
}
|
||||
Iterator((pid, vertexArray))
|
||||
}.cache()
|
||||
|
||||
// @todo assert edge table has partitioner
|
||||
}
|
||||
|
||||
def makeTriplets[VD: ClassManifest, ED: ClassManifest](
|
||||
localVidMap: RDD[(Pid, VertexIdToIndexMap)],
|
||||
vTableReplicatedValues: RDD[(Pid, Array[VD]) ],
|
||||
|
@ -444,7 +403,7 @@ object GraphImpl {
|
|||
def mapTriplets[VD: ClassManifest, ED: ClassManifest, ED2: ClassManifest](
|
||||
g: GraphImpl[VD, ED],
|
||||
f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
|
||||
val newETable = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValuesBothAttrs){
|
||||
val newETable = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues.bothAttrs){
|
||||
(edgePartitionIter, vidToIndexIter, vertexArrayIter) =>
|
||||
val (pid, edgePartition) = edgePartitionIter.next()
|
||||
val (_, vidToIndex) = vidToIndexIter.next()
|
||||
|
@ -476,15 +435,12 @@ object GraphImpl {
|
|||
BytecodeUtils.invokedMethod(mapFunc, classOf[EdgeTriplet[VD, ED]], "srcAttr")
|
||||
val mapFuncUsesDstAttr =
|
||||
BytecodeUtils.invokedMethod(mapFunc, classOf[EdgeTriplet[VD, ED]], "dstAttr")
|
||||
val vTableReplicatedValues = (mapFuncUsesSrcAttr, mapFuncUsesDstAttr) match {
|
||||
case (true, true) => g.vTableReplicatedValuesBothAttrs
|
||||
case (true, false) => g.vTableReplicatedValuesSrcAttrOnly
|
||||
case (false, true) => g.vTableReplicatedValuesDstAttrOnly
|
||||
case (false, false) => g.vTableReplicatedValuesNoAttrs
|
||||
}
|
||||
|
||||
// Map and preaggregate
|
||||
val preAgg = g.eTable.zipPartitions(g.localVidMap, vTableReplicatedValues){
|
||||
val preAgg = g.eTable.zipPartitions(
|
||||
g.localVidMap,
|
||||
g.vTableReplicatedValues.get(mapFuncUsesSrcAttr, mapFuncUsesDstAttr)
|
||||
){
|
||||
(edgePartitionIter, vidToIndexIter, vertexArrayIter) =>
|
||||
val (_, edgePartition) = edgePartitionIter.next()
|
||||
val (_, vidToIndex) = vidToIndexIter.next()
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
package org.apache.spark.graph.impl
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.util.collection.OpenHashSet
|
||||
|
||||
import org.apache.spark.graph._
|
||||
import org.apache.spark.graph.impl.MsgRDDFunctions._
|
||||
|
||||
/**
|
||||
* Stores the vertex attribute values after they are replicated. See
|
||||
* the description of localVidMap in [[GraphImpl]].
|
||||
*/
|
||||
class VTableReplicatedValues[VD: ClassManifest](
|
||||
vTable: VertexSetRDD[VD],
|
||||
vid2pid: Vid2Pid,
|
||||
localVidMap: RDD[(Pid, VertexIdToIndexMap)]) {
|
||||
|
||||
val bothAttrs: RDD[(Pid, Array[VD])] =
|
||||
VTableReplicatedValues.createVTableReplicated(vTable, vid2pid, localVidMap, true, true)
|
||||
val srcAttrOnly: RDD[(Pid, Array[VD])] =
|
||||
VTableReplicatedValues.createVTableReplicated(vTable, vid2pid, localVidMap, true, false)
|
||||
val dstAttrOnly: RDD[(Pid, Array[VD])] =
|
||||
VTableReplicatedValues.createVTableReplicated(vTable, vid2pid, localVidMap, false, true)
|
||||
val noAttrs: RDD[(Pid, Array[VD])] =
|
||||
VTableReplicatedValues.createVTableReplicated(vTable, vid2pid, localVidMap, false, false)
|
||||
|
||||
|
||||
def get(includeSrcAttr: Boolean, includeDstAttr: Boolean): RDD[(Pid, Array[VD])] =
|
||||
(includeSrcAttr, includeDstAttr) match {
|
||||
case (true, true) => bothAttrs
|
||||
case (true, false) => srcAttrOnly
|
||||
case (false, true) => dstAttrOnly
|
||||
case (false, false) => noAttrs
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
object VTableReplicatedValues {
|
||||
protected def createVTableReplicated[VD: ClassManifest](
|
||||
vTable: VertexSetRDD[VD],
|
||||
vid2pid: Vid2Pid,
|
||||
localVidMap: RDD[(Pid, VertexIdToIndexMap)],
|
||||
includeSrcAttr: Boolean,
|
||||
includeDstAttr: Boolean): RDD[(Pid, Array[VD])] = {
|
||||
|
||||
// Join vid2pid and vTable, generate a shuffle dependency on the joined
|
||||
// result, and get the shuffle id so we can use it on the slave.
|
||||
val msgsByPartition = vTable.zipJoinFlatMap(vid2pid.get(includeSrcAttr, includeDstAttr)) {
|
||||
// TODO(rxin): reuse VertexBroadcastMessage
|
||||
(vid, vdata, pids) => pids.iterator.map { pid =>
|
||||
new VertexBroadcastMsg[VD](pid, vid, vdata)
|
||||
}
|
||||
}.partitionBy(localVidMap.partitioner.get).cache()
|
||||
|
||||
localVidMap.zipPartitions(msgsByPartition){
|
||||
(mapIter, msgsIter) =>
|
||||
val (pid, vidToIndex) = mapIter.next()
|
||||
assert(!mapIter.hasNext)
|
||||
// Populate the vertex array using the vidToIndex map
|
||||
val vertexArray = new Array[VD](vidToIndex.capacity)
|
||||
for (msg <- msgsIter) {
|
||||
val ind = vidToIndex.getPos(msg.vid) & OpenHashSet.POSITION_MASK
|
||||
vertexArray(ind) = msg.data
|
||||
}
|
||||
Iterator((pid, vertexArray))
|
||||
}.cache()
|
||||
|
||||
// @todo assert edge table has partitioner
|
||||
}
|
||||
|
||||
}
|
|
@ -3,12 +3,13 @@ package org.apache.spark.graph.impl
|
|||
import scala.collection.JavaConversions._
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.graph._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
||||
import org.apache.spark.graph._
|
||||
|
||||
/**
|
||||
* Stores the layout of vertex attributes.
|
||||
* Stores the layout of vertex attributes for GraphImpl.
|
||||
*/
|
||||
class Vid2Pid(
|
||||
eTable: RDD[(Pid, EdgePartition[ED])] forSome { type ED },
|
||||
|
@ -17,9 +18,16 @@ class Vid2Pid(
|
|||
val bothAttrs: VertexSetRDD[Array[Pid]] = createVid2Pid(true, true)
|
||||
val srcAttrOnly: VertexSetRDD[Array[Pid]] = createVid2Pid(true, false)
|
||||
val dstAttrOnly: VertexSetRDD[Array[Pid]] = createVid2Pid(false, true)
|
||||
// TODO(ankurdave): create this more efficiently
|
||||
val noAttrs: VertexSetRDD[Array[Pid]] = createVid2Pid(false, false)
|
||||
|
||||
def get(includeSrcAttr: Boolean, includeDstAttr: Boolean): VertexSetRDD[Array[Pid]] =
|
||||
(includeSrcAttr, includeDstAttr) match {
|
||||
case (true, true) => bothAttrs
|
||||
case (true, false) => srcAttrOnly
|
||||
case (false, true) => dstAttrOnly
|
||||
case (false, false) => noAttrs
|
||||
}
|
||||
|
||||
def persist(newLevel: StorageLevel) {
|
||||
bothAttrs.persist(newLevel)
|
||||
srcAttrOnly.persist(newLevel)
|
||||
|
@ -33,10 +41,12 @@ class Vid2Pid(
|
|||
val preAgg = eTable.mapPartitions { iter =>
|
||||
val (pid, edgePartition) = iter.next()
|
||||
val vSet = new VertexSet
|
||||
if (includeSrcAttr || includeDstAttr) {
|
||||
edgePartition.foreach(e => {
|
||||
if (includeSrcAttr) vSet.add(e.srcId)
|
||||
if (includeDstAttr) vSet.add(e.dstId)
|
||||
})
|
||||
}
|
||||
vSet.iterator.map { vid => (vid.toLong, pid) }
|
||||
}
|
||||
VertexSetRDD[Pid, ArrayBuffer[Pid]](preAgg, vTableIndex,
|
||||
|
|
Loading…
Reference in a new issue