diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala b/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala index d89143505e..2ab3759e4c 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/VertexPartition.scala @@ -2,7 +2,7 @@ package org.apache.spark.graph.impl import org.apache.spark.util.collection.{BitSet, PrimitiveKeyOpenHashMap} -import org.apache.spark.{Logging, SparkException} +import org.apache.spark.Logging import org.apache.spark.graph._ @@ -184,8 +184,7 @@ class VertexPartition[@specialized(Long, Int, Double) VD: ClassManifest]( for ((k, v) <- this.iterator) { hashMap.setMerge(k, v, arbitraryMerge) } - // TODO: Is this a bug? Why are we using index.getBitSet here? - new VertexPartition(hashMap.keySet, hashMap._values, index.getBitSet) + new VertexPartition(hashMap.keySet, hashMap._values, hashMap.keySet.getBitSet) } def iterator: Iterator[(Vid, VD)] = mask.iterator.map(ind => (index.getValue(ind), values(ind))) diff --git a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala index 47b98cfd80..2a040de7fe 100644 --- a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala +++ b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala @@ -134,4 +134,15 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } + test("subgraph") { + withSpark(new SparkContext("local", "test")) { sc => + val n = 10 + val star = Graph(sc.parallelize((1 to n).map(x => (0: Vid, x: Vid))), "defaultValue") + val subgraph = star.subgraph(vpred = (vid, attr) => vid % 2 == 0) + assert(subgraph.vertices.collect().toSet === + (0 to n / 2).map(x => (x * 2, "defaultValue")).toSet) + assert(subgraph.edges.collect().toSet === (1 to n / 2).map(x => Edge(0, x * 2)).toSet) + } + } + }