[SPARK-1687] [PySpark] pickable namedtuple
Add an hook to replace original namedtuple with an pickable one, then namedtuple could be used in RDDs. PS: pyspark should be import BEFORE "from collections import namedtuple" Author: Davies Liu <davies.liu@gmail.com> Closes #1623 from davies/namedtuple and squashes the following commits: 045dad8 [Davies Liu] remove unrelated code changes 4132f32 [Davies Liu] address comment 55b1c1a [Davies Liu] fix tests 61f86eb [Davies Liu] replace all the reference of namedtuple to new hacked one 98df6c6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into namedtuple f7b1bde [Davies Liu] add hack for CloudPickleSerializer 0c5c849 [Davies Liu] Merge branch 'master' of github.com:apache/spark into namedtuple 21991e6 [Davies Liu] hack namedtuple in __main__ module, make it picklable. 93b03b8 [Davies Liu] pickable namedtuple
This commit is contained in:
parent
e053c55819
commit
59f84a9531
|
@ -65,6 +65,9 @@ from itertools import chain, izip, product
|
|||
import marshal
|
||||
import struct
|
||||
import sys
|
||||
import types
|
||||
import collections
|
||||
|
||||
from pyspark import cloudpickle
|
||||
|
||||
|
||||
|
@ -267,6 +270,63 @@ class NoOpSerializer(FramedSerializer):
|
|||
return obj
|
||||
|
||||
|
||||
# Hook namedtuple, make it picklable
|
||||
|
||||
__cls = {}
|
||||
|
||||
|
||||
def _restore(name, fields, value):
|
||||
""" Restore an object of namedtuple"""
|
||||
k = (name, fields)
|
||||
cls = __cls.get(k)
|
||||
if cls is None:
|
||||
cls = collections.namedtuple(name, fields)
|
||||
__cls[k] = cls
|
||||
return cls(*value)
|
||||
|
||||
|
||||
def _hack_namedtuple(cls):
|
||||
""" Make class generated by namedtuple picklable """
|
||||
name = cls.__name__
|
||||
fields = cls._fields
|
||||
def __reduce__(self):
|
||||
return (_restore, (name, fields, tuple(self)))
|
||||
cls.__reduce__ = __reduce__
|
||||
return cls
|
||||
|
||||
|
||||
def _hijack_namedtuple():
|
||||
""" Hack namedtuple() to make it picklable """
|
||||
global _old_namedtuple # or it will put in closure
|
||||
|
||||
def _copy_func(f):
|
||||
return types.FunctionType(f.func_code, f.func_globals, f.func_name,
|
||||
f.func_defaults, f.func_closure)
|
||||
|
||||
_old_namedtuple = _copy_func(collections.namedtuple)
|
||||
|
||||
def namedtuple(name, fields, verbose=False, rename=False):
|
||||
cls = _old_namedtuple(name, fields, verbose, rename)
|
||||
return _hack_namedtuple(cls)
|
||||
|
||||
# replace namedtuple with new one
|
||||
collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple
|
||||
collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple
|
||||
collections.namedtuple.func_code = namedtuple.func_code
|
||||
|
||||
# hack the cls already generated by namedtuple
|
||||
# those created in other module can be pickled as normal,
|
||||
# so only hack those in __main__ module
|
||||
for n, o in sys.modules["__main__"].__dict__.iteritems():
|
||||
if (type(o) is type and o.__base__ is tuple
|
||||
and hasattr(o, "_fields")
|
||||
and "__reduce__" not in o.__dict__):
|
||||
_hack_namedtuple(o) # hack inplace
|
||||
|
||||
|
||||
_hijack_namedtuple()
|
||||
|
||||
|
||||
class PickleSerializer(FramedSerializer):
|
||||
"""
|
||||
Serializes objects using Python's cPickle serializer:
|
||||
|
|
|
@ -112,6 +112,17 @@ class TestMerger(unittest.TestCase):
|
|||
m._cleanup()
|
||||
|
||||
|
||||
class SerializationTestCase(unittest.TestCase):
|
||||
|
||||
def test_namedtuple(self):
|
||||
from collections import namedtuple
|
||||
from cPickle import dumps, loads
|
||||
P = namedtuple("P", "x y")
|
||||
p1 = P(1, 3)
|
||||
p2 = loads(dumps(p1, 2))
|
||||
self.assertEquals(p1, p2)
|
||||
|
||||
|
||||
class PySparkTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -298,6 +309,14 @@ class TestRDDFunctions(PySparkTestCase):
|
|||
self.assertEqual([1], rdd.map(itemgetter(1)).collect())
|
||||
self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect())
|
||||
|
||||
def test_namedtuple_in_rdd(self):
|
||||
from collections import namedtuple
|
||||
Person = namedtuple("Person", "id firstName lastName")
|
||||
jon = Person(1, "Jon", "Doe")
|
||||
jane = Person(2, "Jane", "Doe")
|
||||
theDoes = self.sc.parallelize([jon, jane])
|
||||
self.assertEquals([jon, jane], theDoes.collect())
|
||||
|
||||
|
||||
class TestIO(PySparkTestCase):
|
||||
|
||||
|
|
Loading…
Reference in a new issue