[SPARK-2790] [PySpark] fix zip with serializers which have different batch sizes.

If two RDDs have different batch size in serializers, then it will try to re-serialize the one with smaller batch size, then call RDD.zip() in Spark.

Author: Davies Liu <davies.liu@gmail.com>

Closes #1894 from davies/zip and squashes the following commits:

c4652ea [Davies Liu] add more test cases
6d05fc8 [Davies Liu] Merge branch 'master' into zip
813b1e4 [Davies Liu] add more tests for failed cases
a4aafda [Davies Liu] fix zip with serializers which have different batch sizes.
This commit is contained in:
Davies Liu 2014-08-19 14:46:32 -07:00 committed by Josh Rosen
parent 76eaeb4523
commit d7e80c2597
3 changed files with 54 additions and 1 deletions

View file

@ -1687,6 +1687,31 @@ class RDD(object):
>>> x.zip(y).collect()
[(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)]
"""
if self.getNumPartitions() != other.getNumPartitions():
raise ValueError("Can only zip with RDD which has the same number of partitions")
def get_batch_size(ser):
if isinstance(ser, BatchedSerializer):
return ser.batchSize
return 0
def batch_as(rdd, batchSize):
ser = rdd._jrdd_deserializer
if isinstance(ser, BatchedSerializer):
ser = ser.serializer
return rdd._reserialize(BatchedSerializer(ser, batchSize))
my_batch = get_batch_size(self._jrdd_deserializer)
other_batch = get_batch_size(other._jrdd_deserializer)
if my_batch != other_batch:
# use the greatest batchSize to batch the other one.
if my_batch > other_batch:
other = batch_as(other, my_batch)
else:
self = batch_as(self, other_batch)
# There will be an Exception in JVM if there are different number
# of items in each partitions.
pairRDD = self._jrdd.zip(other._jrdd)
deserializer = PairDeserializer(self._jrdd_deserializer,
other._jrdd_deserializer)

View file

@ -255,6 +255,9 @@ class PairDeserializer(CartesianDeserializer):
def load_stream(self, stream):
for (keys, vals) in self.prepare_keys_values(stream):
if len(keys) != len(vals):
raise ValueError("Can not deserialize RDD with different number of items"
" in pair: (%d, %d)" % (len(keys), len(vals)))
for pair in izip(keys, vals):
yield pair

View file

@ -39,7 +39,7 @@ else:
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
_have_scipy = False
@ -339,6 +339,31 @@ class TestRDDFunctions(PySparkTestCase):
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEquals(N, m)
def test_zip_with_different_serializers(self):
a = self.sc.parallelize(range(5))
b = self.sc.parallelize(range(100, 105))
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
a = a._reserialize(BatchedSerializer(PickleSerializer(), 2))
b = b._reserialize(MarshalSerializer())
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
def test_zip_with_different_number_of_items(self):
a = self.sc.parallelize(range(5), 2)
# different number of partitions
b = self.sc.parallelize(range(100, 106), 3)
self.assertRaises(ValueError, lambda: a.zip(b))
# different number of batched items in JVM
b = self.sc.parallelize(range(100, 104), 2)
self.assertRaises(Exception, lambda: a.zip(b).count())
# different number of items in one pair
b = self.sc.parallelize(range(100, 106), 2)
self.assertRaises(Exception, lambda: a.zip(b).count())
# same total number of items, but different distributions
a = self.sc.parallelize([2, 3], 2).flatMap(range)
b = self.sc.parallelize([3, 2], 2).flatMap(range)
self.assertEquals(a.count(), b.count())
self.assertRaises(Exception, lambda: a.zip(b).count())
class TestIO(PySparkTestCase):