[SPARK-6216] [PYSPARK] check python version of worker with driver
This PR revert #5404, change to pass the version of python in driver into JVM, check it in worker before deserializing closure, then it can works with different major version of Python.
Author: Davies Liu <davies@databricks.com>
Closes #6203 from davies/py_version and squashes the following commits:
b8fb76e [Davies Liu] fix test
6ce5096 [Davies Liu] use string for version
47c6278 [Davies Liu] check python version of worker with driver
(cherry picked from commit 32fbd297dd
)
Signed-off-by: Josh Rosen <joshrosen@databricks.com>
This commit is contained in:
parent
39623481fc
commit
a8332098ce
|
@ -47,6 +47,7 @@ private[spark] class PythonRDD(
|
|||
pythonIncludes: JList[String],
|
||||
preservePartitoning: Boolean,
|
||||
pythonExec: String,
|
||||
pythonVer: String,
|
||||
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
||||
accumulator: Accumulator[JList[Array[Byte]]])
|
||||
extends RDD[Array[Byte]](parent) {
|
||||
|
@ -210,6 +211,8 @@ private[spark] class PythonRDD(
|
|||
val dataOut = new DataOutputStream(stream)
|
||||
// Partition index
|
||||
dataOut.writeInt(split.index)
|
||||
// Python version of driver
|
||||
PythonRDD.writeUTF(pythonVer, dataOut)
|
||||
// sparkFilesDir
|
||||
PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
|
||||
// Python includes (*.zip and *.egg files)
|
||||
|
|
|
@ -173,6 +173,7 @@ class SparkContext(object):
|
|||
self._jvm.PythonAccumulatorParam(host, port))
|
||||
|
||||
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
|
||||
self.pythonVer = "%d.%d" % sys.version_info[:2]
|
||||
|
||||
# Broadcast's __reduce__ method stores Broadcast instances here.
|
||||
# This allows other code to determine which Broadcast instances have
|
||||
|
|
|
@ -2260,7 +2260,7 @@ class RDD(object):
|
|||
def _prepare_for_python_RDD(sc, command, obj=None):
|
||||
# the serialized command will be compressed by broadcast
|
||||
ser = CloudPickleSerializer()
|
||||
pickled_command = ser.dumps((command, sys.version_info[:2]))
|
||||
pickled_command = ser.dumps(command)
|
||||
if len(pickled_command) > (1 << 20): # 1M
|
||||
# The broadcast will have same life cycle as created PythonRDD
|
||||
broadcast = sc.broadcast(pickled_command)
|
||||
|
@ -2344,7 +2344,7 @@ class PipelinedRDD(RDD):
|
|||
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
|
||||
bytearray(pickled_cmd),
|
||||
env, includes, self.preservesPartitioning,
|
||||
self.ctx.pythonExec,
|
||||
self.ctx.pythonExec, self.ctx.pythonVer,
|
||||
bvars, self.ctx._javaAccumulator)
|
||||
self._jrdd_val = python_rdd.asJavaRDD()
|
||||
|
||||
|
|
|
@ -157,6 +157,7 @@ class SQLContext(object):
|
|||
env,
|
||||
includes,
|
||||
self._sc.pythonExec,
|
||||
self._sc.pythonVer,
|
||||
bvars,
|
||||
self._sc._javaAccumulator,
|
||||
returnType.json())
|
||||
|
|
|
@ -353,8 +353,8 @@ class UserDefinedFunction(object):
|
|||
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
|
||||
jdt = ssql_ctx.parseDataType(self.returnType.json())
|
||||
fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
|
||||
judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env,
|
||||
includes, sc.pythonExec, broadcast_vars,
|
||||
judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes,
|
||||
sc.pythonExec, sc.pythonVer, broadcast_vars,
|
||||
sc._javaAccumulator, jdt)
|
||||
return judf
|
||||
|
||||
|
|
|
@ -1543,13 +1543,13 @@ class WorkerTests(ReusedPySparkTestCase):
|
|||
def test_with_different_versions_of_python(self):
|
||||
rdd = self.sc.parallelize(range(10))
|
||||
rdd.count()
|
||||
version = sys.version_info
|
||||
sys.version_info = (2, 0, 0)
|
||||
version = self.sc.pythonVer
|
||||
self.sc.pythonVer = "2.0"
|
||||
try:
|
||||
with QuietTest(self.sc):
|
||||
self.assertRaises(Py4JJavaError, lambda: rdd.count())
|
||||
finally:
|
||||
sys.version_info = version
|
||||
self.sc.pythonVer = version
|
||||
|
||||
|
||||
class SparkSubmitTests(unittest.TestCase):
|
||||
|
|
|
@ -57,6 +57,12 @@ def main(infile, outfile):
|
|||
if split_index == -1: # for unit tests
|
||||
exit(-1)
|
||||
|
||||
version = utf8_deserializer.loads(infile)
|
||||
if version != "%d.%d" % sys.version_info[:2]:
|
||||
raise Exception(("Python in worker has different version %s than that in " +
|
||||
"driver %s, PySpark cannot run with different minor versions") %
|
||||
("%d.%d" % sys.version_info[:2], version))
|
||||
|
||||
# initialize global state
|
||||
shuffle.MemoryBytesSpilled = 0
|
||||
shuffle.DiskBytesSpilled = 0
|
||||
|
@ -92,11 +98,7 @@ def main(infile, outfile):
|
|||
command = pickleSer._read_with_length(infile)
|
||||
if isinstance(command, Broadcast):
|
||||
command = pickleSer.loads(command.value)
|
||||
(func, profiler, deserializer, serializer), version = command
|
||||
if version != sys.version_info[:2]:
|
||||
raise Exception(("Python in worker has different version %s than that in " +
|
||||
"driver %s, PySpark cannot run with different minor versions") %
|
||||
(sys.version_info[:2], version))
|
||||
func, profiler, deserializer, serializer = command
|
||||
init_time = time.time()
|
||||
|
||||
def process():
|
||||
|
|
|
@ -46,6 +46,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
|
|||
envVars: JMap[String, String],
|
||||
pythonIncludes: JList[String],
|
||||
pythonExec: String,
|
||||
pythonVer: String,
|
||||
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
||||
accumulator: Accumulator[JList[Array[Byte]]],
|
||||
stringDataType: String): Unit = {
|
||||
|
@ -70,6 +71,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
|
|||
envVars,
|
||||
pythonIncludes,
|
||||
pythonExec,
|
||||
pythonVer,
|
||||
broadcastVars,
|
||||
accumulator,
|
||||
dataType,
|
||||
|
|
|
@ -58,14 +58,15 @@ private[sql] case class UserDefinedPythonFunction(
|
|||
envVars: JMap[String, String],
|
||||
pythonIncludes: JList[String],
|
||||
pythonExec: String,
|
||||
pythonVer: String,
|
||||
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
||||
accumulator: Accumulator[JList[Array[Byte]]],
|
||||
dataType: DataType) {
|
||||
|
||||
/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
|
||||
def apply(exprs: Column*): Column = {
|
||||
val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars,
|
||||
accumulator, dataType, exprs.map(_.expr))
|
||||
val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer,
|
||||
broadcastVars, accumulator, dataType, exprs.map(_.expr))
|
||||
Column(udf)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,6 +46,7 @@ private[spark] case class PythonUDF(
|
|||
envVars: JMap[String, String],
|
||||
pythonIncludes: JList[String],
|
||||
pythonExec: String,
|
||||
pythonVer: String,
|
||||
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
||||
accumulator: Accumulator[JList[Array[Byte]]],
|
||||
dataType: DataType,
|
||||
|
@ -251,6 +252,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
|
|||
udf.pythonIncludes,
|
||||
false,
|
||||
udf.pythonExec,
|
||||
udf.pythonVer,
|
||||
udf.broadcastVars,
|
||||
udf.accumulator
|
||||
).mapPartitions { iter =>
|
||||
|
|
Loading…
Reference in a new issue