From 502c5117110574ac1daf3d8347fb2ad71da80e71 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Sat, 9 Nov 2013 22:10:29 -0800 Subject: [PATCH 1/2] Use pid2vid for creating VTableReplicatedValues --- .../spark/graph/GraphKryoRegistrator.scala | 1 + .../graph/impl/VTableReplicatedValues.scala | 48 ++++++++++++++----- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala b/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala index 82b9198e43..baf8099556 100644 --- a/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala +++ b/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala @@ -18,6 +18,7 @@ class GraphKryoRegistrator extends KryoRegistrator { kryo.register(classOf[EdgePartition[Object]]) kryo.register(classOf[BitSet]) kryo.register(classOf[VertexIdToIndexMap]) + kryo.register(classOf[VertexAttributeBlock[Object]]) // This avoids a large number of hash table lookups. kryo.setReferences(false) } diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala b/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala index a9ab6255fa..25cd1b8054 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala @@ -1,7 +1,10 @@ package org.apache.spark.graph.impl +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.collection.OpenHashSet +import org.apache.spark.util.collection.{OpenHashSet, PrimitiveKeyOpenHashMap} import org.apache.spark.graph._ import org.apache.spark.graph.impl.MsgRDDFunctions._ @@ -34,7 +37,7 @@ class VTableReplicatedValues[VD: ClassManifest]( } } - +class VertexAttributeBlock[VD: ClassManifest](val vids: Array[Vid], val attrs: Array[VD]) object VTableReplicatedValues { protected def createVTableReplicated[VD: ClassManifest]( @@ -44,13 +47,30 @@ object VTableReplicatedValues { includeSrcAttr: Boolean, includeDstAttr: Boolean): RDD[(Pid, Array[VD])] = { - // Join vid2pid and vTable, generate a shuffle dependency on the joined - // result, and get the shuffle id so we can use it on the slave. - val msgsByPartition = vTable.zipJoinFlatMap(vid2pid.get(includeSrcAttr, includeDstAttr)) { - // TODO(rxin): reuse VertexBroadcastMessage - (vid, vdata, pids) => pids.iterator.map { pid => - new VertexBroadcastMsg[VD](pid, vid, vdata) + // Within each partition of vid2pid, construct a pid2vid mapping + val numPartitions = vTable.partitions.size + val pid2vid = vid2pid.get(includeSrcAttr, includeDstAttr).mapPartitions { iter => + val pid2vidLocal = Array.fill[ArrayBuilder[Vid]](numPartitions)(ArrayBuilder.make[Vid]) + for ((vid, pids) <- iter) { + pids.foreach { pid => pid2vidLocal(pid) += vid } } + Iterator(pid2vidLocal.map(_.result)) + } + + val msgsByPartition = pid2vid.zipPartitions(vTable.index.rdd, vTable.valuesRDD) { + (pid2vidIter, indexIter, valuesIter) => + val pid2vid = pid2vidIter.next() + val index = indexIter.next() + val values = valuesIter.next() + val vmap = new PrimitiveKeyOpenHashMap(index, values._1) + + // Send each partition the vertex attributes it wants + val output = new Array[(Pid, VertexAttributeBlock[VD])](pid2vid.size) + for (pid <- 0 until pid2vid.size) { + val block = new VertexAttributeBlock(pid2vid(pid), pid2vid(pid).map(vid => vmap(vid))) + output(pid) = (pid, block) + } + output.iterator }.partitionBy(localVidMap.partitioner.get).cache() localVidMap.zipPartitions(msgsByPartition){ @@ -59,14 +79,16 @@ object VTableReplicatedValues { assert(!mapIter.hasNext) // Populate the vertex array using the vidToIndex map val vertexArray = new Array[VD](vidToIndex.capacity) - for (msg <- msgsIter) { - val ind = vidToIndex.getPos(msg.vid) & OpenHashSet.POSITION_MASK - vertexArray(ind) = msg.data + for ((_, block) <- msgsIter) { + for (i <- 0 until block.vids.size) { + val vid = block.vids(i) + val attr = block.attrs(i) + val ind = vidToIndex.getPos(vid) & OpenHashSet.POSITION_MASK + vertexArray(ind) = attr + } } Iterator((pid, vertexArray)) }.cache() - - // @todo assert edge table has partitioner } } From d1ff1b722274de8e03938452d8155f2a26c55f96 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Sun, 10 Nov 2013 01:51:42 -0800 Subject: [PATCH 2/2] Build pid2vid structures only once, in Vid2Pid --- .../graph/impl/VTableReplicatedValues.scala | 12 +------- .../org/apache/spark/graph/impl/Vid2Pid.scala | 29 +++++++++++++++++++ 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala b/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala index 25cd1b8054..fee2d40ee4 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala @@ -1,7 +1,5 @@ package org.apache.spark.graph.impl -import scala.collection.mutable.ArrayBuilder - import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.util.collection.{OpenHashSet, PrimitiveKeyOpenHashMap} @@ -47,15 +45,7 @@ object VTableReplicatedValues { includeSrcAttr: Boolean, includeDstAttr: Boolean): RDD[(Pid, Array[VD])] = { - // Within each partition of vid2pid, construct a pid2vid mapping - val numPartitions = vTable.partitions.size - val pid2vid = vid2pid.get(includeSrcAttr, includeDstAttr).mapPartitions { iter => - val pid2vidLocal = Array.fill[ArrayBuilder[Vid]](numPartitions)(ArrayBuilder.make[Vid]) - for ((vid, pids) <- iter) { - pids.foreach { pid => pid2vidLocal(pid) += vid } - } - Iterator(pid2vidLocal.map(_.result)) - } + val pid2vid = vid2pid.getPid2Vid(includeSrcAttr, includeDstAttr) val msgsByPartition = pid2vid.zipPartitions(vTable.index.rdd, vTable.valuesRDD) { (pid2vidIter, indexIter, valuesIter) => diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala b/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala index 9bdca7f407..363adbbce9 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala @@ -2,6 +2,7 @@ package org.apache.spark.graph.impl import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ArrayBuilder import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -20,6 +21,11 @@ class Vid2Pid( val dstAttrOnly: VertexSetRDD[Array[Pid]] = createVid2Pid(false, true) val noAttrs: VertexSetRDD[Array[Pid]] = createVid2Pid(false, false) + val pid2VidBothAttrs: RDD[Array[Array[Vid]]] = createPid2Vid(bothAttrs) + val pid2VidSrcAttrOnly: RDD[Array[Array[Vid]]] = createPid2Vid(srcAttrOnly) + val pid2VidDstAttrOnly: RDD[Array[Array[Vid]]] = createPid2Vid(dstAttrOnly) + val pid2VidNoAttrs: RDD[Array[Array[Vid]]] = createPid2Vid(noAttrs) + def get(includeSrcAttr: Boolean, includeDstAttr: Boolean): VertexSetRDD[Array[Pid]] = (includeSrcAttr, includeDstAttr) match { case (true, true) => bothAttrs @@ -28,6 +34,14 @@ class Vid2Pid( case (false, false) => noAttrs } + def getPid2Vid(includeSrcAttr: Boolean, includeDstAttr: Boolean): RDD[Array[Array[Vid]]] = + (includeSrcAttr, includeDstAttr) match { + case (true, true) => pid2VidBothAttrs + case (true, false) => pid2VidSrcAttrOnly + case (false, true) => pid2VidDstAttrOnly + case (false, false) => pid2VidNoAttrs + } + def persist(newLevel: StorageLevel) { bothAttrs.persist(newLevel) srcAttrOnly.persist(newLevel) @@ -55,4 +69,19 @@ class Vid2Pid( (a: ArrayBuffer[Pid], b: ArrayBuffer[Pid]) => a ++ b) .mapValues(a => a.toArray).cache() } + + /** + * Creates an intermediate pid2vid structure that tells each partition of the + * vertex data where it should go. + */ + private def createPid2Vid(vid2pid: VertexSetRDD[Array[Pid]]): RDD[Array[Array[Vid]]] = { + val numPartitions = vid2pid.partitions.size + vid2pid.mapPartitions { iter => + val pid2vidLocal = Array.fill[ArrayBuilder[Vid]](numPartitions)(ArrayBuilder.make[Vid]) + for ((vid, pids) <- iter) { + pids.foreach { pid => pid2vidLocal(pid) += vid } + } + Iterator(pid2vidLocal.map(_.result)) + } + } }