diff --git a/graph/src/main/scala/org/apache/spark/graph/Analytics.scala b/graph/src/main/scala/org/apache/spark/graph/Analytics.scala index 8455a145ff..f542ec6069 100644 --- a/graph/src/main/scala/org/apache/spark/graph/Analytics.scala +++ b/graph/src/main/scala/org/apache/spark/graph/Analytics.scala @@ -241,7 +241,7 @@ object Analytics extends Logging { var outFname = "" var numVPart = 4 var numEPart = 4 - var partitionStrategy: PartitionStrategy = RandomVertexCut + var partitionStrategy: PartitionStrategy = RandomVertexCut() options.foreach{ case ("numIter", v) => numIter = v.toInt @@ -251,11 +251,11 @@ object Analytics extends Logging { case ("numVPart", v) => numVPart = v.toInt case ("numEPart", v) => numEPart = v.toInt case ("partStrategy", v) => { - v match { - case "RandomVertexCut" => partitionStrategy = RandomVertexCut - case "EdgePartition1D" => partitionStrategy = EdgePartition1D - case "EdgePartition2D" => partitionStrategy = EdgePartition2D - case "CanonicalRandomVertexCut" => partitionStrategy = CanonicalRandomVertexCut + partitionStrategy = v match { + case "RandomVertexCut" => RandomVertexCut() + case "EdgePartition1D" => EdgePartition1D() + case "EdgePartition2D" => EdgePartition2D() + case "CanonicalRandomVertexCut" => CanonicalRandomVertexCut() case _ => throw new IllegalArgumentException("Invalid Partition Strategy: " + v) } } diff --git a/graph/src/main/scala/org/apache/spark/graph/Graph.scala b/graph/src/main/scala/org/apache/spark/graph/Graph.scala index 6ce3f5d2e7..87667f6958 100644 --- a/graph/src/main/scala/org/apache/spark/graph/Graph.scala +++ b/graph/src/main/scala/org/apache/spark/graph/Graph.scala @@ -1,7 +1,6 @@ package org.apache.spark.graph import org.apache.spark.rdd.RDD -import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel /** @@ -22,7 +21,7 @@ import org.apache.spark.storage.StorageLevel * @tparam VD the vertex attribute type * @tparam ED the edge attribute type */ -abstract class Graph[VD: ClassManifest, ED: ClassManifest] extends Logging { +abstract class Graph[VD: ClassManifest, ED: ClassManifest] { /** * Get the vertices and their data. diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala b/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala index 39dc33acf0..b8503ab7fd 100644 --- a/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala +++ b/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala @@ -2,12 +2,11 @@ package org.apache.spark.graph import scala.collection.JavaConversions._ import org.apache.spark.rdd.RDD -import org.apache.spark.Logging /** * This object implements the GraphLab gather-apply-scatter api. */ -object GraphLab extends Logging { +object GraphLab { /** * Execute the GraphLab Gather-Apply-Scatter API diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphLoader.scala b/graph/src/main/scala/org/apache/spark/graph/GraphLoader.scala index 813f176313..4dc33a02ce 100644 --- a/graph/src/main/scala/org/apache/spark/graph/GraphLoader.scala +++ b/graph/src/main/scala/org/apache/spark/graph/GraphLoader.scala @@ -28,7 +28,7 @@ object GraphLoader { edgeParser: Array[String] => ED, minEdgePartitions: Int = 1, minVertexPartitions: Int = 1, - partitionStrategy: PartitionStrategy = RandomVertexCut): GraphImpl[Int, ED] = { + partitionStrategy: PartitionStrategy = RandomVertexCut()): GraphImpl[Int, ED] = { // Parse the edge data table val edges = sc.textFile(path, minEdgePartitions).flatMap { line => diff --git a/graph/src/main/scala/org/apache/spark/graph/PartitionStrategy.scala b/graph/src/main/scala/org/apache/spark/graph/PartitionStrategy.scala index f7db667e2f..cf65f50657 100644 --- a/graph/src/main/scala/org/apache/spark/graph/PartitionStrategy.scala +++ b/graph/src/main/scala/org/apache/spark/graph/PartitionStrategy.scala @@ -50,7 +50,7 @@ sealed trait PartitionStrategy extends Serializable { * * */ -object EdgePartition2D extends PartitionStrategy { +case class EdgePartition2D() extends PartitionStrategy { override def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid = { val ceilSqrtNumParts: Pid = math.ceil(math.sqrt(numParts)).toInt val mixingPrime: Vid = 1125899906842597L @@ -61,7 +61,7 @@ object EdgePartition2D extends PartitionStrategy { } -object EdgePartition1D extends PartitionStrategy { +case class EdgePartition1D() extends PartitionStrategy { override def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid = { val mixingPrime: Vid = 1125899906842597L (math.abs(src) * mixingPrime).toInt % numParts @@ -73,7 +73,7 @@ object EdgePartition1D extends PartitionStrategy { * Assign edges to an aribtrary machine corresponding to a * random vertex cut. */ -object RandomVertexCut extends PartitionStrategy { +case class RandomVertexCut() extends PartitionStrategy { override def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid = { math.abs((src, dst).hashCode()) % numParts } @@ -85,7 +85,7 @@ object RandomVertexCut extends PartitionStrategy { * function ensures that edges of opposite direction between the same two vertices * will end up on the same partition. */ -object CanonicalRandomVertexCut extends PartitionStrategy { +case class CanonicalRandomVertexCut() extends PartitionStrategy { override def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid = { val lower = math.min(src, dst) val higher = math.max(src, dst) diff --git a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala index f3016e6ad3..3b4d3c0df2 100644 --- a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala +++ b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala @@ -1,7 +1,6 @@ package org.apache.spark.graph import org.apache.spark.rdd.RDD -import org.apache.spark.Logging /** @@ -42,7 +41,7 @@ import org.apache.spark.Logging * }}} * */ -object Pregel extends Logging { +object Pregel { /** diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala index 7c3d401832..6ad0ce60a7 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala @@ -8,6 +8,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext._ import org.apache.spark.HashPartitioner import org.apache.spark.util.ClosureCleaner +import org.apache.spark.SparkException import org.apache.spark.Partitioner import org.apache.spark.graph._ @@ -97,8 +98,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( @transient override val triplets: RDD[EdgeTriplet[VD, ED]] = makeTriplets(localVidMap, vTableReplicatedValues.bothAttrs, eTable) - //@transient private val partitioner: PartitionStrategy = partitionStrategy - override def persist(newLevel: StorageLevel): Graph[VD, ED] = { eTable.persist(newLevel) vid2pid.persist(newLevel) @@ -250,43 +249,55 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( override def groupEdgeTriplets[ED2: ClassManifest]( f: Iterator[EdgeTriplet[VD,ED]] => ED2 ): Graph[VD,ED2] = { - val newEdges: RDD[Edge[ED2]] = triplets.mapPartitions { partIter => - partIter - // TODO(crankshaw) toList requires that the entire edge partition - // can fit in memory right now. - .toList - // groups all ETs in this partition that have the same src and dst - // Because all ETs with the same src and dst will live on the same - // partition due to the canonicalRandomVertexCut partitioner, this - // guarantees that these ET groups will be complete. - .groupBy { t: EdgeTriplet[VD, ED] => (t.srcId, t.dstId) } - .mapValues { ts: List[EdgeTriplet[VD, ED]] => f(ts.toIterator) } - .toList - .toIterator - .map { case ((src, dst), data) => Edge(src, dst, data) } - .toIterator - } + partitioner match { + case _: CanonicalRandomVertexCut => { + val newEdges: RDD[Edge[ED2]] = triplets.mapPartitions { partIter => + partIter + // TODO(crankshaw) toList requires that the entire edge partition + // can fit in memory right now. + .toList + // groups all ETs in this partition that have the same src and dst + // Because all ETs with the same src and dst will live on the same + // partition due to the canonicalRandomVertexCut partitioner, this + // guarantees that these ET groups will be complete. + .groupBy { t: EdgeTriplet[VD, ED] => (t.srcId, t.dstId) } + .mapValues { ts: List[EdgeTriplet[VD, ED]] => f(ts.toIterator) } + .toList + .toIterator + .map { case ((src, dst), data) => Edge(src, dst, data) } + .toIterator + } + //TODO(crankshaw) eliminate the need to call createETable + val newETable = createETable(newEdges, partitioner) + new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner) + } - //TODO(crankshaw) eliminate the need to call createETable - val newETable = createETable(newEdges, partitioner) - new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner) + case _ => throw new SparkException(partitioner.getClass.getName + + " is incompatible with groupEdgeTriplets") + } } override def groupEdges[ED2: ClassManifest](f: Iterator[Edge[ED]] => ED2 ): Graph[VD,ED2] = { + partitioner match { + case _: CanonicalRandomVertexCut => { + val newEdges: RDD[Edge[ED2]] = edges.mapPartitions { partIter => + partIter.toList + .groupBy { e: Edge[ED] => (e.srcId, e.dstId) } + .mapValues { ts => f(ts.toIterator) } + .toList + .toIterator + .map { case ((src, dst), data) => Edge(src, dst, data) } + } + // TODO(crankshaw) eliminate the need to call createETable + val newETable = createETable(newEdges, partitioner) - val newEdges: RDD[Edge[ED2]] = edges.mapPartitions { partIter => - partIter.toList - .groupBy { e: Edge[ED] => (e.srcId, e.dstId) } - .mapValues { ts => f(ts.toIterator) } - .toList - .toIterator - .map { case ((src, dst), data) => Edge(src, dst, data) } + new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner) + } + + case _ => throw new SparkException(partitioner.getClass.getName + + " is incompatible with groupEdges") } - // TODO(crankshaw) eliminate the need to call createETable - val newETable = createETable(newEdges, partitioner) - - new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner) } ////////////////////////////////////////////////////////////////////////////////////////////////// @@ -315,7 +326,7 @@ object GraphImpl { vertices: RDD[(Vid, VD)], edges: RDD[Edge[ED]], defaultVertexAttr: VD): GraphImpl[VD,ED] = { - apply(vertices, edges, defaultVertexAttr, (a:VD, b:VD) => a, RandomVertexCut) + apply(vertices, edges, defaultVertexAttr, (a:VD, b:VD) => a, RandomVertexCut()) } def apply[VD: ClassManifest, ED: ClassManifest]( @@ -331,7 +342,7 @@ object GraphImpl { edges: RDD[Edge[ED]], defaultVertexAttr: VD, mergeFunc: (VD, VD) => VD): GraphImpl[VD,ED] = { - apply(vertices, edges, defaultVertexAttr, mergeFunc, RandomVertexCut) + apply(vertices, edges, defaultVertexAttr, mergeFunc, RandomVertexCut()) } def apply[VD: ClassManifest, ED: ClassManifest]( @@ -362,14 +373,6 @@ object GraphImpl { } - - - // TODO(crankshaw) - can I remove this - //protected def createETable[ED: ClassManifest](edges: RDD[Edge[ED]]) - // : RDD[(Pid, EdgePartition[ED])] = { - // createETable(edges, RandomVertexCut) - //} - /** * Create the edge table RDD, which is much more efficient for Java heap storage than the * normal edges data structure (RDD[(Vid, Vid, ED)]).