SPARK-977 Added Python RDD.zip function
was raised earlier as a part of apache/incubator-spark#486 Author: Prabin Banka <prabin.banka@imaginea.com> Closes #76 from prabinb/python-api-zip and squashes the following commits: b1a31a0 [Prabin Banka] Added Python RDD.zip function
This commit is contained in:
parent
5d98cfc1c8
commit
e1e09e0ef6
|
@ -30,7 +30,7 @@ from threading import Thread
|
|||
import warnings
|
||||
|
||||
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
|
||||
BatchedSerializer, CloudPickleSerializer, pack_long
|
||||
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
|
||||
from pyspark.join import python_join, python_left_outer_join, \
|
||||
python_right_outer_join, python_cogroup
|
||||
from pyspark.statcounter import StatCounter
|
||||
|
@ -1081,6 +1081,24 @@ class RDD(object):
|
|||
jrdd = self._jrdd.coalesce(numPartitions)
|
||||
return RDD(jrdd, self.ctx, self._jrdd_deserializer)
|
||||
|
||||
def zip(self, other):
|
||||
"""
|
||||
Zips this RDD with another one, returning key-value pairs with the first element in each RDD
|
||||
second element in each RDD, etc. Assumes that the two RDDs have the same number of
|
||||
partitions and the same number of elements in each partition (e.g. one was made through
|
||||
a map on the other).
|
||||
|
||||
>>> x = sc.parallelize(range(0,5))
|
||||
>>> y = sc.parallelize(range(1000, 1005))
|
||||
>>> x.zip(y).collect()
|
||||
[(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)]
|
||||
"""
|
||||
pairRDD = self._jrdd.zip(other._jrdd)
|
||||
deserializer = PairDeserializer(self._jrdd_deserializer,
|
||||
other._jrdd_deserializer)
|
||||
return RDD(pairRDD, self.ctx, deserializer)
|
||||
|
||||
|
||||
# TODO: `lookup` is disabled because we can't make direct comparisons based
|
||||
# on the key; we need to compare the hash of the key to the hash of the
|
||||
# keys in the pairs. This could be an expensive operation, since those
|
||||
|
|
|
@ -204,7 +204,7 @@ class CartesianDeserializer(FramedSerializer):
|
|||
self.key_ser = key_ser
|
||||
self.val_ser = val_ser
|
||||
|
||||
def load_stream(self, stream):
|
||||
def prepare_keys_values(self, stream):
|
||||
key_stream = self.key_ser._load_stream_without_unbatching(stream)
|
||||
val_stream = self.val_ser._load_stream_without_unbatching(stream)
|
||||
key_is_batched = isinstance(self.key_ser, BatchedSerializer)
|
||||
|
@ -212,6 +212,10 @@ class CartesianDeserializer(FramedSerializer):
|
|||
for (keys, vals) in izip(key_stream, val_stream):
|
||||
keys = keys if key_is_batched else [keys]
|
||||
vals = vals if val_is_batched else [vals]
|
||||
yield (keys, vals)
|
||||
|
||||
def load_stream(self, stream):
|
||||
for (keys, vals) in self.prepare_keys_values(stream):
|
||||
for pair in product(keys, vals):
|
||||
yield pair
|
||||
|
||||
|
@ -224,6 +228,29 @@ class CartesianDeserializer(FramedSerializer):
|
|||
(str(self.key_ser), str(self.val_ser))
|
||||
|
||||
|
||||
class PairDeserializer(CartesianDeserializer):
|
||||
"""
|
||||
Deserializes the JavaRDD zip() of two PythonRDDs.
|
||||
"""
|
||||
|
||||
def __init__(self, key_ser, val_ser):
|
||||
self.key_ser = key_ser
|
||||
self.val_ser = val_ser
|
||||
|
||||
def load_stream(self, stream):
|
||||
for (keys, vals) in self.prepare_keys_values(stream):
|
||||
for pair in izip(keys, vals):
|
||||
yield pair
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, PairDeserializer) and \
|
||||
self.key_ser == other.key_ser and self.val_ser == other.val_ser
|
||||
|
||||
def __str__(self):
|
||||
return "PairDeserializer<%s, %s>" % \
|
||||
(str(self.key_ser), str(self.val_ser))
|
||||
|
||||
|
||||
class NoOpSerializer(FramedSerializer):
|
||||
|
||||
def loads(self, obj): return obj
|
||||
|
|
Loading…
Reference in a new issue