[SPARK-4841] fix zip with textFile()

UTF8Deserializer can not be used in BatchedSerializer, so always use PickleSerializer() when change batchSize in zip().

Also, if two RDD have the same batch size already, they did not need re-serialize any more.

Author: Davies Liu <davies@databricks.com>

Closes #3706 from davies/fix_4841 and squashes the following commits:

20ce3a3 [Davies Liu] fix bug in _reserialize()
e3ebf7c [Davies Liu] add comment
379d2c8 [Davies Liu] fix zip with textFile()
This commit is contained in:
Davies Liu 2014-12-15 22:58:26 -08:00 committed by Josh Rosen
parent c7628771da
commit c246b95dd2
3 changed files with 26 additions and 14 deletions

View file

@ -469,7 +469,6 @@ class RDD(object):
def _reserialize(self, serializer=None):
serializer = serializer or self.ctx.serializer
if self._jrdd_deserializer != serializer:
if not isinstance(self, PipelinedRDD):
self = self.map(lambda x: x, preservesPartitioning=True)
self._jrdd_deserializer = serializer
return self
@ -1798,16 +1797,14 @@ class RDD(object):
def get_batch_size(ser):
if isinstance(ser, BatchedSerializer):
return ser.batchSize
return 1
return 1 # not batched
def batch_as(rdd, batchSize):
ser = rdd._jrdd_deserializer
if isinstance(ser, BatchedSerializer):
ser = ser.serializer
return rdd._reserialize(BatchedSerializer(ser, batchSize))
return rdd._reserialize(BatchedSerializer(PickleSerializer(), batchSize))
my_batch = get_batch_size(self._jrdd_deserializer)
other_batch = get_batch_size(other._jrdd_deserializer)
if my_batch != other_batch:
# use the smallest batchSize for both of them
batchSize = min(my_batch, other_batch)
if batchSize <= 0:

View file

@ -463,6 +463,9 @@ class CompressedSerializer(FramedSerializer):
def loads(self, obj):
return self.serializer.loads(zlib.decompress(obj))
def __eq__(self, other):
return isinstance(other, CompressedSerializer) and self.serializer == other.serializer
class UTF8Deserializer(Serializer):
@ -489,6 +492,9 @@ class UTF8Deserializer(Serializer):
except EOFError:
return
def __eq__(self, other):
return isinstance(other, UTF8Deserializer) and self.use_unicode == other.use_unicode
def read_long(stream):
length = stream.read(8)

View file

@ -533,6 +533,15 @@ class RDDTests(ReusedPySparkTestCase):
a = a._reserialize(BatchedSerializer(PickleSerializer(), 2))
b = b._reserialize(MarshalSerializer())
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
# regression test for SPARK-4841
path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
t = self.sc.textFile(path)
cnt = t.count()
self.assertEqual(cnt, t.zip(t).count())
rdd = t.map(str)
self.assertEqual(cnt, t.zip(rdd).count())
# regression test for bug in _reserializer()
self.assertEqual(cnt, t.zip(rdd).count())
def test_zip_with_different_number_of_items(self):
a = self.sc.parallelize(range(5), 2)