From 0e813cd483eb4cc612404f8602e635b29295efc3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 9 Nov 2013 23:29:37 -0800 Subject: [PATCH] Fix the hanging bug. --- .../apache/spark/graph/impl/Serializers.scala | 12 ++++-- .../apache/spark/graph/SerializerSuite.scala | 41 ++++++++++++++----- 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala b/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala index 54fd65e738..c56bbc8aee 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala @@ -1,9 +1,9 @@ package org.apache.spark.graph.impl -import java.io.{InputStream, OutputStream} +import java.io.{EOFException, InputStream, OutputStream} import java.nio.ByteBuffer -import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance, Serializer} +import org.apache.spark.serializer._ /** A special shuffle serializer for VertexBroadcastMessage[Int]. */ @@ -185,11 +185,15 @@ sealed abstract class ShuffleDeserializationStream(s: InputStream) extends Deser def readObject[T](): T def readInt(): Int = { - (s.read() & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF) + val first = s.read() + if (first < 0) throw new EOFException + (first & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF) } def readLong(): Long = { - (s.read().toLong << 56) | + val first = s.read() + if (first < 0) throw new EOFException() + (first.toLong << 56) | (s.read() & 0xFF).toLong << 48 | (s.read() & 0xFF).toLong << 40 | (s.read() & 0xFF).toLong << 32 | diff --git a/graph/src/test/scala/org/apache/spark/graph/SerializerSuite.scala b/graph/src/test/scala/org/apache/spark/graph/SerializerSuite.scala index 5a59fd912a..0d55cc0189 100644 --- a/graph/src/test/scala/org/apache/spark/graph/SerializerSuite.scala +++ b/graph/src/test/scala/org/apache/spark/graph/SerializerSuite.scala @@ -4,8 +4,7 @@ import org.scalatest.FunSuite import org.apache.spark.SparkContext import org.apache.spark.graph.LocalSparkContext._ -import java.io.ByteArrayInputStream -import java.io.ByteArrayOutputStream +import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream} import org.apache.spark.graph.impl._ import org.apache.spark.graph.impl.MsgRDDFunctions._ import org.apache.spark._ @@ -31,6 +30,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext { assert(outMsg.vid === inMsg2.vid) assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg2.data) + + intercept[EOFException] { + inStrm.readObject() + } } test("TestVertexBroadcastMessageLong") { @@ -48,6 +51,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext { assert(outMsg.vid === inMsg2.vid) assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg2.data) + + intercept[EOFException] { + inStrm.readObject() + } } test("TestVertexBroadcastMessageDouble") { @@ -65,6 +72,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext { assert(outMsg.vid === inMsg2.vid) assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg2.data) + + intercept[EOFException] { + inStrm.readObject() + } } test("TestAggregationMessageInt") { @@ -82,6 +93,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext { assert(outMsg.vid === inMsg2.vid) assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg2.data) + + intercept[EOFException] { + inStrm.readObject() + } } test("TestAggregationMessageLong") { @@ -99,6 +114,10 @@ class SerializerSuite extends FunSuite with LocalSparkContext { assert(outMsg.vid === inMsg2.vid) assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg2.data) + + intercept[EOFException] { + inStrm.readObject() + } } test("TestAggregationMessageDouble") { @@ -116,23 +135,25 @@ class SerializerSuite extends FunSuite with LocalSparkContext { assert(outMsg.vid === inMsg2.vid) assert(outMsg.data === inMsg1.data) assert(outMsg.data === inMsg2.data) + + intercept[EOFException] { + inStrm.readObject() + } } test("TestShuffleVertexBroadcastMsg") { withSpark(new SparkContext("local[2]", "test")) { sc => - val bmsgs = sc.parallelize( - (0 until 100).map(pid => new VertexBroadcastMsg[Int](pid, pid, pid)), 10) - val partitioner = new HashPartitioner(3) - val bmsgsArray = bmsgs.partitionBy(partitioner).collect + val bmsgs = sc.parallelize(0 until 100, 10).map { pid => + new VertexBroadcastMsg[Int](pid, pid, pid) + } + bmsgs.partitionBy(new HashPartitioner(3)).collect() } } test("TestShuffleAggregationMsg") { withSpark(new SparkContext("local[2]", "test")) { sc => - val bmsgs = sc.parallelize( - (0 until 100).map(pid => new AggregationMsg[Int](pid, pid)), 10) - val partitioner = new HashPartitioner(3) - val bmsgsArray = bmsgs.partitionBy(partitioner).collect + val bmsgs = sc.parallelize(0 until 100, 10).map(pid => new AggregationMsg[Int](pid, pid)) + bmsgs.partitionBy(new HashPartitioner(3)).collect() } }