[SPARK-2790] [PySpark] fix zip with serializers which have different batch sizes.
If two RDDs have different batch size in serializers, then it will try to re-serialize the one with smaller batch size, then call RDD.zip() in Spark. Author: Davies Liu <davies.liu@gmail.com> Closes #1894 from davies/zip and squashes the following commits: c4652ea [Davies Liu] add more test cases 6d05fc8 [Davies Liu] Merge branch 'master' into zip 813b1e4 [Davies Liu] add more tests for failed cases a4aafda [Davies Liu] fix zip with serializers which have different batch sizes.
This commit is contained in:
parent
76eaeb4523
commit
d7e80c2597
|
@ -1687,6 +1687,31 @@ class RDD(object):
|
|||
>>> x.zip(y).collect()
|
||||
[(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)]
|
||||
"""
|
||||
if self.getNumPartitions() != other.getNumPartitions():
|
||||
raise ValueError("Can only zip with RDD which has the same number of partitions")
|
||||
|
||||
def get_batch_size(ser):
|
||||
if isinstance(ser, BatchedSerializer):
|
||||
return ser.batchSize
|
||||
return 0
|
||||
|
||||
def batch_as(rdd, batchSize):
|
||||
ser = rdd._jrdd_deserializer
|
||||
if isinstance(ser, BatchedSerializer):
|
||||
ser = ser.serializer
|
||||
return rdd._reserialize(BatchedSerializer(ser, batchSize))
|
||||
|
||||
my_batch = get_batch_size(self._jrdd_deserializer)
|
||||
other_batch = get_batch_size(other._jrdd_deserializer)
|
||||
if my_batch != other_batch:
|
||||
# use the greatest batchSize to batch the other one.
|
||||
if my_batch > other_batch:
|
||||
other = batch_as(other, my_batch)
|
||||
else:
|
||||
self = batch_as(self, other_batch)
|
||||
|
||||
# There will be an Exception in JVM if there are different number
|
||||
# of items in each partitions.
|
||||
pairRDD = self._jrdd.zip(other._jrdd)
|
||||
deserializer = PairDeserializer(self._jrdd_deserializer,
|
||||
other._jrdd_deserializer)
|
||||
|
|
|
@ -255,6 +255,9 @@ class PairDeserializer(CartesianDeserializer):
|
|||
|
||||
def load_stream(self, stream):
|
||||
for (keys, vals) in self.prepare_keys_values(stream):
|
||||
if len(keys) != len(vals):
|
||||
raise ValueError("Can not deserialize RDD with different number of items"
|
||||
" in pair: (%d, %d)" % (len(keys), len(vals)))
|
||||
for pair in izip(keys, vals):
|
||||
yield pair
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ else:
|
|||
|
||||
from pyspark.context import SparkContext
|
||||
from pyspark.files import SparkFiles
|
||||
from pyspark.serializers import read_int
|
||||
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
|
||||
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
|
||||
|
||||
_have_scipy = False
|
||||
|
@ -339,6 +339,31 @@ class TestRDDFunctions(PySparkTestCase):
|
|||
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
|
||||
self.assertEquals(N, m)
|
||||
|
||||
def test_zip_with_different_serializers(self):
|
||||
a = self.sc.parallelize(range(5))
|
||||
b = self.sc.parallelize(range(100, 105))
|
||||
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
|
||||
a = a._reserialize(BatchedSerializer(PickleSerializer(), 2))
|
||||
b = b._reserialize(MarshalSerializer())
|
||||
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
|
||||
|
||||
def test_zip_with_different_number_of_items(self):
|
||||
a = self.sc.parallelize(range(5), 2)
|
||||
# different number of partitions
|
||||
b = self.sc.parallelize(range(100, 106), 3)
|
||||
self.assertRaises(ValueError, lambda: a.zip(b))
|
||||
# different number of batched items in JVM
|
||||
b = self.sc.parallelize(range(100, 104), 2)
|
||||
self.assertRaises(Exception, lambda: a.zip(b).count())
|
||||
# different number of items in one pair
|
||||
b = self.sc.parallelize(range(100, 106), 2)
|
||||
self.assertRaises(Exception, lambda: a.zip(b).count())
|
||||
# same total number of items, but different distributions
|
||||
a = self.sc.parallelize([2, 3], 2).flatMap(range)
|
||||
b = self.sc.parallelize([3, 2], 2).flatMap(range)
|
||||
self.assertEquals(a.count(), b.count())
|
||||
self.assertRaises(Exception, lambda: a.zip(b).count())
|
||||
|
||||
|
||||
class TestIO(PySparkTestCase):
|
||||
|
||||
|
|
Loading…
Reference in a new issue