diff --git a/graph/src/main/scala/org/apache/spark/graph/Analytics.scala b/graph/src/main/scala/org/apache/spark/graph/Analytics.scala index 6beaea07fa..a6e808cdbe 100644 --- a/graph/src/main/scala/org/apache/spark/graph/Analytics.scala +++ b/graph/src/main/scala/org/apache/spark/graph/Analytics.scala @@ -241,6 +241,7 @@ object Analytics extends Logging { var outFname = "" var numVPart = 4 var numEPart = 4 + var partitionStrategy: PartitionStrategy = RandomVertexCut options.foreach{ case ("numIter", v) => numIter = v.toInt @@ -249,6 +250,15 @@ object Analytics extends Logging { case ("output", v) => outFname = v case ("numVPart", v) => numVPart = v.toInt case ("numEPart", v) => numEPart = v.toInt + case ("partStrategy", v) => { + v match { + case "RandomVertexCut" => partitionStrategy = RandomVertexCut + case "EdgePartition1D" => partitionStrategy = EdgePartition1D + case "EdgePartition2D" => partitionStrategy = EdgePartition2D + case "CanonicalRandomVertexCut" => partitionStrategy = CanonicalRandomVertexCut + case _ => throw new IllegalArgumentException("Invalid Partion Strategy: " + v) + } + } case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt) } diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala b/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala index baf8099556..6f18e46ab2 100644 --- a/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala +++ b/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala @@ -5,6 +5,7 @@ import com.esotericsoftware.kryo.Kryo import org.apache.spark.graph.impl._ import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.util.collection.BitSet +import org.apache.spark.graph._ class GraphKryoRegistrator extends KryoRegistrator { @@ -19,6 +20,8 @@ class GraphKryoRegistrator extends KryoRegistrator { kryo.register(classOf[BitSet]) kryo.register(classOf[VertexIdToIndexMap]) kryo.register(classOf[VertexAttributeBlock[Object]]) + kryo.register(classOf[PartitionStrategy]) + // This avoids a large number of hash table lookups. kryo.setReferences(false) } diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphLoader.scala b/graph/src/main/scala/org/apache/spark/graph/GraphLoader.scala index 313737fdbe..f314083353 100644 --- a/graph/src/main/scala/org/apache/spark/graph/GraphLoader.scala +++ b/graph/src/main/scala/org/apache/spark/graph/GraphLoader.scala @@ -27,8 +27,8 @@ object GraphLoader { path: String, edgeParser: Array[String] => ED, minEdgePartitions: Int = 1, - minVertexPartitions: Int = 1) - : GraphImpl[Int, ED] = { + minVertexPartitions: Int = 1, + partitionStrategy: PartitionStrategy = RandomVertexCut): GraphImpl[Int, ED] = { // Parse the edge data table val edges = sc.textFile(path, minEdgePartitions).flatMap { line => @@ -48,13 +48,15 @@ object GraphLoader { } }.cache() - val graph = fromEdges(edges) + val graph = fromEdges(edges, partitionStrategy) graph } - private def fromEdges[ED: ClassManifest](edges: RDD[Edge[ED]]): GraphImpl[Int, ED] = { + private def fromEdges[ED: ClassManifest]( + edges: RDD[Edge[ED]], + partitionStrategy: PartitionStrategy): GraphImpl[Int, ED] = { val vertices = edges.flatMap { edge => List((edge.srcId, 1), (edge.dstId, 1)) } .reduceByKey(_ + _) - GraphImpl(vertices, edges, 0) + GraphImpl(vertices, edges, 0, (a: Int, b: Int) => a, partitionStrategy) } } diff --git a/graph/src/main/scala/org/apache/spark/graph/PartitionStrategy.scala b/graph/src/main/scala/org/apache/spark/graph/PartitionStrategy.scala new file mode 100644 index 0000000000..caf96ad9ce --- /dev/null +++ b/graph/src/main/scala/org/apache/spark/graph/PartitionStrategy.scala @@ -0,0 +1,42 @@ +package org.apache.spark.graph + +//import org.apache.spark.graph._ + + +sealed trait PartitionStrategy extends Serializable { def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid} + +//case object EdgePartition2D extends PartitionStrategy { +object EdgePartition2D extends PartitionStrategy { + override def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid = { + val ceilSqrtNumParts: Pid = math.ceil(math.sqrt(numParts)).toInt + val mixingPrime: Vid = 1125899906842597L + val col: Pid = ((math.abs(src) * mixingPrime) % ceilSqrtNumParts).toInt + val row: Pid = ((math.abs(dst) * mixingPrime) % ceilSqrtNumParts).toInt + (col * ceilSqrtNumParts + row) % numParts + } +} + + + +object EdgePartition1D extends PartitionStrategy { + override def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid = { + val mixingPrime: Vid = 1125899906842597L + (math.abs(src) * mixingPrime).toInt % numParts + } +} + + +object RandomVertexCut extends PartitionStrategy { + override def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid = { + math.abs((src, dst).hashCode()) % numParts + } +} + + +object CanonicalRandomVertexCut extends PartitionStrategy { + override def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid = { + val lower = math.min(src, dst) + val higher = math.max(src, dst) + math.abs((lower, higher).hashCode()) % numParts + } +} diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala index 693bb888bc..b529a6964e 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala @@ -308,10 +308,18 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( object GraphImpl { def apply[VD: ClassManifest, ED: ClassManifest]( - vertices: RDD[(Vid, VD)], edges: RDD[Edge[ED]], - defaultVertexAttr: VD): - GraphImpl[VD,ED] = { - apply(vertices, edges, defaultVertexAttr, (a:VD, b:VD) => a) + vertices: RDD[(Vid, VD)], + edges: RDD[Edge[ED]], + defaultVertexAttr: VD): GraphImpl[VD,ED] = { + apply(vertices, edges, defaultVertexAttr, (a:VD, b:VD) => a, RandomVertexCut) + } + + def apply[VD: ClassManifest, ED: ClassManifest]( + vertices: RDD[(Vid, VD)], + edges: RDD[Edge[ED]], + defaultVertexAttr: VD, + partitionStrategy: PartitionStrategy): GraphImpl[VD,ED] = { + apply(vertices, edges, defaultVertexAttr, (a:VD, b:VD) => a, partitionStrategy) } def apply[VD: ClassManifest, ED: ClassManifest]( @@ -319,6 +327,15 @@ object GraphImpl { edges: RDD[Edge[ED]], defaultVertexAttr: VD, mergeFunc: (VD, VD) => VD): GraphImpl[VD,ED] = { + apply(vertices, edges, defaultVertexAttr, mergeFunc, RandomVertexCut) + } + + def apply[VD: ClassManifest, ED: ClassManifest]( + vertices: RDD[(Vid, VD)], + edges: RDD[Edge[ED]], + defaultVertexAttr: VD, + mergeFunc: (VD, VD) => VD, + partitionStrategy: PartitionStrategy): GraphImpl[VD,ED] = { val vtable = VertexSetRDD(vertices, mergeFunc) /** @@ -339,6 +356,14 @@ object GraphImpl { new GraphImpl(vtable, vid2pid, localVidMap, etable) } + + + + protected def createETable[ED: ClassManifest](edges: RDD[Edge[ED]]) + : RDD[(Pid, EdgePartition[ED])] = { + createETable(edges, RandomVertexCut) + } + /** * Create the edge table RDD, which is much more efficient for Java heap storage than the * normal edges data structure (RDD[(Vid, Vid, ED)]). @@ -347,16 +372,18 @@ object GraphImpl { * key-value pair: the key is the partition id, and the value is an EdgePartition object * containing all the edges in a partition. */ - protected def createETable[ED: ClassManifest](edges: RDD[Edge[ED]]) - : RDD[(Pid, EdgePartition[ED])] = { + protected def createETable[ED: ClassManifest]( + edges: RDD[Edge[ED]], + partitionStrategy: PartitionStrategy): RDD[(Pid, EdgePartition[ED])] = { // Get the number of partitions val numPartitions = edges.partitions.size - val ceilSqrt: Pid = math.ceil(math.sqrt(numPartitions)).toInt + edges.map { e => // Random partitioning based on the source vertex id. // val part: Pid = edgePartitionFunction1D(e.srcId, e.dstId, numPartitions) // val part: Pid = edgePartitionFunction2D(e.srcId, e.dstId, numPartitions, ceilSqrt) - val part: Pid = randomVertexCut(e.srcId, e.dstId, numPartitions) + //val part: Pid = randomVertexCut(e.srcId, e.dstId, numPartitions) + val part: Pid = partitionStrategy.getPartition(e.srcId, e.dstId, numPartitions) // Should we be using 3-tuple or an optimized class new MessageToPartition(part, (e.srcId, e.dstId, e.attr)) @@ -538,7 +565,8 @@ object GraphImpl { * */ protected def edgePartitionFunction2D(src: Vid, dst: Vid, - numParts: Pid, ceilSqrtNumParts: Pid): Pid = { + numParts: Pid): Pid = { + val ceilSqrtNumParts: Pid = math.ceil(math.sqrt(numParts)).toInt val mixingPrime: Vid = 1125899906842597L val col: Pid = ((math.abs(src) * mixingPrime) % ceilSqrtNumParts).toInt val row: Pid = ((math.abs(dst) * mixingPrime) % ceilSqrtNumParts).toInt