2017-11-17 10:43:08 -05:00
|
|
|
#
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
|
|
# this work for additional information regarding copyright ownership.
|
|
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
|
|
# (the "License"); you may not use this file except in compliance with
|
|
|
|
# the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
"""
|
|
|
|
User-defined function related classes and functions
|
|
|
|
"""
|
|
|
|
import functools
|
2018-03-08 06:38:34 -05:00
|
|
|
import sys
|
2017-11-17 10:43:08 -05:00
|
|
|
|
2018-01-18 00:51:05 -05:00
|
|
|
from pyspark import SparkContext, since
|
|
|
|
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix
|
2017-11-17 10:43:08 -05:00
|
|
|
from pyspark.sql.column import Column, _to_java_column, _to_seq
|
2018-03-25 23:42:32 -04:00
|
|
|
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\
|
|
|
|
to_arrow_type, to_arrow_schema
|
[SPARK-23754][PYTHON][FOLLOWUP] Move UDF stop iteration wrapping from driver to executor
## What changes were proposed in this pull request?
SPARK-23754 was fixed in #21383 by changing the UDF code to wrap the user function, but this required a hack to save its argspec. This PR reverts this change and fixes the `StopIteration` bug in the worker
## How does this work?
The root of the problem is that when an user-supplied function raises a `StopIteration`, pyspark might stop processing data, if this function is used in a for-loop. The solution is to catch `StopIteration`s exceptions and re-raise them as `RuntimeError`s, so that the execution fails and the error is reported to the user. This is done using the `fail_on_stopiteration` wrapper, in different ways depending on where the function is used:
- In RDDs, the user function is wrapped in the driver, because this function is also called in the driver itself.
- In SQL UDFs, the function is wrapped in the worker, since all processing happens there. Moreover, the worker needs the signature of the user function, which is lost when wrapping it, but passing this signature to the worker requires a not so nice hack.
## How was this patch tested?
Same tests, plus tests for pandas UDFs
Author: edorigatti <emilio.dorigatti@gmail.com>
Closes #21467 from e-dorigatti/fix_udf_hack.
2018-06-10 22:15:42 -04:00
|
|
|
from pyspark.util import _get_argspec
|
2017-11-17 10:43:08 -05:00
|
|
|
|
2018-01-18 00:51:05 -05:00
|
|
|
__all__ = ["UDFRegistration"]
|
|
|
|
|
2017-11-17 10:43:08 -05:00
|
|
|
|
|
|
|
def _wrap_function(sc, func, returnType):
|
|
|
|
command = (func, returnType)
|
|
|
|
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
|
|
|
|
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
|
|
|
|
sc.pythonVer, broadcast_vars, sc._javaAccumulator)
|
|
|
|
|
|
|
|
|
|
|
|
def _create_udf(f, returnType, evalType):
|
2017-12-21 06:43:56 -05:00
|
|
|
|
2018-01-30 07:55:55 -05:00
|
|
|
if evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
|
|
|
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
|
|
|
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
|
2018-01-23 00:11:30 -05:00
|
|
|
|
2017-12-22 06:09:51 -05:00
|
|
|
from pyspark.sql.utils import require_minimum_pyarrow_version
|
|
|
|
require_minimum_pyarrow_version()
|
2018-03-04 23:36:42 -05:00
|
|
|
|
2018-03-08 06:29:07 -05:00
|
|
|
argspec = _get_argspec(f)
|
2017-12-21 06:43:56 -05:00
|
|
|
|
2018-01-30 07:55:55 -05:00
|
|
|
if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \
|
2017-12-21 06:43:56 -05:00
|
|
|
argspec.varargs is None:
|
2017-11-17 10:43:08 -05:00
|
|
|
raise ValueError(
|
|
|
|
"Invalid function: 0-arg pandas_udfs are not supported. "
|
|
|
|
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
|
|
|
|
)
|
|
|
|
|
2018-03-08 06:29:07 -05:00
|
|
|
if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
|
|
|
|
and len(argspec.args) not in (1, 2):
|
2017-11-17 10:43:08 -05:00
|
|
|
raise ValueError(
|
2018-01-30 07:55:55 -05:00
|
|
|
"Invalid function: pandas_udfs with function type GROUPED_MAP "
|
2018-03-08 06:29:07 -05:00
|
|
|
"must take either one argument (data) or two arguments (key, data).")
|
2017-11-17 10:43:08 -05:00
|
|
|
|
|
|
|
# Set the name of the UserDefinedFunction object to be the name of function f
|
2018-01-04 08:07:31 -05:00
|
|
|
udf_obj = UserDefinedFunction(
|
|
|
|
f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
|
2017-11-17 10:43:08 -05:00
|
|
|
return udf_obj._wrapped()
|
|
|
|
|
|
|
|
|
|
|
|
class UserDefinedFunction(object):
|
|
|
|
"""
|
|
|
|
User defined function in Python
|
|
|
|
|
|
|
|
.. versionadded:: 1.3
|
|
|
|
"""
|
|
|
|
def __init__(self, func,
|
2018-01-04 08:07:31 -05:00
|
|
|
returnType=StringType(),
|
|
|
|
name=None,
|
|
|
|
evalType=PythonEvalType.SQL_BATCHED_UDF,
|
|
|
|
deterministic=True):
|
2017-11-17 10:43:08 -05:00
|
|
|
if not callable(func):
|
|
|
|
raise TypeError(
|
|
|
|
"Invalid function: not a function or callable (__call__ is not defined): "
|
|
|
|
"{0}".format(type(func)))
|
|
|
|
|
|
|
|
if not isinstance(returnType, (DataType, str)):
|
|
|
|
raise TypeError(
|
|
|
|
"Invalid returnType: returnType should be DataType or str "
|
|
|
|
"but is {}".format(returnType))
|
|
|
|
|
|
|
|
if not isinstance(evalType, int):
|
|
|
|
raise TypeError(
|
|
|
|
"Invalid evalType: evalType should be an int but is {}".format(evalType))
|
|
|
|
|
|
|
|
self.func = func
|
|
|
|
self._returnType = returnType
|
|
|
|
# Stores UserDefinedPythonFunctions jobj, once initialized
|
|
|
|
self._returnType_placeholder = None
|
|
|
|
self._judf_placeholder = None
|
|
|
|
self._name = name or (
|
|
|
|
func.__name__ if hasattr(func, '__name__')
|
|
|
|
else func.__class__.__name__)
|
|
|
|
self.evalType = evalType
|
2018-01-04 08:07:31 -05:00
|
|
|
self.deterministic = deterministic
|
2017-11-17 10:43:08 -05:00
|
|
|
|
|
|
|
@property
|
|
|
|
def returnType(self):
|
|
|
|
# This makes sure this is called after SparkContext is initialized.
|
|
|
|
# ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string.
|
|
|
|
if self._returnType_placeholder is None:
|
|
|
|
if isinstance(self._returnType, DataType):
|
|
|
|
self._returnType_placeholder = self._returnType
|
|
|
|
else:
|
|
|
|
self._returnType_placeholder = _parse_datatype_string(self._returnType)
|
|
|
|
|
[SPARK-23352][PYTHON] Explicitly specify supported types in Pandas UDFs
## What changes were proposed in this pull request?
This PR targets to explicitly specify supported types in Pandas UDFs.
The main change here is to add a deduplicated and explicit type checking in `returnType` ahead with documenting this; however, it happened to fix multiple things.
1. Currently, we don't support `BinaryType` in Pandas UDFs, for example, see:
```python
from pyspark.sql.functions import pandas_udf
pudf = pandas_udf(lambda x: x, "binary")
df = spark.createDataFrame([[bytearray(1)]])
df.select(pudf("_1")).show()
```
```
...
TypeError: Unsupported type in conversion to Arrow: BinaryType
```
We can document this behaviour for its guide.
2. Also, the grouped aggregate Pandas UDF fails fast on `ArrayType` but seems we can support this case.
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo = pandas_udf(lambda v: v.mean(), 'array<double>', PandasUDFType.GROUPED_AGG)
df = spark.range(100).selectExpr("id", "array(id) as value")
df.groupBy("id").agg(foo("value")).show()
```
```
...
NotImplementedError: ArrayType, StructType and MapType are not supported with PandasUDFType.GROUPED_AGG
```
3. Since we can check the return type ahead, we can fail fast before actual execution.
```python
# we can fail fast at this stage because we know the schema ahead
pandas_udf(lambda x: x, BinaryType())
```
## How was this patch tested?
Manually tested and unit tests for `BinaryType` and `ArrayType(...)` were added.
Author: hyukjinkwon <gurwls223@gmail.com>
Closes #20531 from HyukjinKwon/pudf-cleanup.
2018-02-12 06:49:36 -05:00
|
|
|
if self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
|
|
|
|
try:
|
|
|
|
to_arrow_type(self._returnType_placeholder)
|
|
|
|
except TypeError:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Invalid returnType with scalar Pandas UDFs: %s is "
|
|
|
|
"not supported" % str(self._returnType_placeholder))
|
|
|
|
elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
|
|
|
|
if isinstance(self._returnType_placeholder, StructType):
|
|
|
|
try:
|
|
|
|
to_arrow_schema(self._returnType_placeholder)
|
|
|
|
except TypeError:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Invalid returnType with grouped map Pandas UDFs: "
|
|
|
|
"%s is not supported" % str(self._returnType_placeholder))
|
|
|
|
else:
|
|
|
|
raise TypeError("Invalid returnType for grouped map Pandas "
|
|
|
|
"UDFs: returnType must be a StructType.")
|
|
|
|
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
|
|
|
|
try:
|
|
|
|
to_arrow_type(self._returnType_placeholder)
|
|
|
|
except TypeError:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Invalid returnType with grouped aggregate Pandas UDFs: "
|
|
|
|
"%s is not supported" % str(self._returnType_placeholder))
|
2017-11-17 10:43:08 -05:00
|
|
|
|
|
|
|
return self._returnType_placeholder
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _judf(self):
|
|
|
|
# It is possible that concurrent access, to newly created UDF,
|
|
|
|
# will initialize multiple UserDefinedPythonFunctions.
|
|
|
|
# This is unlikely, doesn't affect correctness,
|
|
|
|
# and should have a minimal performance impact.
|
|
|
|
if self._judf_placeholder is None:
|
|
|
|
self._judf_placeholder = self._create_judf()
|
|
|
|
return self._judf_placeholder
|
|
|
|
|
|
|
|
def _create_judf(self):
|
|
|
|
from pyspark.sql import SparkSession
|
|
|
|
|
|
|
|
spark = SparkSession.builder.getOrCreate()
|
|
|
|
sc = spark.sparkContext
|
|
|
|
|
[SPARK-23754][PYTHON][FOLLOWUP] Move UDF stop iteration wrapping from driver to executor
## What changes were proposed in this pull request?
SPARK-23754 was fixed in #21383 by changing the UDF code to wrap the user function, but this required a hack to save its argspec. This PR reverts this change and fixes the `StopIteration` bug in the worker
## How does this work?
The root of the problem is that when an user-supplied function raises a `StopIteration`, pyspark might stop processing data, if this function is used in a for-loop. The solution is to catch `StopIteration`s exceptions and re-raise them as `RuntimeError`s, so that the execution fails and the error is reported to the user. This is done using the `fail_on_stopiteration` wrapper, in different ways depending on where the function is used:
- In RDDs, the user function is wrapped in the driver, because this function is also called in the driver itself.
- In SQL UDFs, the function is wrapped in the worker, since all processing happens there. Moreover, the worker needs the signature of the user function, which is lost when wrapping it, but passing this signature to the worker requires a not so nice hack.
## How was this patch tested?
Same tests, plus tests for pandas UDFs
Author: edorigatti <emilio.dorigatti@gmail.com>
Closes #21467 from e-dorigatti/fix_udf_hack.
2018-06-10 22:15:42 -04:00
|
|
|
wrapped_func = _wrap_function(sc, self.func, self.returnType)
|
2017-11-17 10:43:08 -05:00
|
|
|
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
|
|
|
|
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
|
2018-01-04 08:07:31 -05:00
|
|
|
self._name, wrapped_func, jdt, self.evalType, self.deterministic)
|
2017-11-17 10:43:08 -05:00
|
|
|
return judf
|
|
|
|
|
|
|
|
def __call__(self, *cols):
|
|
|
|
judf = self._judf
|
|
|
|
sc = SparkContext._active_spark_context
|
|
|
|
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
|
|
|
|
|
2018-01-04 08:07:31 -05:00
|
|
|
# This function is for improving the online help system in the interactive interpreter.
|
|
|
|
# For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and
|
|
|
|
# argument annotation. (See: SPARK-19161)
|
2017-11-17 10:43:08 -05:00
|
|
|
def _wrapped(self):
|
|
|
|
"""
|
|
|
|
Wrap this udf with a function and attach docstring from func
|
|
|
|
"""
|
|
|
|
|
|
|
|
# It is possible for a callable instance without __name__ attribute or/and
|
|
|
|
# __module__ attribute to be wrapped here. For example, functools.partial. In this case,
|
|
|
|
# we should avoid wrapping the attributes from the wrapped function to the wrapper
|
|
|
|
# function. So, we take out these attribute names from the default names to set and
|
|
|
|
# then manually assign it after being wrapped.
|
|
|
|
assignments = tuple(
|
|
|
|
a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__')
|
|
|
|
|
|
|
|
@functools.wraps(self.func, assigned=assignments)
|
|
|
|
def wrapper(*args):
|
|
|
|
return self(*args)
|
|
|
|
|
|
|
|
wrapper.__name__ = self._name
|
|
|
|
wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__')
|
|
|
|
else self.func.__class__.__module__)
|
|
|
|
|
|
|
|
wrapper.func = self.func
|
|
|
|
wrapper.returnType = self.returnType
|
|
|
|
wrapper.evalType = self.evalType
|
2018-01-04 08:07:31 -05:00
|
|
|
wrapper.deterministic = self.deterministic
|
2018-01-06 10:08:26 -05:00
|
|
|
wrapper.asNondeterministic = functools.wraps(
|
|
|
|
self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped())
|
2017-11-17 10:43:08 -05:00
|
|
|
return wrapper
|
2017-12-26 09:39:40 -05:00
|
|
|
|
|
|
|
def asNondeterministic(self):
|
|
|
|
"""
|
|
|
|
Updates UserDefinedFunction to nondeterministic.
|
|
|
|
|
|
|
|
.. versionadded:: 2.3
|
|
|
|
"""
|
2018-01-27 14:26:09 -05:00
|
|
|
# Here, we explicitly clean the cache to create a JVM UDF instance
|
|
|
|
# with 'deterministic' updated. See SPARK-23233.
|
|
|
|
self._judf_placeholder = None
|
2018-01-04 08:07:31 -05:00
|
|
|
self.deterministic = False
|
2017-12-26 09:39:40 -05:00
|
|
|
return self
|
2018-01-18 00:51:05 -05:00
|
|
|
|
|
|
|
|
|
|
|
class UDFRegistration(object):
|
|
|
|
"""
|
|
|
|
Wrapper for user-defined function registration. This instance can be accessed by
|
|
|
|
:attr:`spark.udf` or :attr:`sqlContext.udf`.
|
|
|
|
|
|
|
|
.. versionadded:: 1.3.1
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, sparkSession):
|
|
|
|
self.sparkSession = sparkSession
|
|
|
|
|
|
|
|
@ignore_unicode_prefix
|
|
|
|
@since("1.3.1")
|
|
|
|
def register(self, name, f, returnType=None):
|
2018-01-22 07:27:59 -05:00
|
|
|
"""Register a Python function (including lambda function) or a user-defined function
|
|
|
|
as a SQL function.
|
2018-01-18 00:51:05 -05:00
|
|
|
|
|
|
|
:param name: name of the user-defined function in SQL statements.
|
|
|
|
:param f: a Python function, or a user-defined function. The user-defined function can
|
|
|
|
be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and
|
|
|
|
:meth:`pyspark.sql.functions.pandas_udf`.
|
2018-01-18 08:33:04 -05:00
|
|
|
:param returnType: the return type of the registered user-defined function. The value can
|
|
|
|
be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
|
2018-01-18 00:51:05 -05:00
|
|
|
:return: a user-defined function.
|
|
|
|
|
2018-01-22 07:27:59 -05:00
|
|
|
To register a nondeterministic Python function, users need to first build
|
|
|
|
a nondeterministic user-defined function for the Python function and then register it
|
|
|
|
as a SQL function.
|
|
|
|
|
2018-01-18 00:51:05 -05:00
|
|
|
`returnType` can be optionally specified when `f` is a Python function but not
|
|
|
|
when `f` is a user-defined function. Please see below.
|
|
|
|
|
|
|
|
1. When `f` is a Python function:
|
|
|
|
|
|
|
|
`returnType` defaults to string type and can be optionally specified. The produced
|
|
|
|
object must match the specified type. In this case, this API works as if
|
|
|
|
`register(name, f, returnType=StringType())`.
|
|
|
|
|
|
|
|
>>> strlen = spark.udf.register("stringLengthString", lambda x: len(x))
|
|
|
|
>>> spark.sql("SELECT stringLengthString('test')").collect()
|
|
|
|
[Row(stringLengthString(test)=u'4')]
|
|
|
|
|
|
|
|
>>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
|
|
|
|
[Row(stringLengthString(text)=u'3')]
|
|
|
|
|
|
|
|
>>> from pyspark.sql.types import IntegerType
|
|
|
|
>>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
|
|
|
|
>>> spark.sql("SELECT stringLengthInt('test')").collect()
|
|
|
|
[Row(stringLengthInt(test)=4)]
|
|
|
|
|
|
|
|
>>> from pyspark.sql.types import IntegerType
|
|
|
|
>>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
|
|
|
|
>>> spark.sql("SELECT stringLengthInt('test')").collect()
|
|
|
|
[Row(stringLengthInt(test)=4)]
|
|
|
|
|
|
|
|
2. When `f` is a user-defined function:
|
|
|
|
|
|
|
|
Spark uses the return type of the given user-defined function as the return type of
|
|
|
|
the registered user-defined function. `returnType` should not be specified.
|
|
|
|
In this case, this API works as if `register(name, f)`.
|
|
|
|
|
|
|
|
>>> from pyspark.sql.types import IntegerType
|
|
|
|
>>> from pyspark.sql.functions import udf
|
|
|
|
>>> slen = udf(lambda s: len(s), IntegerType())
|
|
|
|
>>> _ = spark.udf.register("slen", slen)
|
|
|
|
>>> spark.sql("SELECT slen('test')").collect()
|
|
|
|
[Row(slen(test)=4)]
|
|
|
|
|
|
|
|
>>> import random
|
|
|
|
>>> from pyspark.sql.functions import udf
|
|
|
|
>>> from pyspark.sql.types import IntegerType
|
|
|
|
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
|
|
|
|
>>> new_random_udf = spark.udf.register("random_udf", random_udf)
|
|
|
|
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
|
|
|
|
[Row(random_udf()=82)]
|
|
|
|
|
|
|
|
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
|
|
|
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
|
|
|
|
... def add_one(x):
|
|
|
|
... return x + 1
|
|
|
|
...
|
|
|
|
>>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP
|
|
|
|
>>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
|
|
|
|
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
|
|
|
|
|
[SPARK-25601][PYTHON] Register Grouped aggregate UDF Vectorized UDFs for SQL Statement
## What changes were proposed in this pull request?
This PR proposes to register Grouped aggregate UDF Vectorized UDFs for SQL Statement, for instance:
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
pandas_udf("integer", PandasUDFType.GROUPED_AGG)
def sum_udf(v):
return v.sum()
spark.udf.register("sum_udf", sum_udf)
q = "SELECT v2, sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
spark.sql(q).show()
```
```
+---+-----------+
| v2|sum_udf(v1)|
+---+-----------+
| 1| 1|
| 0| 5|
+---+-----------+
```
## How was this patch tested?
Manual test and unit test.
Closes #22620 from HyukjinKwon/SPARK-25601.
Authored-by: hyukjinkwon <gurwls223@apache.org>
Signed-off-by: hyukjinkwon <gurwls223@apache.org>
2018-10-03 21:36:23 -04:00
|
|
|
>>> @pandas_udf("integer", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
|
|
|
|
... def sum_udf(v):
|
|
|
|
... return v.sum()
|
|
|
|
...
|
|
|
|
>>> _ = spark.udf.register("sum_udf", sum_udf) # doctest: +SKIP
|
|
|
|
>>> q = "SELECT sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
|
|
|
|
>>> spark.sql(q).collect() # doctest: +SKIP
|
|
|
|
[Row(sum_udf(v1)=1), Row(sum_udf(v1)=5)]
|
|
|
|
|
2018-01-18 00:51:05 -05:00
|
|
|
.. note:: Registration for a user-defined function (case 2.) was added from
|
|
|
|
Spark 2.3.0.
|
|
|
|
"""
|
|
|
|
|
|
|
|
# This is to check whether the input function is from a user-defined function or
|
|
|
|
# Python function.
|
|
|
|
if hasattr(f, 'asNondeterministic'):
|
|
|
|
if returnType is not None:
|
|
|
|
raise TypeError(
|
|
|
|
"Invalid returnType: data type can not be specified when f is"
|
|
|
|
"a user-defined function, but got %s." % returnType)
|
|
|
|
if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
|
[SPARK-25601][PYTHON] Register Grouped aggregate UDF Vectorized UDFs for SQL Statement
## What changes were proposed in this pull request?
This PR proposes to register Grouped aggregate UDF Vectorized UDFs for SQL Statement, for instance:
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
pandas_udf("integer", PandasUDFType.GROUPED_AGG)
def sum_udf(v):
return v.sum()
spark.udf.register("sum_udf", sum_udf)
q = "SELECT v2, sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
spark.sql(q).show()
```
```
+---+-----------+
| v2|sum_udf(v1)|
+---+-----------+
| 1| 1|
| 0| 5|
+---+-----------+
```
## How was this patch tested?
Manual test and unit test.
Closes #22620 from HyukjinKwon/SPARK-25601.
Authored-by: hyukjinkwon <gurwls223@apache.org>
Signed-off-by: hyukjinkwon <gurwls223@apache.org>
2018-10-03 21:36:23 -04:00
|
|
|
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
|
|
|
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]:
|
2018-01-18 00:51:05 -05:00
|
|
|
raise ValueError(
|
[SPARK-25601][PYTHON] Register Grouped aggregate UDF Vectorized UDFs for SQL Statement
## What changes were proposed in this pull request?
This PR proposes to register Grouped aggregate UDF Vectorized UDFs for SQL Statement, for instance:
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
pandas_udf("integer", PandasUDFType.GROUPED_AGG)
def sum_udf(v):
return v.sum()
spark.udf.register("sum_udf", sum_udf)
q = "SELECT v2, sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
spark.sql(q).show()
```
```
+---+-----------+
| v2|sum_udf(v1)|
+---+-----------+
| 1| 1|
| 0| 5|
+---+-----------+
```
## How was this patch tested?
Manual test and unit test.
Closes #22620 from HyukjinKwon/SPARK-25601.
Authored-by: hyukjinkwon <gurwls223@apache.org>
Signed-off-by: hyukjinkwon <gurwls223@apache.org>
2018-10-03 21:36:23 -04:00
|
|
|
"Invalid f: f must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF or "
|
|
|
|
"SQL_GROUPED_AGG_PANDAS_UDF")
|
2018-01-18 00:51:05 -05:00
|
|
|
register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
|
|
|
|
evalType=f.evalType,
|
|
|
|
deterministic=f.deterministic)
|
|
|
|
return_udf = f
|
|
|
|
else:
|
|
|
|
if returnType is None:
|
|
|
|
returnType = StringType()
|
|
|
|
register_udf = UserDefinedFunction(f, returnType=returnType, name=name,
|
|
|
|
evalType=PythonEvalType.SQL_BATCHED_UDF)
|
|
|
|
return_udf = register_udf._wrapped()
|
|
|
|
self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf)
|
|
|
|
return return_udf
|
|
|
|
|
|
|
|
@ignore_unicode_prefix
|
|
|
|
@since(2.3)
|
|
|
|
def registerJavaFunction(self, name, javaClassName, returnType=None):
|
2018-01-22 07:27:59 -05:00
|
|
|
"""Register a Java user-defined function as a SQL function.
|
2018-01-18 00:51:05 -05:00
|
|
|
|
|
|
|
In addition to a name and the function itself, the return type can be optionally specified.
|
|
|
|
When the return type is not specified we would infer it via reflection.
|
|
|
|
|
|
|
|
:param name: name of the user-defined function
|
|
|
|
:param javaClassName: fully qualified name of java class
|
2018-01-18 08:33:04 -05:00
|
|
|
:param returnType: the return type of the registered Java function. The value can be either
|
|
|
|
a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
|
2018-01-18 00:51:05 -05:00
|
|
|
|
|
|
|
>>> from pyspark.sql.types import IntegerType
|
|
|
|
>>> spark.udf.registerJavaFunction(
|
|
|
|
... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType())
|
|
|
|
>>> spark.sql("SELECT javaStringLength('test')").collect()
|
|
|
|
[Row(UDF:javaStringLength(test)=4)]
|
2018-01-18 08:33:04 -05:00
|
|
|
|
2018-01-18 00:51:05 -05:00
|
|
|
>>> spark.udf.registerJavaFunction(
|
|
|
|
... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength")
|
|
|
|
>>> spark.sql("SELECT javaStringLength2('test')").collect()
|
|
|
|
[Row(UDF:javaStringLength2(test)=4)]
|
2018-01-18 08:33:04 -05:00
|
|
|
|
|
|
|
>>> spark.udf.registerJavaFunction(
|
|
|
|
... "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer")
|
|
|
|
>>> spark.sql("SELECT javaStringLength3('test')").collect()
|
|
|
|
[Row(UDF:javaStringLength3(test)=4)]
|
2018-01-18 00:51:05 -05:00
|
|
|
"""
|
|
|
|
|
|
|
|
jdt = None
|
|
|
|
if returnType is not None:
|
2018-01-18 08:33:04 -05:00
|
|
|
if not isinstance(returnType, DataType):
|
|
|
|
returnType = _parse_datatype_string(returnType)
|
2018-01-18 00:51:05 -05:00
|
|
|
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
|
|
|
|
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
|
|
|
|
|
|
|
|
@ignore_unicode_prefix
|
|
|
|
@since(2.3)
|
|
|
|
def registerJavaUDAF(self, name, javaClassName):
|
2018-01-22 07:27:59 -05:00
|
|
|
"""Register a Java user-defined aggregate function as a SQL function.
|
2018-01-18 00:51:05 -05:00
|
|
|
|
|
|
|
:param name: name of the user-defined aggregate function
|
|
|
|
:param javaClassName: fully qualified name of java class
|
|
|
|
|
|
|
|
>>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg")
|
|
|
|
>>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
|
2018-02-07 09:24:16 -05:00
|
|
|
>>> df.createOrReplaceTempView("df")
|
2018-01-18 00:51:05 -05:00
|
|
|
>>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect()
|
|
|
|
[Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
|
|
|
|
"""
|
|
|
|
|
|
|
|
self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)
|
|
|
|
|
|
|
|
|
|
|
|
def _test():
|
|
|
|
import doctest
|
|
|
|
from pyspark.sql import SparkSession
|
|
|
|
import pyspark.sql.udf
|
|
|
|
globs = pyspark.sql.udf.__dict__.copy()
|
|
|
|
spark = SparkSession.builder\
|
|
|
|
.master("local[4]")\
|
|
|
|
.appName("sql.udf tests")\
|
|
|
|
.getOrCreate()
|
|
|
|
globs['spark'] = spark
|
|
|
|
(failure_count, test_count) = doctest.testmod(
|
|
|
|
pyspark.sql.udf, globs=globs,
|
|
|
|
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
|
|
|
|
spark.stop()
|
|
|
|
if failure_count:
|
2018-03-08 06:38:34 -05:00
|
|
|
sys.exit(-1)
|
2018-01-18 00:51:05 -05:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
_test()
|