Replace RoutingTableMessage with pair
RoutingTableMessage was used to construct routing tables to enable joining VertexRDDs with partitioned edges. It stored three elements: the destination vertex ID, the source edge partition, and a byte specifying the position in which the edge partition referenced the vertex to enable join elimination. However, this was incompatible with sort-based shuffle (SPARK-2045). It was also slightly wasteful, because partition IDs are usually much smaller than 2^32, though this was mitigated by a custom serializer that used variable-length encoding. This commit replaces RoutingTableMessage with a pair of (VertexId, Int) where the Int encodes both the source partition ID (in the lower 30 bits) and the position (in the top 2 bits). Author: Ankur Dave <ankurdave@gmail.com> Closes #1553 from ankurdave/remove-RoutingTableMessage and squashes the following commits: 697e17b [Ankur Dave] Replace RoutingTableMessage with pair
This commit is contained in:
parent
60f0ae3d87
commit
2d25e34814
|
@ -35,7 +35,6 @@ class GraphKryoRegistrator extends KryoRegistrator {
|
|||
|
||||
def registerClasses(kryo: Kryo) {
|
||||
kryo.register(classOf[Edge[Object]])
|
||||
kryo.register(classOf[RoutingTableMessage])
|
||||
kryo.register(classOf[(VertexId, Object)])
|
||||
kryo.register(classOf[EdgePartition[Object, Object]])
|
||||
kryo.register(classOf[BitSet])
|
||||
|
|
|
@ -27,26 +27,13 @@ import org.apache.spark.util.collection.{BitSet, PrimitiveVector}
|
|||
import org.apache.spark.graphx._
|
||||
import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
|
||||
|
||||
/**
|
||||
* A message from the edge partition `pid` to the vertex partition containing `vid` specifying that
|
||||
* the edge partition references `vid` in the specified `position` (src, dst, or both).
|
||||
*/
|
||||
private[graphx]
|
||||
class RoutingTableMessage(
|
||||
var vid: VertexId,
|
||||
var pid: PartitionID,
|
||||
var position: Byte)
|
||||
extends Product2[VertexId, (PartitionID, Byte)] with Serializable {
|
||||
override def _1 = vid
|
||||
override def _2 = (pid, position)
|
||||
override def canEqual(that: Any): Boolean = that.isInstanceOf[RoutingTableMessage]
|
||||
}
|
||||
import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
|
||||
|
||||
private[graphx]
|
||||
class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) {
|
||||
/** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */
|
||||
def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = {
|
||||
new ShuffledRDD[VertexId, (PartitionID, Byte), (PartitionID, Byte), RoutingTableMessage](
|
||||
new ShuffledRDD[VertexId, Int, Int, RoutingTableMessage](
|
||||
self, partitioner).setSerializer(new RoutingTableMessageSerializer)
|
||||
}
|
||||
}
|
||||
|
@ -62,6 +49,23 @@ object RoutingTableMessageRDDFunctions {
|
|||
|
||||
private[graphx]
|
||||
object RoutingTablePartition {
|
||||
/**
|
||||
* A message from an edge partition to a vertex specifying the position in which the edge
|
||||
* partition references the vertex (src, dst, or both). The edge partition is encoded in the lower
|
||||
* 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int.
|
||||
*/
|
||||
type RoutingTableMessage = (VertexId, Int)
|
||||
|
||||
private def toMessage(vid: VertexId, pid: PartitionID, position: Byte): RoutingTableMessage = {
|
||||
val positionUpper2 = position << 30
|
||||
val pidLower30 = pid & 0x3FFFFFFF
|
||||
(vid, positionUpper2 | pidLower30)
|
||||
}
|
||||
|
||||
private def vidFromMessage(msg: RoutingTableMessage): VertexId = msg._1
|
||||
private def pidFromMessage(msg: RoutingTableMessage): PartitionID = msg._2 & 0x3FFFFFFF
|
||||
private def positionFromMessage(msg: RoutingTableMessage): Byte = (msg._2 >> 30).toByte
|
||||
|
||||
val empty: RoutingTablePartition = new RoutingTablePartition(Array.empty)
|
||||
|
||||
/** Generate a `RoutingTableMessage` for each vertex referenced in `edgePartition`. */
|
||||
|
@ -77,7 +81,9 @@ object RoutingTablePartition {
|
|||
map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte)
|
||||
}
|
||||
map.iterator.map { vidAndPosition =>
|
||||
new RoutingTableMessage(vidAndPosition._1, pid, vidAndPosition._2)
|
||||
val vid = vidAndPosition._1
|
||||
val position = vidAndPosition._2
|
||||
toMessage(vid, pid, position)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -88,9 +94,12 @@ object RoutingTablePartition {
|
|||
val srcFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean])
|
||||
val dstFlags = Array.fill(numEdgePartitions)(new PrimitiveVector[Boolean])
|
||||
for (msg <- iter) {
|
||||
pid2vid(msg.pid) += msg.vid
|
||||
srcFlags(msg.pid) += (msg.position & 0x1) != 0
|
||||
dstFlags(msg.pid) += (msg.position & 0x2) != 0
|
||||
val vid = vidFromMessage(msg)
|
||||
val pid = pidFromMessage(msg)
|
||||
val position = positionFromMessage(msg)
|
||||
pid2vid(pid) += vid
|
||||
srcFlags(pid) += (position & 0x1) != 0
|
||||
dstFlags(pid) += (position & 0x2) != 0
|
||||
}
|
||||
|
||||
new RoutingTablePartition(pid2vid.zipWithIndex.map {
|
||||
|
|
|
@ -24,9 +24,11 @@ import java.nio.ByteBuffer
|
|||
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import org.apache.spark.graphx._
|
||||
import org.apache.spark.serializer._
|
||||
|
||||
import org.apache.spark.graphx._
|
||||
import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
|
||||
|
||||
private[graphx]
|
||||
class RoutingTableMessageSerializer extends Serializer with Serializable {
|
||||
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
|
||||
|
@ -35,10 +37,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable {
|
|||
new ShuffleSerializationStream(s) {
|
||||
def writeObject[T: ClassTag](t: T): SerializationStream = {
|
||||
val msg = t.asInstanceOf[RoutingTableMessage]
|
||||
writeVarLong(msg.vid, optimizePositive = false)
|
||||
writeUnsignedVarInt(msg.pid)
|
||||
// TODO: Write only the bottom two bits of msg.position
|
||||
s.write(msg.position)
|
||||
writeVarLong(msg._1, optimizePositive = false)
|
||||
writeInt(msg._2)
|
||||
this
|
||||
}
|
||||
}
|
||||
|
@ -47,10 +47,8 @@ class RoutingTableMessageSerializer extends Serializer with Serializable {
|
|||
new ShuffleDeserializationStream(s) {
|
||||
override def readObject[T: ClassTag](): T = {
|
||||
val a = readVarLong(optimizePositive = false)
|
||||
val b = readUnsignedVarInt()
|
||||
val c = s.read()
|
||||
if (c == -1) throw new EOFException
|
||||
new RoutingTableMessage(a, b, c.toByte).asInstanceOf[T]
|
||||
val b = readInt()
|
||||
(a, b).asInstanceOf[T]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ package object graphx {
|
|||
*/
|
||||
type VertexId = Long
|
||||
|
||||
/** Integer identifer of a graph partition. */
|
||||
/** Integer identifer of a graph partition. Must be less than 2^30. */
|
||||
// TODO: Consider using Char.
|
||||
type PartitionID = Int
|
||||
|
||||
|
|
Loading…
Reference in a new issue