Merge pull request #88 from amplab/varenc
Fixed a bug that variable encoding doesn't work for ints that use all 64 bits.
This commit is contained in:
commit
1c8500efc0
|
@ -167,7 +167,7 @@ class DoubleAggMsgSerializer extends Serializer {
|
||||||
// Helper classes to shorten the implementation of those special serializers.
|
// Helper classes to shorten the implementation of those special serializers.
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
sealed abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream {
|
abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream {
|
||||||
// The implementation should override this one.
|
// The implementation should override this one.
|
||||||
def writeObject[T](t: T): SerializationStream
|
def writeObject[T](t: T): SerializationStream
|
||||||
|
|
||||||
|
@ -280,7 +280,7 @@ sealed abstract class ShuffleSerializationStream(s: OutputStream) extends Serial
|
||||||
override def close(): Unit = s.close()
|
override def close(): Unit = s.close()
|
||||||
}
|
}
|
||||||
|
|
||||||
sealed abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream {
|
abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream {
|
||||||
// The implementation should override this one.
|
// The implementation should override this one.
|
||||||
def readObject[T](): T
|
def readObject[T](): T
|
||||||
|
|
||||||
|
@ -311,17 +311,16 @@ sealed abstract class ShuffleDeserializationStream(s: InputStream) extends Deser
|
||||||
def readVarLong(optimizePositive: Boolean): Long = {
|
def readVarLong(optimizePositive: Boolean): Long = {
|
||||||
// TODO: unroll the while loop.
|
// TODO: unroll the while loop.
|
||||||
var value: Long = 0L
|
var value: Long = 0L
|
||||||
var i: Int = 0
|
|
||||||
def readOrThrow(): Int = {
|
def readOrThrow(): Int = {
|
||||||
val in = s.read()
|
val in = s.read()
|
||||||
if (in < 0) throw new java.io.EOFException
|
if (in < 0) throw new java.io.EOFException
|
||||||
in & 0xFF
|
in & 0xFF
|
||||||
}
|
}
|
||||||
|
var i: Int = 0
|
||||||
var b: Int = readOrThrow()
|
var b: Int = readOrThrow()
|
||||||
while ((b & 0x80) != 0) {
|
while (i < 56 && (b & 0x80) != 0) {
|
||||||
value |= (b & 0x7F).toLong << i
|
value |= (b & 0x7F).toLong << i
|
||||||
i += 7
|
i += 7
|
||||||
if (i > 63) throw new IllegalArgumentException("Variable length quantity is too long")
|
|
||||||
b = readOrThrow()
|
b = readOrThrow()
|
||||||
}
|
}
|
||||||
val ret = value | (b.toLong << i)
|
val ret = value | (b.toLong << i)
|
||||||
|
|
|
@ -1,13 +1,16 @@
|
||||||
package org.apache.spark.graph
|
package org.apache.spark.graph
|
||||||
|
|
||||||
|
import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream}
|
||||||
|
|
||||||
|
import scala.util.Random
|
||||||
|
|
||||||
import org.scalatest.FunSuite
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark._
|
||||||
import org.apache.spark.graph.LocalSparkContext._
|
import org.apache.spark.graph.LocalSparkContext._
|
||||||
import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream}
|
|
||||||
import org.apache.spark.graph.impl._
|
import org.apache.spark.graph.impl._
|
||||||
import org.apache.spark.graph.impl.MsgRDDFunctions._
|
import org.apache.spark.graph.impl.MsgRDDFunctions._
|
||||||
import org.apache.spark._
|
import org.apache.spark.serializer.SerializationStream
|
||||||
|
|
||||||
|
|
||||||
class SerializerSuite extends FunSuite with LocalSparkContext {
|
class SerializerSuite extends FunSuite with LocalSparkContext {
|
||||||
|
@ -143,4 +146,36 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
|
||||||
bmsgs.partitionBy(new HashPartitioner(3)).collect()
|
bmsgs.partitionBy(new HashPartitioner(3)).collect()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("variable long encoding") {
|
||||||
|
def testVarLongEncoding(v: Long, optimizePositive: Boolean) {
|
||||||
|
val bout = new ByteArrayOutputStream
|
||||||
|
val stream = new ShuffleSerializationStream(bout) {
|
||||||
|
def writeObject[T](t: T): SerializationStream = {
|
||||||
|
writeVarLong(t.asInstanceOf[Long], optimizePositive = optimizePositive)
|
||||||
|
this
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stream.writeObject(v)
|
||||||
|
|
||||||
|
val bin = new ByteArrayInputStream(bout.toByteArray)
|
||||||
|
val dstream = new ShuffleDeserializationStream(bin) {
|
||||||
|
def readObject[T](): T = {
|
||||||
|
readVarLong(optimizePositive).asInstanceOf[T]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val read = dstream.readObject[Long]()
|
||||||
|
assert(read === v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test all variable encoding code path (each branch uses 7 bits, i.e. 1L << 7 difference)
|
||||||
|
val d = Random.nextLong() % 128
|
||||||
|
Seq[Long](0, 1L << 0 + d, 1L << 7 + d, 1L << 14 + d, 1L << 21 + d, 1L << 28 + d, 1L << 35 + d,
|
||||||
|
1L << 42 + d, 1L << 49 + d, 1L << 56 + d, 1L << 63 + d).foreach { number =>
|
||||||
|
testVarLongEncoding(number, optimizePositive = false)
|
||||||
|
testVarLongEncoding(number, optimizePositive = true)
|
||||||
|
testVarLongEncoding(-number, optimizePositive = false)
|
||||||
|
testVarLongEncoding(-number, optimizePositive = true)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue