Merge pull request #88 from amplab/varenc

Fixed a bug that variable encoding doesn't work for ints that use all 64 bits.
This commit is contained in:
Ankur Dave 2013-12-05 16:25:44 -08:00
commit 1c8500efc0
2 changed files with 42 additions and 8 deletions

View file

@ -167,7 +167,7 @@ class DoubleAggMsgSerializer extends Serializer {
// Helper classes to shorten the implementation of those special serializers.
////////////////////////////////////////////////////////////////////////////////
sealed abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream {
abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream {
// The implementation should override this one.
def writeObject[T](t: T): SerializationStream
@ -280,7 +280,7 @@ sealed abstract class ShuffleSerializationStream(s: OutputStream) extends Serial
override def close(): Unit = s.close()
}
sealed abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream {
abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream {
// The implementation should override this one.
def readObject[T](): T
@ -311,17 +311,16 @@ sealed abstract class ShuffleDeserializationStream(s: InputStream) extends Deser
def readVarLong(optimizePositive: Boolean): Long = {
// TODO: unroll the while loop.
var value: Long = 0L
var i: Int = 0
def readOrThrow(): Int = {
val in = s.read()
if (in < 0) throw new java.io.EOFException
in & 0xFF
}
var i: Int = 0
var b: Int = readOrThrow()
while ((b & 0x80) != 0) {
while (i < 56 && (b & 0x80) != 0) {
value |= (b & 0x7F).toLong << i
i += 7
if (i > 63) throw new IllegalArgumentException("Variable length quantity is too long")
b = readOrThrow()
}
val ret = value | (b.toLong << i)

View file

@ -1,13 +1,16 @@
package org.apache.spark.graph
import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream}
import scala.util.Random
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark._
import org.apache.spark.graph.LocalSparkContext._
import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream}
import org.apache.spark.graph.impl._
import org.apache.spark.graph.impl.MsgRDDFunctions._
import org.apache.spark._
import org.apache.spark.serializer.SerializationStream
class SerializerSuite extends FunSuite with LocalSparkContext {
@ -143,4 +146,36 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
bmsgs.partitionBy(new HashPartitioner(3)).collect()
}
}
test("variable long encoding") {
def testVarLongEncoding(v: Long, optimizePositive: Boolean) {
val bout = new ByteArrayOutputStream
val stream = new ShuffleSerializationStream(bout) {
def writeObject[T](t: T): SerializationStream = {
writeVarLong(t.asInstanceOf[Long], optimizePositive = optimizePositive)
this
}
}
stream.writeObject(v)
val bin = new ByteArrayInputStream(bout.toByteArray)
val dstream = new ShuffleDeserializationStream(bin) {
def readObject[T](): T = {
readVarLong(optimizePositive).asInstanceOf[T]
}
}
val read = dstream.readObject[Long]()
assert(read === v)
}
// Test all variable encoding code path (each branch uses 7 bits, i.e. 1L << 7 difference)
val d = Random.nextLong() % 128
Seq[Long](0, 1L << 0 + d, 1L << 7 + d, 1L << 14 + d, 1L << 21 + d, 1L << 28 + d, 1L << 35 + d,
1L << 42 + d, 1L << 49 + d, 1L << 56 + d, 1L << 63 + d).foreach { number =>
testVarLongEncoding(number, optimizePositive = false)
testVarLongEncoding(number, optimizePositive = true)
testVarLongEncoding(-number, optimizePositive = false)
testVarLongEncoding(-number, optimizePositive = true)
}
}
}