diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala index ead072dcb8..d16a81d203 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala @@ -322,15 +322,20 @@ object GraphImpl { mergeFunc: (VD, VD) => VD): GraphImpl[VD,ED] = { vertices.cache - edges.cache - // Get the set of all vids - val allVids = vertices.map(_._1).union(edges.flatMap(e => Seq(e.srcId, e.dstId))) + val etable = createETable(edges).cache + // Get the set of all vids, preserving partitions + val partitioner = Partitioner.defaultPartitioner(vertices) + val implicitVids = etable.flatMap { + case (pid, partition) => Array.concat(partition.srcIds, partition.dstIds) + }.map(vid => (vid, ())).partitionBy(partitioner) + val allVids = vertices.zipPartitions(implicitVids) { + (a, b) => a.map(_._1) ++ b.map(_._1) + } // Index the set of all vids - val index = VertexSetRDD.makeIndex(allVids, Some(Partitioner.defaultPartitioner(vertices))) + val index = VertexSetRDD.makeIndex(allVids, Some(partitioner)) // Index the vertices and fill in missing attributes with the default val vtable = VertexSetRDD(vertices, index, mergeFunc).fillMissing(defaultVertexAttr) - val etable = createETable(edges) val vid2pid = new Vid2Pid(etable, vtable.index) val localVidMap = createLocalVidMap(etable) new GraphImpl(vtable, vid2pid, localVidMap, etable)