[SPARK-29341][PYTHON] Upgrade cloudpickle to 1.0.0
### What changes were proposed in this pull request? This patch upgrades cloudpickle to 1.0.0 version. Main changes: 1. cleanup unused functions:936f16fac8
2. Fix relative imports inside function body:31ecdd6f57
3. Write kw only arguments to pickle:6cb4718528
### Why are the changes needed? We should include new bug fix like6cb4718528
, because users might use such python function in PySpark. ```python >>> def f(a, *, b=1): ... return a + b ... >>> rdd = sc.parallelize([1, 2, 3]) >>> rdd.map(f).collect() [Stage 0:> (0 + 12) / 12]19/10/03 00:42:24 ERROR Executor: Exception in task 3.0 in stage 0.0 (TID 3) org.apache.spark.api.python.PythonException: Traceback (most recent call last): File "/spark/python/lib/pyspark.zip/pyspark/worker.py", line 598, in main process() File "/spark/python/lib/pyspark.zip/pyspark/worker.py", line 590, in process serializer.dump_stream(out_iter, outfile) File "/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 513, in dump_stream vs = list(itertools.islice(iterator, batch)) File "/spark/python/lib/pyspark.zip/pyspark/util.py", line 99, in wrapper return f(*args, **kwargs) TypeError: f() missing 1 required keyword-only argument: 'b' ``` After: ```python >>> def f(a, *, b=1): ... return a + b ... >>> rdd = sc.parallelize([1, 2, 3]) >>> rdd.map(f).collect() [2, 3, 4] ``` ### Does this PR introduce any user-facing change? Yes. This fixes two bugs when pickling Python functions. ### How was this patch tested? Existing tests. Closes #26009 from viirya/upgrade-cloudpickle. Authored-by: Liang-Chi Hsieh <viirya@gmail.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
858bf76e35
commit
2bc3fff13b
|
@ -21,10 +21,9 @@ import sys
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from pyspark.cloudpickle import print_exec
|
|
||||||
from pyspark.java_gateway import local_connect_and_auth
|
from pyspark.java_gateway import local_connect_and_auth
|
||||||
from pyspark.serializers import ChunkedStream, pickle_protocol
|
from pyspark.serializers import ChunkedStream, pickle_protocol
|
||||||
from pyspark.util import _exception_message
|
from pyspark.util import _exception_message, print_exec
|
||||||
|
|
||||||
if sys.version < '3':
|
if sys.version < '3':
|
||||||
import cPickle as pickle
|
import cPickle as pickle
|
||||||
|
|
|
@ -591,6 +591,8 @@ class CloudPickler(Pickler):
|
||||||
state['annotations'] = func.__annotations__
|
state['annotations'] = func.__annotations__
|
||||||
if hasattr(func, '__qualname__'):
|
if hasattr(func, '__qualname__'):
|
||||||
state['qualname'] = func.__qualname__
|
state['qualname'] = func.__qualname__
|
||||||
|
if hasattr(func, '__kwdefaults__'):
|
||||||
|
state['kwdefaults'] = func.__kwdefaults__
|
||||||
save(state)
|
save(state)
|
||||||
write(pickle.TUPLE)
|
write(pickle.TUPLE)
|
||||||
write(pickle.REDUCE) # applies _fill_function on the tuple
|
write(pickle.REDUCE) # applies _fill_function on the tuple
|
||||||
|
@ -666,6 +668,15 @@ class CloudPickler(Pickler):
|
||||||
# multiple invokations are bound to the same Cloudpickler.
|
# multiple invokations are bound to the same Cloudpickler.
|
||||||
base_globals = self.globals_ref.setdefault(id(func.__globals__), {})
|
base_globals = self.globals_ref.setdefault(id(func.__globals__), {})
|
||||||
|
|
||||||
|
if base_globals == {}:
|
||||||
|
# Add module attributes used to resolve relative imports
|
||||||
|
# instructions inside func.
|
||||||
|
for k in ["__package__", "__name__", "__path__", "__file__"]:
|
||||||
|
# Some built-in functions/methods such as object.__new__ have
|
||||||
|
# their __globals__ set to None in PyPy
|
||||||
|
if func.__globals__ is not None and k in func.__globals__:
|
||||||
|
base_globals[k] = func.__globals__[k]
|
||||||
|
|
||||||
return (code, f_globals, defaults, closure, dct, base_globals)
|
return (code, f_globals, defaults, closure, dct, base_globals)
|
||||||
|
|
||||||
def save_builtin_function(self, obj):
|
def save_builtin_function(self, obj):
|
||||||
|
@ -979,43 +990,6 @@ def _restore_attr(obj, attr):
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def _get_module_builtins():
|
|
||||||
return pickle.__builtins__
|
|
||||||
|
|
||||||
|
|
||||||
def print_exec(stream):
|
|
||||||
ei = sys.exc_info()
|
|
||||||
traceback.print_exception(ei[0], ei[1], ei[2], None, stream)
|
|
||||||
|
|
||||||
|
|
||||||
def _modules_to_main(modList):
|
|
||||||
"""Force every module in modList to be placed into main"""
|
|
||||||
if not modList:
|
|
||||||
return
|
|
||||||
|
|
||||||
main = sys.modules['__main__']
|
|
||||||
for modname in modList:
|
|
||||||
if type(modname) is str:
|
|
||||||
try:
|
|
||||||
mod = __import__(modname)
|
|
||||||
except Exception:
|
|
||||||
sys.stderr.write('warning: could not import %s\n. '
|
|
||||||
'Your function may unexpectedly error due to this import failing;'
|
|
||||||
'A version mismatch is likely. Specific error was:\n' % modname)
|
|
||||||
print_exec(sys.stderr)
|
|
||||||
else:
|
|
||||||
setattr(main, mod.__name__, mod)
|
|
||||||
|
|
||||||
|
|
||||||
# object generators:
|
|
||||||
def _genpartial(func, args, kwds):
|
|
||||||
if not args:
|
|
||||||
args = ()
|
|
||||||
if not kwds:
|
|
||||||
kwds = {}
|
|
||||||
return partial(func, *args, **kwds)
|
|
||||||
|
|
||||||
|
|
||||||
def _gen_ellipsis():
|
def _gen_ellipsis():
|
||||||
return Ellipsis
|
return Ellipsis
|
||||||
|
|
||||||
|
@ -1103,6 +1077,8 @@ def _fill_function(*args):
|
||||||
func.__module__ = state['module']
|
func.__module__ = state['module']
|
||||||
if 'qualname' in state:
|
if 'qualname' in state:
|
||||||
func.__qualname__ = state['qualname']
|
func.__qualname__ = state['qualname']
|
||||||
|
if 'kwdefaults' in state:
|
||||||
|
func.__kwdefaults__ = state['kwdefaults']
|
||||||
|
|
||||||
cells = func.__closure__
|
cells = func.__closure__
|
||||||
if cells is not None:
|
if cells is not None:
|
||||||
|
@ -1188,15 +1164,6 @@ def _is_dynamic(module):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
"""Constructors for 3rd party libraries
|
|
||||||
Note: These can never be renamed due to client compatibility issues"""
|
|
||||||
|
|
||||||
|
|
||||||
def _getobject(modname, attribute):
|
|
||||||
mod = __import__(modname, fromlist=[attribute])
|
|
||||||
return mod.__dict__[attribute]
|
|
||||||
|
|
||||||
|
|
||||||
""" Use copy_reg to extend global pickle definitions """
|
""" Use copy_reg to extend global pickle definitions """
|
||||||
|
|
||||||
if sys.version_info < (3, 4): # pragma: no branch
|
if sys.version_info < (3, 4): # pragma: no branch
|
||||||
|
|
|
@ -69,7 +69,7 @@ else:
|
||||||
pickle_protocol = pickle.HIGHEST_PROTOCOL
|
pickle_protocol = pickle.HIGHEST_PROTOCOL
|
||||||
|
|
||||||
from pyspark import cloudpickle
|
from pyspark import cloudpickle
|
||||||
from pyspark.util import _exception_message
|
from pyspark.util import _exception_message, print_exec
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"]
|
__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"]
|
||||||
|
@ -716,7 +716,7 @@ class CloudPickleSerializer(PickleSerializer):
|
||||||
msg = "Object too large to serialize: %s" % emsg
|
msg = "Object too large to serialize: %s" % emsg
|
||||||
else:
|
else:
|
||||||
msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg)
|
msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg)
|
||||||
cloudpickle.print_exec(sys.stderr)
|
print_exec(sys.stderr)
|
||||||
raise pickle.PicklingError(msg)
|
raise pickle.PicklingError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import traceback
|
||||||
import inspect
|
import inspect
|
||||||
from py4j.protocol import Py4JJavaError
|
from py4j.protocol import Py4JJavaError
|
||||||
|
|
||||||
|
@ -62,6 +63,11 @@ def _get_argspec(f):
|
||||||
return argspec
|
return argspec
|
||||||
|
|
||||||
|
|
||||||
|
def print_exec(stream):
|
||||||
|
ei = sys.exc_info()
|
||||||
|
traceback.print_exception(ei[0], ei[1], ei[2], None, stream)
|
||||||
|
|
||||||
|
|
||||||
class VersionUtils(object):
|
class VersionUtils(object):
|
||||||
"""
|
"""
|
||||||
Provides utility method to determine Spark versions with given input string.
|
Provides utility method to determine Spark versions with given input string.
|
||||||
|
|
Loading…
Reference in a new issue