Fix the hanging bug.

This commit is contained in:
Reynold Xin 2013-11-09 23:29:37 -08:00
parent f6c946206a
commit 0e813cd483
2 changed files with 39 additions and 14 deletions

View file

@ -1,9 +1,9 @@
package org.apache.spark.graph.impl package org.apache.spark.graph.impl
import java.io.{InputStream, OutputStream} import java.io.{EOFException, InputStream, OutputStream}
import java.nio.ByteBuffer import java.nio.ByteBuffer
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance, Serializer} import org.apache.spark.serializer._
/** A special shuffle serializer for VertexBroadcastMessage[Int]. */ /** A special shuffle serializer for VertexBroadcastMessage[Int]. */
@ -185,11 +185,15 @@ sealed abstract class ShuffleDeserializationStream(s: InputStream) extends Deser
def readObject[T](): T def readObject[T](): T
def readInt(): Int = { def readInt(): Int = {
(s.read() & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF) val first = s.read()
if (first < 0) throw new EOFException
(first & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF)
} }
def readLong(): Long = { def readLong(): Long = {
(s.read().toLong << 56) | val first = s.read()
if (first < 0) throw new EOFException()
(first.toLong << 56) |
(s.read() & 0xFF).toLong << 48 | (s.read() & 0xFF).toLong << 48 |
(s.read() & 0xFF).toLong << 40 | (s.read() & 0xFF).toLong << 40 |
(s.read() & 0xFF).toLong << 32 | (s.read() & 0xFF).toLong << 32 |

View file

@ -4,8 +4,7 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
import org.apache.spark.graph.LocalSparkContext._ import org.apache.spark.graph.LocalSparkContext._
import java.io.ByteArrayInputStream import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream}
import java.io.ByteArrayOutputStream
import org.apache.spark.graph.impl._ import org.apache.spark.graph.impl._
import org.apache.spark.graph.impl.MsgRDDFunctions._ import org.apache.spark.graph.impl.MsgRDDFunctions._
import org.apache.spark._ import org.apache.spark._
@ -31,6 +30,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
assert(outMsg.vid === inMsg2.vid) assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data) assert(outMsg.data === inMsg2.data)
intercept[EOFException] {
inStrm.readObject()
}
} }
test("TestVertexBroadcastMessageLong") { test("TestVertexBroadcastMessageLong") {
@ -48,6 +51,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
assert(outMsg.vid === inMsg2.vid) assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data) assert(outMsg.data === inMsg2.data)
intercept[EOFException] {
inStrm.readObject()
}
} }
test("TestVertexBroadcastMessageDouble") { test("TestVertexBroadcastMessageDouble") {
@ -65,6 +72,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
assert(outMsg.vid === inMsg2.vid) assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data) assert(outMsg.data === inMsg2.data)
intercept[EOFException] {
inStrm.readObject()
}
} }
test("TestAggregationMessageInt") { test("TestAggregationMessageInt") {
@ -82,6 +93,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
assert(outMsg.vid === inMsg2.vid) assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data) assert(outMsg.data === inMsg2.data)
intercept[EOFException] {
inStrm.readObject()
}
} }
test("TestAggregationMessageLong") { test("TestAggregationMessageLong") {
@ -99,6 +114,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
assert(outMsg.vid === inMsg2.vid) assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data) assert(outMsg.data === inMsg2.data)
intercept[EOFException] {
inStrm.readObject()
}
} }
test("TestAggregationMessageDouble") { test("TestAggregationMessageDouble") {
@ -116,23 +135,25 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
assert(outMsg.vid === inMsg2.vid) assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data) assert(outMsg.data === inMsg2.data)
intercept[EOFException] {
inStrm.readObject()
}
} }
test("TestShuffleVertexBroadcastMsg") { test("TestShuffleVertexBroadcastMsg") {
withSpark(new SparkContext("local[2]", "test")) { sc => withSpark(new SparkContext("local[2]", "test")) { sc =>
val bmsgs = sc.parallelize( val bmsgs = sc.parallelize(0 until 100, 10).map { pid =>
(0 until 100).map(pid => new VertexBroadcastMsg[Int](pid, pid, pid)), 10) new VertexBroadcastMsg[Int](pid, pid, pid)
val partitioner = new HashPartitioner(3) }
val bmsgsArray = bmsgs.partitionBy(partitioner).collect bmsgs.partitionBy(new HashPartitioner(3)).collect()
} }
} }
test("TestShuffleAggregationMsg") { test("TestShuffleAggregationMsg") {
withSpark(new SparkContext("local[2]", "test")) { sc => withSpark(new SparkContext("local[2]", "test")) { sc =>
val bmsgs = sc.parallelize( val bmsgs = sc.parallelize(0 until 100, 10).map(pid => new AggregationMsg[Int](pid, pid))
(0 until 100).map(pid => new AggregationMsg[Int](pid, pid)), 10) bmsgs.partitionBy(new HashPartitioner(3)).collect()
val partitioner = new HashPartitioner(3)
val bmsgsArray = bmsgs.partitionBy(partitioner).collect
} }
} }