[SPARK-13467] [PYSPARK] abstract python function to simplify pyspark code
## What changes were proposed in this pull request? When we pass a Python function to JVM side, we also need to send its context, e.g. `envVars`, `pythonIncludes`, `pythonExec`, etc. However, it's annoying to pass around so many parameters at many places. This PR abstract python function along with its context, to simplify some pyspark code and make the logic more clear. ## How was the this patch tested? by existing unit tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #11342 from cloud-fan/python-clean.
This commit is contained in:
parent
f92f53faee
commit
a60f91284c
|
@ -42,14 +42,8 @@ import org.apache.spark.util.{SerializableConfiguration, Utils}
|
||||||
|
|
||||||
private[spark] class PythonRDD(
|
private[spark] class PythonRDD(
|
||||||
parent: RDD[_],
|
parent: RDD[_],
|
||||||
command: Array[Byte],
|
func: PythonFunction,
|
||||||
envVars: JMap[String, String],
|
preservePartitoning: Boolean)
|
||||||
pythonIncludes: JList[String],
|
|
||||||
preservePartitoning: Boolean,
|
|
||||||
pythonExec: String,
|
|
||||||
pythonVer: String,
|
|
||||||
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
|
||||||
accumulator: Accumulator[JList[Array[Byte]]])
|
|
||||||
extends RDD[Array[Byte]](parent) {
|
extends RDD[Array[Byte]](parent) {
|
||||||
|
|
||||||
val bufferSize = conf.getInt("spark.buffer.size", 65536)
|
val bufferSize = conf.getInt("spark.buffer.size", 65536)
|
||||||
|
@ -64,29 +58,37 @@ private[spark] class PythonRDD(
|
||||||
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
|
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
|
||||||
|
|
||||||
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
|
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
|
||||||
val runner = new PythonRunner(
|
val runner = new PythonRunner(func, bufferSize, reuse_worker)
|
||||||
command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator,
|
|
||||||
bufferSize, reuse_worker)
|
|
||||||
runner.compute(firstParent.iterator(split, context), split.index, context)
|
runner.compute(firstParent.iterator(split, context), split.index, context)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A helper class to run Python UDFs in Spark.
|
* A wrapper for a Python function, contains all necessary context to run the function in Python
|
||||||
|
* runner.
|
||||||
*/
|
*/
|
||||||
private[spark] class PythonRunner(
|
private[spark] case class PythonFunction(
|
||||||
command: Array[Byte],
|
command: Array[Byte],
|
||||||
envVars: JMap[String, String],
|
envVars: JMap[String, String],
|
||||||
pythonIncludes: JList[String],
|
pythonIncludes: JList[String],
|
||||||
pythonExec: String,
|
pythonExec: String,
|
||||||
pythonVer: String,
|
pythonVer: String,
|
||||||
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
||||||
accumulator: Accumulator[JList[Array[Byte]]],
|
accumulator: Accumulator[JList[Array[Byte]]])
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A helper class to run Python UDFs in Spark.
|
||||||
|
*/
|
||||||
|
private[spark] class PythonRunner(
|
||||||
|
func: PythonFunction,
|
||||||
bufferSize: Int,
|
bufferSize: Int,
|
||||||
reuse_worker: Boolean)
|
reuse_worker: Boolean)
|
||||||
extends Logging {
|
extends Logging {
|
||||||
|
|
||||||
|
private val envVars = func.envVars
|
||||||
|
private val pythonExec = func.pythonExec
|
||||||
|
private val accumulator = func.accumulator
|
||||||
|
|
||||||
def compute(
|
def compute(
|
||||||
inputIterator: Iterator[_],
|
inputIterator: Iterator[_],
|
||||||
partitionIndex: Int,
|
partitionIndex: Int,
|
||||||
|
@ -225,6 +227,11 @@ private[spark] class PythonRunner(
|
||||||
|
|
||||||
@volatile private var _exception: Exception = null
|
@volatile private var _exception: Exception = null
|
||||||
|
|
||||||
|
private val pythonVer = func.pythonVer
|
||||||
|
private val pythonIncludes = func.pythonIncludes
|
||||||
|
private val broadcastVars = func.broadcastVars
|
||||||
|
private val command = func.command
|
||||||
|
|
||||||
setDaemon(true)
|
setDaemon(true)
|
||||||
|
|
||||||
/** Contains the exception thrown while writing the parent iterator to the Python process. */
|
/** Contains the exception thrown while writing the parent iterator to the Python process. */
|
||||||
|
|
|
@ -2309,7 +2309,7 @@ class RDD(object):
|
||||||
yield row
|
yield row
|
||||||
|
|
||||||
|
|
||||||
def _prepare_for_python_RDD(sc, command, obj=None):
|
def _prepare_for_python_RDD(sc, command):
|
||||||
# the serialized command will be compressed by broadcast
|
# the serialized command will be compressed by broadcast
|
||||||
ser = CloudPickleSerializer()
|
ser = CloudPickleSerializer()
|
||||||
pickled_command = ser.dumps(command)
|
pickled_command = ser.dumps(command)
|
||||||
|
@ -2329,6 +2329,15 @@ def _prepare_for_python_RDD(sc, command, obj=None):
|
||||||
return pickled_command, broadcast_vars, env, includes
|
return pickled_command, broadcast_vars, env, includes
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_function(sc, func, deserializer, serializer, profiler=None):
|
||||||
|
assert deserializer, "deserializer should not be empty"
|
||||||
|
assert serializer, "serializer should not be empty"
|
||||||
|
command = (func, profiler, deserializer, serializer)
|
||||||
|
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
|
||||||
|
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
|
||||||
|
sc.pythonVer, broadcast_vars, sc._javaAccumulator)
|
||||||
|
|
||||||
|
|
||||||
class PipelinedRDD(RDD):
|
class PipelinedRDD(RDD):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -2390,14 +2399,10 @@ class PipelinedRDD(RDD):
|
||||||
else:
|
else:
|
||||||
profiler = None
|
profiler = None
|
||||||
|
|
||||||
command = (self.func, profiler, self._prev_jrdd_deserializer,
|
wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer,
|
||||||
self._jrdd_deserializer)
|
self._jrdd_deserializer, profiler)
|
||||||
pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self)
|
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func,
|
||||||
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
|
self.preservesPartitioning)
|
||||||
bytearray(pickled_cmd),
|
|
||||||
env, includes, self.preservesPartitioning,
|
|
||||||
self.ctx.pythonExec, self.ctx.pythonVer,
|
|
||||||
bvars, self.ctx._javaAccumulator)
|
|
||||||
self._jrdd_val = python_rdd.asJavaRDD()
|
self._jrdd_val = python_rdd.asJavaRDD()
|
||||||
|
|
||||||
if profiler:
|
if profiler:
|
||||||
|
|
|
@ -29,7 +29,7 @@ else:
|
||||||
from py4j.protocol import Py4JError
|
from py4j.protocol import Py4JError
|
||||||
|
|
||||||
from pyspark import since
|
from pyspark import since
|
||||||
from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix
|
from pyspark.rdd import RDD, ignore_unicode_prefix
|
||||||
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
|
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
|
||||||
from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
|
from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
|
||||||
_infer_schema, _has_nulltype, _merge_type, _create_converter
|
_infer_schema, _has_nulltype, _merge_type, _create_converter
|
||||||
|
|
|
@ -25,7 +25,7 @@ if sys.version < "3":
|
||||||
from itertools import imap as map
|
from itertools import imap as map
|
||||||
|
|
||||||
from pyspark import since, SparkContext
|
from pyspark import since, SparkContext
|
||||||
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
|
from pyspark.rdd import _wrap_function, ignore_unicode_prefix
|
||||||
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
|
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
|
||||||
from pyspark.sql.types import StringType
|
from pyspark.sql.types import StringType
|
||||||
from pyspark.sql.column import Column, _to_java_column, _to_seq
|
from pyspark.sql.column import Column, _to_java_column, _to_seq
|
||||||
|
@ -1645,16 +1645,14 @@ class UserDefinedFunction(object):
|
||||||
f, returnType = self.func, self.returnType # put them in closure `func`
|
f, returnType = self.func, self.returnType # put them in closure `func`
|
||||||
func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
|
func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
|
||||||
ser = AutoBatchedSerializer(PickleSerializer())
|
ser = AutoBatchedSerializer(PickleSerializer())
|
||||||
command = (func, None, ser, ser)
|
|
||||||
sc = SparkContext.getOrCreate()
|
sc = SparkContext.getOrCreate()
|
||||||
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
|
wrapped_func = _wrap_function(sc, func, ser, ser)
|
||||||
ctx = SQLContext.getOrCreate(sc)
|
ctx = SQLContext.getOrCreate(sc)
|
||||||
jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
|
jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
|
||||||
if name is None:
|
if name is None:
|
||||||
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
|
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
|
||||||
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
|
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
|
||||||
name, bytearray(pickled_command), env, includes, sc.pythonExec, sc.pythonVer,
|
name, wrapped_func, jdt)
|
||||||
broadcast_vars, sc._javaAccumulator, jdt)
|
|
||||||
return judf
|
return judf
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|
|
@ -43,10 +43,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
|
||||||
s"""
|
s"""
|
||||||
| Registering new PythonUDF:
|
| Registering new PythonUDF:
|
||||||
| name: $name
|
| name: $name
|
||||||
| command: ${udf.command.toSeq}
|
| command: ${udf.func.command.toSeq}
|
||||||
| envVars: ${udf.envVars}
|
| envVars: ${udf.func.envVars}
|
||||||
| pythonIncludes: ${udf.pythonIncludes}
|
| pythonIncludes: ${udf.func.pythonIncludes}
|
||||||
| pythonExec: ${udf.pythonExec}
|
| pythonExec: ${udf.func.pythonExec}
|
||||||
| dataType: ${udf.dataType}
|
| dataType: ${udf.dataType}
|
||||||
""".stripMargin)
|
""".stripMargin)
|
||||||
|
|
||||||
|
|
|
@ -76,13 +76,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
|
||||||
|
|
||||||
// Output iterator for results from Python.
|
// Output iterator for results from Python.
|
||||||
val outputIterator = new PythonRunner(
|
val outputIterator = new PythonRunner(
|
||||||
udf.command,
|
udf.func,
|
||||||
udf.envVars,
|
|
||||||
udf.pythonIncludes,
|
|
||||||
udf.pythonExec,
|
|
||||||
udf.pythonVer,
|
|
||||||
udf.broadcastVars,
|
|
||||||
udf.accumulator,
|
|
||||||
bufferSize,
|
bufferSize,
|
||||||
reuseWorker
|
reuseWorker
|
||||||
).compute(inputIterator, context.partitionId(), context)
|
).compute(inputIterator, context.partitionId(), context)
|
||||||
|
|
|
@ -17,9 +17,8 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.execution.python
|
package org.apache.spark.sql.execution.python
|
||||||
|
|
||||||
import org.apache.spark.{Accumulator, Logging}
|
import org.apache.spark.Logging
|
||||||
import org.apache.spark.api.python.PythonBroadcast
|
import org.apache.spark.api.python.PythonFunction
|
||||||
import org.apache.spark.broadcast.Broadcast
|
|
||||||
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
|
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
|
||||||
import org.apache.spark.sql.types.DataType
|
import org.apache.spark.sql.types.DataType
|
||||||
|
|
||||||
|
@ -28,13 +27,7 @@ import org.apache.spark.sql.types.DataType
|
||||||
*/
|
*/
|
||||||
case class PythonUDF(
|
case class PythonUDF(
|
||||||
name: String,
|
name: String,
|
||||||
command: Array[Byte],
|
func: PythonFunction,
|
||||||
envVars: java.util.Map[String, String],
|
|
||||||
pythonIncludes: java.util.List[String],
|
|
||||||
pythonExec: String,
|
|
||||||
pythonVer: String,
|
|
||||||
broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
|
|
||||||
accumulator: Accumulator[java.util.List[Array[Byte]]],
|
|
||||||
dataType: DataType,
|
dataType: DataType,
|
||||||
children: Seq[Expression])
|
children: Seq[Expression])
|
||||||
extends Expression with Unevaluable with NonSQLExpression with Logging {
|
extends Expression with Unevaluable with NonSQLExpression with Logging {
|
||||||
|
|
|
@ -17,9 +17,7 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.execution.python
|
package org.apache.spark.sql.execution.python
|
||||||
|
|
||||||
import org.apache.spark.Accumulator
|
import org.apache.spark.api.python.PythonFunction
|
||||||
import org.apache.spark.api.python.PythonBroadcast
|
|
||||||
import org.apache.spark.broadcast.Broadcast
|
|
||||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||||
import org.apache.spark.sql.Column
|
import org.apache.spark.sql.Column
|
||||||
import org.apache.spark.sql.types.DataType
|
import org.apache.spark.sql.types.DataType
|
||||||
|
@ -29,18 +27,11 @@ import org.apache.spark.sql.types.DataType
|
||||||
*/
|
*/
|
||||||
case class UserDefinedPythonFunction(
|
case class UserDefinedPythonFunction(
|
||||||
name: String,
|
name: String,
|
||||||
command: Array[Byte],
|
func: PythonFunction,
|
||||||
envVars: java.util.Map[String, String],
|
|
||||||
pythonIncludes: java.util.List[String],
|
|
||||||
pythonExec: String,
|
|
||||||
pythonVer: String,
|
|
||||||
broadcastVars: java.util.List[Broadcast[PythonBroadcast]],
|
|
||||||
accumulator: Accumulator[java.util.List[Array[Byte]]],
|
|
||||||
dataType: DataType) {
|
dataType: DataType) {
|
||||||
|
|
||||||
def builder(e: Seq[Expression]): PythonUDF = {
|
def builder(e: Seq[Expression]): PythonUDF = {
|
||||||
PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars,
|
PythonUDF(name, func, dataType, e)
|
||||||
accumulator, dataType, e)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
|
/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
|
||||||
|
|
Loading…
Reference in a new issue