Fix the hanging bug.
This commit is contained in:
parent
f6c946206a
commit
0e813cd483
|
@ -1,9 +1,9 @@
|
|||
package org.apache.spark.graph.impl
|
||||
|
||||
import java.io.{InputStream, OutputStream}
|
||||
import java.io.{EOFException, InputStream, OutputStream}
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance, Serializer}
|
||||
import org.apache.spark.serializer._
|
||||
|
||||
|
||||
/** A special shuffle serializer for VertexBroadcastMessage[Int]. */
|
||||
|
@ -185,11 +185,15 @@ sealed abstract class ShuffleDeserializationStream(s: InputStream) extends Deser
|
|||
def readObject[T](): T
|
||||
|
||||
def readInt(): Int = {
|
||||
(s.read() & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF)
|
||||
val first = s.read()
|
||||
if (first < 0) throw new EOFException
|
||||
(first & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF)
|
||||
}
|
||||
|
||||
def readLong(): Long = {
|
||||
(s.read().toLong << 56) |
|
||||
val first = s.read()
|
||||
if (first < 0) throw new EOFException()
|
||||
(first.toLong << 56) |
|
||||
(s.read() & 0xFF).toLong << 48 |
|
||||
(s.read() & 0xFF).toLong << 40 |
|
||||
(s.read() & 0xFF).toLong << 32 |
|
||||
|
|
|
@ -4,8 +4,7 @@ import org.scalatest.FunSuite
|
|||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.graph.LocalSparkContext._
|
||||
import java.io.ByteArrayInputStream
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream}
|
||||
import org.apache.spark.graph.impl._
|
||||
import org.apache.spark.graph.impl.MsgRDDFunctions._
|
||||
import org.apache.spark._
|
||||
|
@ -31,6 +30,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
|
|||
assert(outMsg.vid === inMsg2.vid)
|
||||
assert(outMsg.data === inMsg1.data)
|
||||
assert(outMsg.data === inMsg2.data)
|
||||
|
||||
intercept[EOFException] {
|
||||
inStrm.readObject()
|
||||
}
|
||||
}
|
||||
|
||||
test("TestVertexBroadcastMessageLong") {
|
||||
|
@ -48,6 +51,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
|
|||
assert(outMsg.vid === inMsg2.vid)
|
||||
assert(outMsg.data === inMsg1.data)
|
||||
assert(outMsg.data === inMsg2.data)
|
||||
|
||||
intercept[EOFException] {
|
||||
inStrm.readObject()
|
||||
}
|
||||
}
|
||||
|
||||
test("TestVertexBroadcastMessageDouble") {
|
||||
|
@ -65,6 +72,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
|
|||
assert(outMsg.vid === inMsg2.vid)
|
||||
assert(outMsg.data === inMsg1.data)
|
||||
assert(outMsg.data === inMsg2.data)
|
||||
|
||||
intercept[EOFException] {
|
||||
inStrm.readObject()
|
||||
}
|
||||
}
|
||||
|
||||
test("TestAggregationMessageInt") {
|
||||
|
@ -82,6 +93,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
|
|||
assert(outMsg.vid === inMsg2.vid)
|
||||
assert(outMsg.data === inMsg1.data)
|
||||
assert(outMsg.data === inMsg2.data)
|
||||
|
||||
intercept[EOFException] {
|
||||
inStrm.readObject()
|
||||
}
|
||||
}
|
||||
|
||||
test("TestAggregationMessageLong") {
|
||||
|
@ -99,6 +114,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
|
|||
assert(outMsg.vid === inMsg2.vid)
|
||||
assert(outMsg.data === inMsg1.data)
|
||||
assert(outMsg.data === inMsg2.data)
|
||||
|
||||
intercept[EOFException] {
|
||||
inStrm.readObject()
|
||||
}
|
||||
}
|
||||
|
||||
test("TestAggregationMessageDouble") {
|
||||
|
@ -116,23 +135,25 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
|
|||
assert(outMsg.vid === inMsg2.vid)
|
||||
assert(outMsg.data === inMsg1.data)
|
||||
assert(outMsg.data === inMsg2.data)
|
||||
|
||||
intercept[EOFException] {
|
||||
inStrm.readObject()
|
||||
}
|
||||
}
|
||||
|
||||
test("TestShuffleVertexBroadcastMsg") {
|
||||
withSpark(new SparkContext("local[2]", "test")) { sc =>
|
||||
val bmsgs = sc.parallelize(
|
||||
(0 until 100).map(pid => new VertexBroadcastMsg[Int](pid, pid, pid)), 10)
|
||||
val partitioner = new HashPartitioner(3)
|
||||
val bmsgsArray = bmsgs.partitionBy(partitioner).collect
|
||||
val bmsgs = sc.parallelize(0 until 100, 10).map { pid =>
|
||||
new VertexBroadcastMsg[Int](pid, pid, pid)
|
||||
}
|
||||
bmsgs.partitionBy(new HashPartitioner(3)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
test("TestShuffleAggregationMsg") {
|
||||
withSpark(new SparkContext("local[2]", "test")) { sc =>
|
||||
val bmsgs = sc.parallelize(
|
||||
(0 until 100).map(pid => new AggregationMsg[Int](pid, pid)), 10)
|
||||
val partitioner = new HashPartitioner(3)
|
||||
val bmsgsArray = bmsgs.partitionBy(partitioner).collect
|
||||
val bmsgs = sc.parallelize(0 until 100, 10).map(pid => new AggregationMsg[Int](pid, pid))
|
||||
bmsgs.partitionBy(new HashPartitioner(3)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue