[SPARK-6216] [PYSPARK] check python version of worker with driver

This PR revert #5404, change to pass the version of python in driver into JVM, check it in worker before deserializing closure, then it can works with different major version of Python.

Author: Davies Liu <davies@databricks.com>

Closes #6203 from davies/py_version and squashes the following commits:

b8fb76e [Davies Liu] fix test
6ce5096 [Davies Liu] use string for version
47c6278 [Davies Liu] check python version of worker with driver
This commit is contained in:
Davies Liu 2015-05-18 12:55:13 -07:00 committed by Josh Rosen
parent 9dadf019b9
commit 32fbd297dd
10 changed files with 26 additions and 14 deletions

View file

@ -47,6 +47,7 @@ private[spark] class PythonRDD(
pythonIncludes: JList[String],
preservePartitoning: Boolean,
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
@ -210,6 +211,8 @@ private[spark] class PythonRDD(
val dataOut = new DataOutputStream(stream)
// Partition index
dataOut.writeInt(split.index)
// Python version of driver
PythonRDD.writeUTF(pythonVer, dataOut)
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
// Python includes (*.zip and *.egg files)

View file

@ -173,6 +173,7 @@ class SparkContext(object):
self._jvm.PythonAccumulatorParam(host, port))
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
self.pythonVer = "%d.%d" % sys.version_info[:2]
# Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have

View file

@ -2260,7 +2260,7 @@ class RDD(object):
def _prepare_for_python_RDD(sc, command, obj=None):
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
pickled_command = ser.dumps((command, sys.version_info[:2]))
pickled_command = ser.dumps(command)
if len(pickled_command) > (1 << 20): # 1M
# The broadcast will have same life cycle as created PythonRDD
broadcast = sc.broadcast(pickled_command)
@ -2344,7 +2344,7 @@ class PipelinedRDD(RDD):
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
bytearray(pickled_cmd),
env, includes, self.preservesPartitioning,
self.ctx.pythonExec,
self.ctx.pythonExec, self.ctx.pythonVer,
bvars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()

View file

@ -157,6 +157,7 @@ class SQLContext(object):
env,
includes,
self._sc.pythonExec,
self._sc.pythonVer,
bvars,
self._sc._javaAccumulator,
returnType.json())

View file

@ -353,8 +353,8 @@ class UserDefinedFunction(object):
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(self.returnType.json())
fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env,
includes, sc.pythonExec, broadcast_vars,
judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes,
sc.pythonExec, sc.pythonVer, broadcast_vars,
sc._javaAccumulator, jdt)
return judf

View file

@ -1543,13 +1543,13 @@ class WorkerTests(ReusedPySparkTestCase):
def test_with_different_versions_of_python(self):
rdd = self.sc.parallelize(range(10))
rdd.count()
version = sys.version_info
sys.version_info = (2, 0, 0)
version = self.sc.pythonVer
self.sc.pythonVer = "2.0"
try:
with QuietTest(self.sc):
self.assertRaises(Py4JJavaError, lambda: rdd.count())
finally:
sys.version_info = version
self.sc.pythonVer = version
class SparkSubmitTests(unittest.TestCase):

View file

@ -57,6 +57,12 @@ def main(infile, outfile):
if split_index == -1: # for unit tests
exit(-1)
version = utf8_deserializer.loads(infile)
if version != "%d.%d" % sys.version_info[:2]:
raise Exception(("Python in worker has different version %s than that in " +
"driver %s, PySpark cannot run with different minor versions") %
("%d.%d" % sys.version_info[:2], version))
# initialize global state
shuffle.MemoryBytesSpilled = 0
shuffle.DiskBytesSpilled = 0
@ -92,11 +98,7 @@ def main(infile, outfile):
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
(func, profiler, deserializer, serializer), version = command
if version != sys.version_info[:2]:
raise Exception(("Python in worker has different version %s than that in " +
"driver %s, PySpark cannot run with different minor versions") %
(sys.version_info[:2], version))
func, profiler, deserializer, serializer = command
init_time = time.time()
def process():

View file

@ -46,6 +46,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]],
stringDataType: String): Unit = {
@ -70,6 +71,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
envVars,
pythonIncludes,
pythonExec,
pythonVer,
broadcastVars,
accumulator,
dataType,

View file

@ -58,14 +58,15 @@ private[sql] case class UserDefinedPythonFunction(
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]],
dataType: DataType) {
/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
def apply(exprs: Column*): Column = {
val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars,
accumulator, dataType, exprs.map(_.expr))
val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer,
broadcastVars, accumulator, dataType, exprs.map(_.expr))
Column(udf)
}
}

View file

@ -46,6 +46,7 @@ private[spark] case class PythonUDF(
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]],
dataType: DataType,
@ -251,6 +252,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
udf.pythonIncludes,
false,
udf.pythonExec,
udf.pythonVer,
udf.broadcastVars,
udf.accumulator
).mapPartitions { iter =>