[SPARK-13467] [PYSPARK] abstract python function to simplify pyspark code

## What changes were proposed in this pull request?

When we pass a Python function to JVM side, we also need to send its context, e.g. `envVars`, `pythonIncludes`, `pythonExec`, etc. However, it's annoying to pass around so many parameters at many places. This PR abstract python function along with its context, to simplify some pyspark code and make the logic more clear.

## How was the this patch tested?

by existing unit tests.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #11342 from cloud-fan/python-clean.
This commit is contained in:
Wenchen Fan 2016-02-24 12:44:54 -08:00 committed by Davies Liu
parent f92f53faee
commit a60f91284c
8 changed files with 51 additions and 63 deletions

View file

@ -42,14 +42,8 @@ import org.apache.spark.util.{SerializableConfiguration, Utils}
private[spark] class PythonRDD( private[spark] class PythonRDD(
parent: RDD[_], parent: RDD[_],
command: Array[Byte], func: PythonFunction,
envVars: JMap[String, String], preservePartitoning: Boolean)
pythonIncludes: JList[String],
preservePartitoning: Boolean,
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) { extends RDD[Array[Byte]](parent) {
val bufferSize = conf.getInt("spark.buffer.size", 65536) val bufferSize = conf.getInt("spark.buffer.size", 65536)
@ -64,29 +58,37 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val runner = new PythonRunner( val runner = new PythonRunner(func, bufferSize, reuse_worker)
command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator,
bufferSize, reuse_worker)
runner.compute(firstParent.iterator(split, context), split.index, context) runner.compute(firstParent.iterator(split, context), split.index, context)
} }
} }
/** /**
* A helper class to run Python UDFs in Spark. * A wrapper for a Python function, contains all necessary context to run the function in Python
* runner.
*/ */
private[spark] class PythonRunner( private[spark] case class PythonFunction(
command: Array[Byte], command: Array[Byte],
envVars: JMap[String, String], envVars: JMap[String, String],
pythonIncludes: JList[String], pythonIncludes: JList[String],
pythonExec: String, pythonExec: String,
pythonVer: String, pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]], broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]], accumulator: Accumulator[JList[Array[Byte]]])
/**
* A helper class to run Python UDFs in Spark.
*/
private[spark] class PythonRunner(
func: PythonFunction,
bufferSize: Int, bufferSize: Int,
reuse_worker: Boolean) reuse_worker: Boolean)
extends Logging { extends Logging {
private val envVars = func.envVars
private val pythonExec = func.pythonExec
private val accumulator = func.accumulator
def compute( def compute(
inputIterator: Iterator[_], inputIterator: Iterator[_],
partitionIndex: Int, partitionIndex: Int,
@ -225,6 +227,11 @@ private[spark] class PythonRunner(
@volatile private var _exception: Exception = null @volatile private var _exception: Exception = null
private val pythonVer = func.pythonVer
private val pythonIncludes = func.pythonIncludes
private val broadcastVars = func.broadcastVars
private val command = func.command
setDaemon(true) setDaemon(true)
/** Contains the exception thrown while writing the parent iterator to the Python process. */ /** Contains the exception thrown while writing the parent iterator to the Python process. */

View file

@ -2309,7 +2309,7 @@ class RDD(object):
yield row yield row
def _prepare_for_python_RDD(sc, command, obj=None): def _prepare_for_python_RDD(sc, command):
# the serialized command will be compressed by broadcast # the serialized command will be compressed by broadcast
ser = CloudPickleSerializer() ser = CloudPickleSerializer()
pickled_command = ser.dumps(command) pickled_command = ser.dumps(command)
@ -2329,6 +2329,15 @@ def _prepare_for_python_RDD(sc, command, obj=None):
return pickled_command, broadcast_vars, env, includes return pickled_command, broadcast_vars, env, includes
def _wrap_function(sc, func, deserializer, serializer, profiler=None):
assert deserializer, "deserializer should not be empty"
assert serializer, "serializer should not be empty"
command = (func, profiler, deserializer, serializer)
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
sc.pythonVer, broadcast_vars, sc._javaAccumulator)
class PipelinedRDD(RDD): class PipelinedRDD(RDD):
""" """
@ -2390,14 +2399,10 @@ class PipelinedRDD(RDD):
else: else:
profiler = None profiler = None
command = (self.func, profiler, self._prev_jrdd_deserializer, wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer,
self._jrdd_deserializer) self._jrdd_deserializer, profiler)
pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func,
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), self.preservesPartitioning)
bytearray(pickled_cmd),
env, includes, self.preservesPartitioning,
self.ctx.pythonExec, self.ctx.pythonVer,
bvars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD() self._jrdd_val = python_rdd.asJavaRDD()
if profiler: if profiler:

View file

@ -29,7 +29,7 @@ else:
from py4j.protocol import Py4JError from py4j.protocol import Py4JError
from pyspark import since from pyspark import since
from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter _infer_schema, _has_nulltype, _merge_type, _create_converter

View file

@ -25,7 +25,7 @@ if sys.version < "3":
from itertools import imap as map from itertools import imap as map
from pyspark import since, SparkContext from pyspark import since, SparkContext
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.rdd import _wrap_function, ignore_unicode_prefix
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql.types import StringType from pyspark.sql.types import StringType
from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.column import Column, _to_java_column, _to_seq
@ -1645,16 +1645,14 @@ class UserDefinedFunction(object):
f, returnType = self.func, self.returnType # put them in closure `func` f, returnType = self.func, self.returnType # put them in closure `func`
func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
ser = AutoBatchedSerializer(PickleSerializer()) ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
sc = SparkContext.getOrCreate() sc = SparkContext.getOrCreate()
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) wrapped_func = _wrap_function(sc, func, ser, ser)
ctx = SQLContext.getOrCreate(sc) ctx = SQLContext.getOrCreate(sc)
jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
if name is None: if name is None:
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
name, bytearray(pickled_command), env, includes, sc.pythonExec, sc.pythonVer, name, wrapped_func, jdt)
broadcast_vars, sc._javaAccumulator, jdt)
return judf return judf
def __del__(self): def __del__(self):

View file

@ -43,10 +43,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
s""" s"""
| Registering new PythonUDF: | Registering new PythonUDF:
| name: $name | name: $name
| command: ${udf.command.toSeq} | command: ${udf.func.command.toSeq}
| envVars: ${udf.envVars} | envVars: ${udf.func.envVars}
| pythonIncludes: ${udf.pythonIncludes} | pythonIncludes: ${udf.func.pythonIncludes}
| pythonExec: ${udf.pythonExec} | pythonExec: ${udf.func.pythonExec}
| dataType: ${udf.dataType} | dataType: ${udf.dataType}
""".stripMargin) """.stripMargin)

View file

@ -76,13 +76,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
// Output iterator for results from Python. // Output iterator for results from Python.
val outputIterator = new PythonRunner( val outputIterator = new PythonRunner(
udf.command, udf.func,
udf.envVars,
udf.pythonIncludes,
udf.pythonExec,
udf.pythonVer,
udf.broadcastVars,
udf.accumulator,
bufferSize, bufferSize,
reuseWorker reuseWorker
).compute(inputIterator, context.partitionId(), context) ).compute(inputIterator, context.partitionId(), context)

View file

@ -17,9 +17,8 @@
package org.apache.spark.sql.execution.python package org.apache.spark.sql.execution.python
import org.apache.spark.{Accumulator, Logging} import org.apache.spark.Logging
import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.api.python.PythonFunction
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable} import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
import org.apache.spark.sql.types.DataType import org.apache.spark.sql.types.DataType
@ -28,13 +27,7 @@ import org.apache.spark.sql.types.DataType
*/ */
case class PythonUDF( case class PythonUDF(
name: String, name: String,
command: Array[Byte], func: PythonFunction,
envVars: java.util.Map[String, String],
pythonIncludes: java.util.List[String],
pythonExec: String,
pythonVer: String,
broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
accumulator: Accumulator[java.util.List[Array[Byte]]],
dataType: DataType, dataType: DataType,
children: Seq[Expression]) children: Seq[Expression])
extends Expression with Unevaluable with NonSQLExpression with Logging { extends Expression with Unevaluable with NonSQLExpression with Logging {

View file

@ -17,9 +17,7 @@
package org.apache.spark.sql.execution.python package org.apache.spark.sql.execution.python
import org.apache.spark.Accumulator import org.apache.spark.api.python.PythonFunction
import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.Column import org.apache.spark.sql.Column
import org.apache.spark.sql.types.DataType import org.apache.spark.sql.types.DataType
@ -29,18 +27,11 @@ import org.apache.spark.sql.types.DataType
*/ */
case class UserDefinedPythonFunction( case class UserDefinedPythonFunction(
name: String, name: String,
command: Array[Byte], func: PythonFunction,
envVars: java.util.Map[String, String],
pythonIncludes: java.util.List[String],
pythonExec: String,
pythonVer: String,
broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
accumulator: Accumulator[java.util.List[Array[Byte]]],
dataType: DataType) { dataType: DataType) {
def builder(e: Seq[Expression]): PythonUDF = { def builder(e: Seq[Expression]): PythonUDF = {
PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, PythonUDF(name, func, dataType, e)
accumulator, dataType, e)
} }
/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */