diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala index 7c0b9e23f2..ae1ea715e2 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala @@ -63,6 +63,13 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest]( /** * A Graph RDD that supports computation on graphs. + * + * @param localVidMap Stores the location of vertex attributes after they are + * replicated. Within each partition, localVidMap holds a map from vertex ID to + * the index where that vertex's attribute is stored. This index refers to the + * arrays in the same partition in the variants of + * [[VTableReplicatedValues]]. Therefore, localVidMap can be reused across + * changes to the vertex attributes. */ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( @transient val vTable: VertexSetRDD[VD], @@ -73,27 +80,8 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( def this() = this(null, null, null, null) - /** - * (localVidMap: VertexSetRDD[Pid, VertexIdToIndexMap]) is a version of the - * vertex data after it is replicated. Within each partition, it holds a map - * from vertex ID to the index where that vertex's attribute is stored. This - * index refers to an array in the same partition in vTableReplicatedValues. - * - * (vTableReplicatedValues: VertexSetRDD[Pid, Array[VD]]) holds the vertex data - * and is arranged as described above. - */ - @transient val vTableReplicatedValuesBothAttrs: RDD[(Pid, Array[VD])] = - createVTableReplicated(vTable, vid2pid.bothAttrs, localVidMap) - - @transient val vTableReplicatedValuesSrcAttrOnly: RDD[(Pid, Array[VD])] = - createVTableReplicated(vTable, vid2pid.srcAttrOnly, localVidMap) - - @transient val vTableReplicatedValuesDstAttrOnly: RDD[(Pid, Array[VD])] = - createVTableReplicated(vTable, vid2pid.dstAttrOnly, localVidMap) - - // TODO(ankurdave): create this more efficiently - @transient val vTableReplicatedValuesNoAttrs: RDD[(Pid, Array[VD])] = - createVTableReplicated(vTable, vid2pid.noAttrs, localVidMap) + @transient val vTableReplicatedValues: VTableReplicatedValues[VD] = + new VTableReplicatedValues(vTable, vid2pid, localVidMap) /** Return a RDD of vertices. */ @transient override val vertices = vTable @@ -105,7 +93,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( /** Return a RDD that brings edges with its source and destination vertices together. */ @transient override val triplets: RDD[EdgeTriplet[VD, ED]] = - makeTriplets(localVidMap, vTableReplicatedValuesBothAttrs, eTable) + makeTriplets(localVidMap, vTableReplicatedValues.bothAttrs, eTable) override def persist(newLevel: StorageLevel): Graph[VD, ED] = { eTable.persist(newLevel) @@ -188,9 +176,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( traverseLineage(localVidMap, " ", visited) visited += (localVidMap.id -> "localVidMap") - println("\n\nvTableReplicatedValuesBothAttrs -----------------") - traverseLineage(vTableReplicatedValuesBothAttrs, " ", visited) - visited += (vTableReplicatedValuesBothAttrs.id -> "vTableReplicatedValuesBothAttrs") + println("\n\nvTableReplicatedValues.bothAttrs ----------------") + traverseLineage(vTableReplicatedValues.bothAttrs, " ", visited) + visited += (vTableReplicatedValues.bothAttrs.id -> "vTableReplicatedValues.bothAttrs") println("\n\ntriplets ----------------------------------------") traverseLineage(triplets, " ", visited) @@ -386,8 +374,9 @@ object GraphImpl { }, preservesPartitioning = true).cache() } - protected def createLocalVidMap[ED: ClassManifest](eTable: RDD[(Pid, EdgePartition[ED])]): - RDD[(Pid, VertexIdToIndexMap)] = { + private def createLocalVidMap( + eTable: RDD[(Pid, EdgePartition[ED])] forSome { type ED } + ): RDD[(Pid, VertexIdToIndexMap)] = { eTable.mapPartitions( _.map{ case (pid, epart) => val vidToIndex = new VertexIdToIndexMap epart.foreach{ e => @@ -398,36 +387,6 @@ object GraphImpl { }, preservesPartitioning = true).cache() } - protected def createVTableReplicated[VD: ClassManifest]( - vTable: VertexSetRDD[VD], - vid2pid: VertexSetRDD[Array[Pid]], - replicationMap: RDD[(Pid, VertexIdToIndexMap)]): - RDD[(Pid, Array[VD])] = { - // Join vid2pid and vTable, generate a shuffle dependency on the joined - // result, and get the shuffle id so we can use it on the slave. - val msgsByPartition = vTable.zipJoinFlatMap(vid2pid) { (vid, vdata, pids) => - // TODO(rxin): reuse VertexBroadcastMessage - pids.iterator.map { pid => - new VertexBroadcastMsg[VD](pid, vid, vdata) - } - }.partitionBy(replicationMap.partitioner.get).cache() - - replicationMap.zipPartitions(msgsByPartition){ - (mapIter, msgsIter) => - val (pid, vidToIndex) = mapIter.next() - assert(!mapIter.hasNext) - // Populate the vertex array using the vidToIndex map - val vertexArray = new Array[VD](vidToIndex.capacity) - for (msg <- msgsIter) { - val ind = vidToIndex.getPos(msg.vid) & OpenHashSet.POSITION_MASK - vertexArray(ind) = msg.data - } - Iterator((pid, vertexArray)) - }.cache() - - // @todo assert edge table has partitioner - } - def makeTriplets[VD: ClassManifest, ED: ClassManifest]( localVidMap: RDD[(Pid, VertexIdToIndexMap)], vTableReplicatedValues: RDD[(Pid, Array[VD]) ], @@ -444,7 +403,7 @@ object GraphImpl { def mapTriplets[VD: ClassManifest, ED: ClassManifest, ED2: ClassManifest]( g: GraphImpl[VD, ED], f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = { - val newETable = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValuesBothAttrs){ + val newETable = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues.bothAttrs){ (edgePartitionIter, vidToIndexIter, vertexArrayIter) => val (pid, edgePartition) = edgePartitionIter.next() val (_, vidToIndex) = vidToIndexIter.next() @@ -476,15 +435,12 @@ object GraphImpl { BytecodeUtils.invokedMethod(mapFunc, classOf[EdgeTriplet[VD, ED]], "srcAttr") val mapFuncUsesDstAttr = BytecodeUtils.invokedMethod(mapFunc, classOf[EdgeTriplet[VD, ED]], "dstAttr") - val vTableReplicatedValues = (mapFuncUsesSrcAttr, mapFuncUsesDstAttr) match { - case (true, true) => g.vTableReplicatedValuesBothAttrs - case (true, false) => g.vTableReplicatedValuesSrcAttrOnly - case (false, true) => g.vTableReplicatedValuesDstAttrOnly - case (false, false) => g.vTableReplicatedValuesNoAttrs - } // Map and preaggregate - val preAgg = g.eTable.zipPartitions(g.localVidMap, vTableReplicatedValues){ + val preAgg = g.eTable.zipPartitions( + g.localVidMap, + g.vTableReplicatedValues.get(mapFuncUsesSrcAttr, mapFuncUsesDstAttr) + ){ (edgePartitionIter, vidToIndexIter, vertexArrayIter) => val (_, edgePartition) = edgePartitionIter.next() val (_, vidToIndex) = vidToIndexIter.next() diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala b/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala new file mode 100644 index 0000000000..a9ab6255fa --- /dev/null +++ b/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala @@ -0,0 +1,72 @@ +package org.apache.spark.graph.impl + +import org.apache.spark.rdd.RDD +import org.apache.spark.util.collection.OpenHashSet + +import org.apache.spark.graph._ +import org.apache.spark.graph.impl.MsgRDDFunctions._ + +/** + * Stores the vertex attribute values after they are replicated. See + * the description of localVidMap in [[GraphImpl]]. + */ +class VTableReplicatedValues[VD: ClassManifest]( + vTable: VertexSetRDD[VD], + vid2pid: Vid2Pid, + localVidMap: RDD[(Pid, VertexIdToIndexMap)]) { + + val bothAttrs: RDD[(Pid, Array[VD])] = + VTableReplicatedValues.createVTableReplicated(vTable, vid2pid, localVidMap, true, true) + val srcAttrOnly: RDD[(Pid, Array[VD])] = + VTableReplicatedValues.createVTableReplicated(vTable, vid2pid, localVidMap, true, false) + val dstAttrOnly: RDD[(Pid, Array[VD])] = + VTableReplicatedValues.createVTableReplicated(vTable, vid2pid, localVidMap, false, true) + val noAttrs: RDD[(Pid, Array[VD])] = + VTableReplicatedValues.createVTableReplicated(vTable, vid2pid, localVidMap, false, false) + + + def get(includeSrcAttr: Boolean, includeDstAttr: Boolean): RDD[(Pid, Array[VD])] = + (includeSrcAttr, includeDstAttr) match { + case (true, true) => bothAttrs + case (true, false) => srcAttrOnly + case (false, true) => dstAttrOnly + case (false, false) => noAttrs + } +} + + + +object VTableReplicatedValues { + protected def createVTableReplicated[VD: ClassManifest]( + vTable: VertexSetRDD[VD], + vid2pid: Vid2Pid, + localVidMap: RDD[(Pid, VertexIdToIndexMap)], + includeSrcAttr: Boolean, + includeDstAttr: Boolean): RDD[(Pid, Array[VD])] = { + + // Join vid2pid and vTable, generate a shuffle dependency on the joined + // result, and get the shuffle id so we can use it on the slave. + val msgsByPartition = vTable.zipJoinFlatMap(vid2pid.get(includeSrcAttr, includeDstAttr)) { + // TODO(rxin): reuse VertexBroadcastMessage + (vid, vdata, pids) => pids.iterator.map { pid => + new VertexBroadcastMsg[VD](pid, vid, vdata) + } + }.partitionBy(localVidMap.partitioner.get).cache() + + localVidMap.zipPartitions(msgsByPartition){ + (mapIter, msgsIter) => + val (pid, vidToIndex) = mapIter.next() + assert(!mapIter.hasNext) + // Populate the vertex array using the vidToIndex map + val vertexArray = new Array[VD](vidToIndex.capacity) + for (msg <- msgsIter) { + val ind = vidToIndex.getPos(msg.vid) & OpenHashSet.POSITION_MASK + vertexArray(ind) = msg.data + } + Iterator((pid, vertexArray)) + }.cache() + + // @todo assert edge table has partitioner + } + +} diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala b/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala index d8c8d35ee1..9bdca7f407 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala @@ -3,12 +3,13 @@ package org.apache.spark.graph.impl import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.graph._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.graph._ + /** - * Stores the layout of vertex attributes. + * Stores the layout of vertex attributes for GraphImpl. */ class Vid2Pid( eTable: RDD[(Pid, EdgePartition[ED])] forSome { type ED }, @@ -17,9 +18,16 @@ class Vid2Pid( val bothAttrs: VertexSetRDD[Array[Pid]] = createVid2Pid(true, true) val srcAttrOnly: VertexSetRDD[Array[Pid]] = createVid2Pid(true, false) val dstAttrOnly: VertexSetRDD[Array[Pid]] = createVid2Pid(false, true) - // TODO(ankurdave): create this more efficiently val noAttrs: VertexSetRDD[Array[Pid]] = createVid2Pid(false, false) + def get(includeSrcAttr: Boolean, includeDstAttr: Boolean): VertexSetRDD[Array[Pid]] = + (includeSrcAttr, includeDstAttr) match { + case (true, true) => bothAttrs + case (true, false) => srcAttrOnly + case (false, true) => dstAttrOnly + case (false, false) => noAttrs + } + def persist(newLevel: StorageLevel) { bothAttrs.persist(newLevel) srcAttrOnly.persist(newLevel) @@ -28,15 +36,17 @@ class Vid2Pid( } private def createVid2Pid( - includeSrcAttr: Boolean, - includeDstAttr: Boolean): VertexSetRDD[Array[Pid]] = { + includeSrcAttr: Boolean, + includeDstAttr: Boolean): VertexSetRDD[Array[Pid]] = { val preAgg = eTable.mapPartitions { iter => val (pid, edgePartition) = iter.next() val vSet = new VertexSet - edgePartition.foreach(e => { - if (includeSrcAttr) vSet.add(e.srcId) - if (includeDstAttr) vSet.add(e.dstId) - }) + if (includeSrcAttr || includeDstAttr) { + edgePartition.foreach(e => { + if (includeSrcAttr) vSet.add(e.srcId) + if (includeDstAttr) vSet.add(e.dstId) + }) + } vSet.iterator.map { vid => (vid.toLong, pid) } } VertexSetRDD[Pid, ArrayBuffer[Pid]](preAgg, vTableIndex,