[SPARK-29341][PYTHON] Upgrade cloudpickle to 1.0.0

### What changes were proposed in this pull request?

This patch upgrades cloudpickle to 1.0.0 version.

Main changes:

1. cleanup unused functions: 936f16fac8
2. Fix relative imports inside function body: 31ecdd6f57
3. Write kw only arguments to pickle: 6cb4718528

### Why are the changes needed?

We should include new bug fix like 6cb4718528, because users might use such python function in PySpark.

```python
>>> def f(a, *, b=1):
...   return a + b
...
>>> rdd = sc.parallelize([1, 2, 3])
>>> rdd.map(f).collect()
[Stage 0:>                                                        (0 + 12) / 12]19/10/03 00:42:24 ERROR Executor: Exception in task 3.0 in stage 0.0 (TID 3)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/spark/python/lib/pyspark.zip/pyspark/worker.py", line 598, in main
    process()
  File "/spark/python/lib/pyspark.zip/pyspark/worker.py", line 590, in process
    serializer.dump_stream(out_iter, outfile)
  File "/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 513, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "/spark/python/lib/pyspark.zip/pyspark/util.py", line 99, in wrapper
    return f(*args, **kwargs)
TypeError: f() missing 1 required keyword-only argument: 'b'
```

After:

```python
>>> def f(a, *, b=1):
...   return a + b
...
>>> rdd = sc.parallelize([1, 2, 3])
>>> rdd.map(f).collect()
[2, 3, 4]
```

### Does this PR introduce any user-facing change?

Yes. This fixes two bugs when pickling Python functions.

### How was this patch tested?

Existing tests.

Closes #26009 from viirya/upgrade-cloudpickle.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
Liang-Chi Hsieh 2019-10-03 19:20:51 +09:00 committed by HyukjinKwon
parent 858bf76e35
commit 2bc3fff13b
4 changed files with 22 additions and 50 deletions

View file

@ -21,10 +21,9 @@ import sys
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
import threading import threading
from pyspark.cloudpickle import print_exec
from pyspark.java_gateway import local_connect_and_auth from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import ChunkedStream, pickle_protocol from pyspark.serializers import ChunkedStream, pickle_protocol
from pyspark.util import _exception_message from pyspark.util import _exception_message, print_exec
if sys.version < '3': if sys.version < '3':
import cPickle as pickle import cPickle as pickle

View file

@ -591,6 +591,8 @@ class CloudPickler(Pickler):
state['annotations'] = func.__annotations__ state['annotations'] = func.__annotations__
if hasattr(func, '__qualname__'): if hasattr(func, '__qualname__'):
state['qualname'] = func.__qualname__ state['qualname'] = func.__qualname__
if hasattr(func, '__kwdefaults__'):
state['kwdefaults'] = func.__kwdefaults__
save(state) save(state)
write(pickle.TUPLE) write(pickle.TUPLE)
write(pickle.REDUCE) # applies _fill_function on the tuple write(pickle.REDUCE) # applies _fill_function on the tuple
@ -666,6 +668,15 @@ class CloudPickler(Pickler):
# multiple invokations are bound to the same Cloudpickler. # multiple invokations are bound to the same Cloudpickler.
base_globals = self.globals_ref.setdefault(id(func.__globals__), {}) base_globals = self.globals_ref.setdefault(id(func.__globals__), {})
if base_globals == {}:
# Add module attributes used to resolve relative imports
# instructions inside func.
for k in ["__package__", "__name__", "__path__", "__file__"]:
# Some built-in functions/methods such as object.__new__ have
# their __globals__ set to None in PyPy
if func.__globals__ is not None and k in func.__globals__:
base_globals[k] = func.__globals__[k]
return (code, f_globals, defaults, closure, dct, base_globals) return (code, f_globals, defaults, closure, dct, base_globals)
def save_builtin_function(self, obj): def save_builtin_function(self, obj):
@ -979,43 +990,6 @@ def _restore_attr(obj, attr):
return obj return obj
def _get_module_builtins():
return pickle.__builtins__
def print_exec(stream):
ei = sys.exc_info()
traceback.print_exception(ei[0], ei[1], ei[2], None, stream)
def _modules_to_main(modList):
"""Force every module in modList to be placed into main"""
if not modList:
return
main = sys.modules['__main__']
for modname in modList:
if type(modname) is str:
try:
mod = __import__(modname)
except Exception:
sys.stderr.write('warning: could not import %s\n. '
'Your function may unexpectedly error due to this import failing;'
'A version mismatch is likely. Specific error was:\n' % modname)
print_exec(sys.stderr)
else:
setattr(main, mod.__name__, mod)
# object generators:
def _genpartial(func, args, kwds):
if not args:
args = ()
if not kwds:
kwds = {}
return partial(func, *args, **kwds)
def _gen_ellipsis(): def _gen_ellipsis():
return Ellipsis return Ellipsis
@ -1103,6 +1077,8 @@ def _fill_function(*args):
func.__module__ = state['module'] func.__module__ = state['module']
if 'qualname' in state: if 'qualname' in state:
func.__qualname__ = state['qualname'] func.__qualname__ = state['qualname']
if 'kwdefaults' in state:
func.__kwdefaults__ = state['kwdefaults']
cells = func.__closure__ cells = func.__closure__
if cells is not None: if cells is not None:
@ -1188,15 +1164,6 @@ def _is_dynamic(module):
return False return False
"""Constructors for 3rd party libraries
Note: These can never be renamed due to client compatibility issues"""
def _getobject(modname, attribute):
mod = __import__(modname, fromlist=[attribute])
return mod.__dict__[attribute]
""" Use copy_reg to extend global pickle definitions """ """ Use copy_reg to extend global pickle definitions """
if sys.version_info < (3, 4): # pragma: no branch if sys.version_info < (3, 4): # pragma: no branch

View file

@ -69,7 +69,7 @@ else:
pickle_protocol = pickle.HIGHEST_PROTOCOL pickle_protocol = pickle.HIGHEST_PROTOCOL
from pyspark import cloudpickle from pyspark import cloudpickle
from pyspark.util import _exception_message from pyspark.util import _exception_message, print_exec
__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"] __all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"]
@ -716,7 +716,7 @@ class CloudPickleSerializer(PickleSerializer):
msg = "Object too large to serialize: %s" % emsg msg = "Object too large to serialize: %s" % emsg
else: else:
msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg)
cloudpickle.print_exec(sys.stderr) print_exec(sys.stderr)
raise pickle.PicklingError(msg) raise pickle.PicklingError(msg)

View file

@ -18,6 +18,7 @@
import re import re
import sys import sys
import traceback
import inspect import inspect
from py4j.protocol import Py4JJavaError from py4j.protocol import Py4JJavaError
@ -62,6 +63,11 @@ def _get_argspec(f):
return argspec return argspec
def print_exec(stream):
ei = sys.exc_info()
traceback.print_exception(ei[0], ei[1], ei[2], None, stream)
class VersionUtils(object): class VersionUtils(object):
""" """
Provides utility method to determine Spark versions with given input string. Provides utility method to determine Spark versions with given input string.