[SPARK-6216] [PYSPARK] check python version of worker with driver
This PR revert #5404, change to pass the version of python in driver into JVM, check it in worker before deserializing closure, then it can works with different major version of Python. Author: Davies Liu <davies@databricks.com> Closes #6203 from davies/py_version and squashes the following commits: b8fb76e [Davies Liu] fix test 6ce5096 [Davies Liu] use string for version 47c6278 [Davies Liu] check python version of worker with driver
This commit is contained in:
parent
9dadf019b9
commit
32fbd297dd
|
@ -47,6 +47,7 @@ private[spark] class PythonRDD(
|
||||||
pythonIncludes: JList[String],
|
pythonIncludes: JList[String],
|
||||||
preservePartitoning: Boolean,
|
preservePartitoning: Boolean,
|
||||||
pythonExec: String,
|
pythonExec: String,
|
||||||
|
pythonVer: String,
|
||||||
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
||||||
accumulator: Accumulator[JList[Array[Byte]]])
|
accumulator: Accumulator[JList[Array[Byte]]])
|
||||||
extends RDD[Array[Byte]](parent) {
|
extends RDD[Array[Byte]](parent) {
|
||||||
|
@ -210,6 +211,8 @@ private[spark] class PythonRDD(
|
||||||
val dataOut = new DataOutputStream(stream)
|
val dataOut = new DataOutputStream(stream)
|
||||||
// Partition index
|
// Partition index
|
||||||
dataOut.writeInt(split.index)
|
dataOut.writeInt(split.index)
|
||||||
|
// Python version of driver
|
||||||
|
PythonRDD.writeUTF(pythonVer, dataOut)
|
||||||
// sparkFilesDir
|
// sparkFilesDir
|
||||||
PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
|
PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
|
||||||
// Python includes (*.zip and *.egg files)
|
// Python includes (*.zip and *.egg files)
|
||||||
|
|
|
@ -173,6 +173,7 @@ class SparkContext(object):
|
||||||
self._jvm.PythonAccumulatorParam(host, port))
|
self._jvm.PythonAccumulatorParam(host, port))
|
||||||
|
|
||||||
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
|
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
|
||||||
|
self.pythonVer = "%d.%d" % sys.version_info[:2]
|
||||||
|
|
||||||
# Broadcast's __reduce__ method stores Broadcast instances here.
|
# Broadcast's __reduce__ method stores Broadcast instances here.
|
||||||
# This allows other code to determine which Broadcast instances have
|
# This allows other code to determine which Broadcast instances have
|
||||||
|
|
|
@ -2260,7 +2260,7 @@ class RDD(object):
|
||||||
def _prepare_for_python_RDD(sc, command, obj=None):
|
def _prepare_for_python_RDD(sc, command, obj=None):
|
||||||
# the serialized command will be compressed by broadcast
|
# the serialized command will be compressed by broadcast
|
||||||
ser = CloudPickleSerializer()
|
ser = CloudPickleSerializer()
|
||||||
pickled_command = ser.dumps((command, sys.version_info[:2]))
|
pickled_command = ser.dumps(command)
|
||||||
if len(pickled_command) > (1 << 20): # 1M
|
if len(pickled_command) > (1 << 20): # 1M
|
||||||
# The broadcast will have same life cycle as created PythonRDD
|
# The broadcast will have same life cycle as created PythonRDD
|
||||||
broadcast = sc.broadcast(pickled_command)
|
broadcast = sc.broadcast(pickled_command)
|
||||||
|
@ -2344,7 +2344,7 @@ class PipelinedRDD(RDD):
|
||||||
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
|
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
|
||||||
bytearray(pickled_cmd),
|
bytearray(pickled_cmd),
|
||||||
env, includes, self.preservesPartitioning,
|
env, includes, self.preservesPartitioning,
|
||||||
self.ctx.pythonExec,
|
self.ctx.pythonExec, self.ctx.pythonVer,
|
||||||
bvars, self.ctx._javaAccumulator)
|
bvars, self.ctx._javaAccumulator)
|
||||||
self._jrdd_val = python_rdd.asJavaRDD()
|
self._jrdd_val = python_rdd.asJavaRDD()
|
||||||
|
|
||||||
|
|
|
@ -157,6 +157,7 @@ class SQLContext(object):
|
||||||
env,
|
env,
|
||||||
includes,
|
includes,
|
||||||
self._sc.pythonExec,
|
self._sc.pythonExec,
|
||||||
|
self._sc.pythonVer,
|
||||||
bvars,
|
bvars,
|
||||||
self._sc._javaAccumulator,
|
self._sc._javaAccumulator,
|
||||||
returnType.json())
|
returnType.json())
|
||||||
|
|
|
@ -353,8 +353,8 @@ class UserDefinedFunction(object):
|
||||||
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
|
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
|
||||||
jdt = ssql_ctx.parseDataType(self.returnType.json())
|
jdt = ssql_ctx.parseDataType(self.returnType.json())
|
||||||
fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
|
fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
|
||||||
judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env,
|
judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes,
|
||||||
includes, sc.pythonExec, broadcast_vars,
|
sc.pythonExec, sc.pythonVer, broadcast_vars,
|
||||||
sc._javaAccumulator, jdt)
|
sc._javaAccumulator, jdt)
|
||||||
return judf
|
return judf
|
||||||
|
|
||||||
|
|
|
@ -1543,13 +1543,13 @@ class WorkerTests(ReusedPySparkTestCase):
|
||||||
def test_with_different_versions_of_python(self):
|
def test_with_different_versions_of_python(self):
|
||||||
rdd = self.sc.parallelize(range(10))
|
rdd = self.sc.parallelize(range(10))
|
||||||
rdd.count()
|
rdd.count()
|
||||||
version = sys.version_info
|
version = self.sc.pythonVer
|
||||||
sys.version_info = (2, 0, 0)
|
self.sc.pythonVer = "2.0"
|
||||||
try:
|
try:
|
||||||
with QuietTest(self.sc):
|
with QuietTest(self.sc):
|
||||||
self.assertRaises(Py4JJavaError, lambda: rdd.count())
|
self.assertRaises(Py4JJavaError, lambda: rdd.count())
|
||||||
finally:
|
finally:
|
||||||
sys.version_info = version
|
self.sc.pythonVer = version
|
||||||
|
|
||||||
|
|
||||||
class SparkSubmitTests(unittest.TestCase):
|
class SparkSubmitTests(unittest.TestCase):
|
||||||
|
|
|
@ -57,6 +57,12 @@ def main(infile, outfile):
|
||||||
if split_index == -1: # for unit tests
|
if split_index == -1: # for unit tests
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
|
version = utf8_deserializer.loads(infile)
|
||||||
|
if version != "%d.%d" % sys.version_info[:2]:
|
||||||
|
raise Exception(("Python in worker has different version %s than that in " +
|
||||||
|
"driver %s, PySpark cannot run with different minor versions") %
|
||||||
|
("%d.%d" % sys.version_info[:2], version))
|
||||||
|
|
||||||
# initialize global state
|
# initialize global state
|
||||||
shuffle.MemoryBytesSpilled = 0
|
shuffle.MemoryBytesSpilled = 0
|
||||||
shuffle.DiskBytesSpilled = 0
|
shuffle.DiskBytesSpilled = 0
|
||||||
|
@ -92,11 +98,7 @@ def main(infile, outfile):
|
||||||
command = pickleSer._read_with_length(infile)
|
command = pickleSer._read_with_length(infile)
|
||||||
if isinstance(command, Broadcast):
|
if isinstance(command, Broadcast):
|
||||||
command = pickleSer.loads(command.value)
|
command = pickleSer.loads(command.value)
|
||||||
(func, profiler, deserializer, serializer), version = command
|
func, profiler, deserializer, serializer = command
|
||||||
if version != sys.version_info[:2]:
|
|
||||||
raise Exception(("Python in worker has different version %s than that in " +
|
|
||||||
"driver %s, PySpark cannot run with different minor versions") %
|
|
||||||
(sys.version_info[:2], version))
|
|
||||||
init_time = time.time()
|
init_time = time.time()
|
||||||
|
|
||||||
def process():
|
def process():
|
||||||
|
|
|
@ -46,6 +46,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
|
||||||
envVars: JMap[String, String],
|
envVars: JMap[String, String],
|
||||||
pythonIncludes: JList[String],
|
pythonIncludes: JList[String],
|
||||||
pythonExec: String,
|
pythonExec: String,
|
||||||
|
pythonVer: String,
|
||||||
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
||||||
accumulator: Accumulator[JList[Array[Byte]]],
|
accumulator: Accumulator[JList[Array[Byte]]],
|
||||||
stringDataType: String): Unit = {
|
stringDataType: String): Unit = {
|
||||||
|
@ -70,6 +71,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
|
||||||
envVars,
|
envVars,
|
||||||
pythonIncludes,
|
pythonIncludes,
|
||||||
pythonExec,
|
pythonExec,
|
||||||
|
pythonVer,
|
||||||
broadcastVars,
|
broadcastVars,
|
||||||
accumulator,
|
accumulator,
|
||||||
dataType,
|
dataType,
|
||||||
|
|
|
@ -58,14 +58,15 @@ private[sql] case class UserDefinedPythonFunction(
|
||||||
envVars: JMap[String, String],
|
envVars: JMap[String, String],
|
||||||
pythonIncludes: JList[String],
|
pythonIncludes: JList[String],
|
||||||
pythonExec: String,
|
pythonExec: String,
|
||||||
|
pythonVer: String,
|
||||||
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
||||||
accumulator: Accumulator[JList[Array[Byte]]],
|
accumulator: Accumulator[JList[Array[Byte]]],
|
||||||
dataType: DataType) {
|
dataType: DataType) {
|
||||||
|
|
||||||
/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
|
/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
|
||||||
def apply(exprs: Column*): Column = {
|
def apply(exprs: Column*): Column = {
|
||||||
val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars,
|
val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer,
|
||||||
accumulator, dataType, exprs.map(_.expr))
|
broadcastVars, accumulator, dataType, exprs.map(_.expr))
|
||||||
Column(udf)
|
Column(udf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,6 +46,7 @@ private[spark] case class PythonUDF(
|
||||||
envVars: JMap[String, String],
|
envVars: JMap[String, String],
|
||||||
pythonIncludes: JList[String],
|
pythonIncludes: JList[String],
|
||||||
pythonExec: String,
|
pythonExec: String,
|
||||||
|
pythonVer: String,
|
||||||
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
||||||
accumulator: Accumulator[JList[Array[Byte]]],
|
accumulator: Accumulator[JList[Array[Byte]]],
|
||||||
dataType: DataType,
|
dataType: DataType,
|
||||||
|
@ -251,6 +252,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
|
||||||
udf.pythonIncludes,
|
udf.pythonIncludes,
|
||||||
false,
|
false,
|
||||||
udf.pythonExec,
|
udf.pythonExec,
|
||||||
|
udf.pythonVer,
|
||||||
udf.broadcastVars,
|
udf.broadcastVars,
|
||||||
udf.accumulator
|
udf.accumulator
|
||||||
).mapPartitions { iter =>
|
).mapPartitions { iter =>
|
||||||
|
|
Loading…
Reference in a new issue