fe3e34dda6
## What changes were proposed in this pull request? This PR adds some tests converted from `pgSQL/case.sql'` to test UDFs. Please see contribution guide of this umbrella ticket - [SPARK-27921](https://issues.apache.org/jira/browse/SPARK-27921). This PR also contains two minor fixes: 1. Change name of Scala UDF from `UDF:name(...)` to `name(...)` to be consistent with Python' 2. Fix Scala UDF at `IntegratedUDFTestUtils.scala ` to handle `null` in strings. <details><summary>Diff comparing to 'pgSQL/case.sql'</summary> <p> ```diff diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/case.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/pgSQL/udf-case.sql.out index fa078d16d6d..55bef64338f 100644 --- a/sql/core/src/test/resources/sql-tests/results/pgSQL/case.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/pgSQL/udf-case.sql.out -115,7 +115,7 struct<> -- !query 13 SELECT '3' AS `One`, CASE - WHEN 1 < 2 THEN 3 + WHEN CAST(udf(1 < 2) AS boolean) THEN 3 END AS `Simple WHEN` -- !query 13 schema struct<One:string,Simple WHEN:int> -126,10 +126,10 struct<One:string,Simple WHEN:int> -- !query 14 SELECT '<NULL>' AS `One`, CASE - WHEN 1 > 2 THEN 3 + WHEN 1 > 2 THEN udf(3) END AS `Simple default` -- !query 14 schema -struct<One:string,Simple default:int> +struct<One:string,Simple default:string> -- !query 14 output <NULL> NULL -137,17 +137,17 struct<One:string,Simple default:int> -- !query 15 SELECT '3' AS `One`, CASE - WHEN 1 < 2 THEN 3 - ELSE 4 + WHEN udf(1) < 2 THEN udf(3) + ELSE udf(4) END AS `Simple ELSE` -- !query 15 schema -struct<One:string,Simple ELSE:int> +struct<One:string,Simple ELSE:string> -- !query 15 output 3 3 -- !query 16 -SELECT '4' AS `One`, +SELECT udf('4') AS `One`, CASE WHEN 1 > 2 THEN 3 ELSE 4 -159,10 +159,10 struct<One:string,ELSE default:int> -- !query 17 -SELECT '6' AS `One`, +SELECT udf('6') AS `One`, CASE - WHEN 1 > 2 THEN 3 - WHEN 4 < 5 THEN 6 + WHEN CAST(udf(1 > 2) AS boolean) THEN 3 + WHEN udf(4) < 5 THEN 6 ELSE 7 END AS `Two WHEN with default` -- !query 17 schema -173,7 +173,7 struct<One:string,Two WHEN with default:int> -- !query 18 SELECT '7' AS `None`, - CASE WHEN rand() < 0 THEN 1 + CASE WHEN rand() < udf(0) THEN 1 END AS `NULL on no matches` -- !query 18 schema struct<None:string,NULL on no matches:int> -182,36 +182,36 struct<None:string,NULL on no matches:int> -- !query 19 -SELECT CASE WHEN 1=0 THEN 1/0 WHEN 1=1 THEN 1 ELSE 2/0 END +SELECT CASE WHEN CAST(udf(1=0) AS boolean) THEN 1/0 WHEN 1=1 THEN 1 ELSE 2/0 END -- !query 19 schema -struct<CASE WHEN (1 = 0) THEN (CAST(1 AS DOUBLE) / CAST(0 AS DOUBLE)) WHEN (1 = 1) THEN CAST(1 AS DOUBLE) ELSE (CAST(2 AS DOUBLE) / CAST(0 AS DOUBLE)) END:double> +struct<CASE WHEN CAST(udf((1 = 0)) AS BOOLEAN) THEN (CAST(1 AS DOUBLE) / CAST(0 AS DOUBLE)) WHEN (1 = 1) THEN CAST(1 AS DOUBLE) ELSE (CAST(2 AS DOUBLE) / CAST(0 AS DOUBLE)) END:double> -- !query 19 output 1.0 -- !query 20 -SELECT CASE 1 WHEN 0 THEN 1/0 WHEN 1 THEN 1 ELSE 2/0 END +SELECT CASE 1 WHEN 0 THEN 1/udf(0) WHEN 1 THEN 1 ELSE 2/0 END -- !query 20 schema -struct<CASE WHEN (1 = 0) THEN (CAST(1 AS DOUBLE) / CAST(0 AS DOUBLE)) WHEN (1 = 1) THEN CAST(1 AS DOUBLE) ELSE (CAST(2 AS DOUBLE) / CAST(0 AS DOUBLE)) END:double> +struct<CASE WHEN (1 = 0) THEN (CAST(1 AS DOUBLE) / CAST(CAST(udf(0) AS DOUBLE) AS DOUBLE)) WHEN (1 = 1) THEN CAST(1 AS DOUBLE) ELSE (CAST(2 AS DOUBLE) / CAST(0 AS DOUBLE)) END:double> -- !query 20 output 1.0 -- !query 21 -SELECT CASE WHEN i > 100 THEN 1/0 ELSE 0 END FROM case_tbl +SELECT CASE WHEN i > 100 THEN udf(1/0) ELSE udf(0) END FROM case_tbl -- !query 21 schema -struct<CASE WHEN (i > 100) THEN (CAST(1 AS DOUBLE) / CAST(0 AS DOUBLE)) ELSE CAST(0 AS DOUBLE) END:double> +struct<CASE WHEN (i > 100) THEN udf((cast(1 as double) / cast(0 as double))) ELSE udf(0) END:string> -- !query 21 output -0.0 -0.0 -0.0 -0.0 +0 +0 +0 +0 -- !query 22 -SELECT CASE 'a' WHEN 'a' THEN 1 ELSE 2 END +SELECT CASE 'a' WHEN 'a' THEN udf(1) ELSE udf(2) END -- !query 22 schema -struct<CASE WHEN (a = a) THEN 1 ELSE 2 END:int> +struct<CASE WHEN (a = a) THEN udf(1) ELSE udf(2) END:string> -- !query 22 output 1 -283,7 +283,7 big -- !query 27 -SELECT * FROM CASE_TBL WHERE COALESCE(f,i) = 4 +SELECT * FROM CASE_TBL WHERE udf(COALESCE(f,i)) = 4 -- !query 27 schema struct<i:int,f:double> -- !query 27 output -291,7 +291,7 struct<i:int,f:double> -- !query 28 -SELECT * FROM CASE_TBL WHERE NULLIF(f,i) = 2 +SELECT * FROM CASE_TBL WHERE udf(NULLIF(f,i)) = 2 -- !query 28 schema struct<i:int,f:double> -- !query 28 output -299,10 +299,10 struct<i:int,f:double> -- !query 29 -SELECT COALESCE(a.f, b.i, b.j) +SELECT udf(COALESCE(a.f, b.i, b.j)) FROM CASE_TBL a, CASE2_TBL b -- !query 29 schema -struct<coalesce(f, CAST(i AS DOUBLE), CAST(j AS DOUBLE)):double> +struct<udf(coalesce(f, cast(i as double), cast(j as double))):string> -- !query 29 output -30.3 -30.3 -332,8 +332,8 struct<coalesce(f, CAST(i AS DOUBLE), CAST(j AS DOUBLE)):double> -- !query 30 SELECT * - FROM CASE_TBL a, CASE2_TBL b - WHERE COALESCE(a.f, b.i, b.j) = 2 + FROM CASE_TBL a, CASE2_TBL b + WHERE udf(COALESCE(a.f, b.i, b.j)) = 2 -- !query 30 schema struct<i:int,f:double,i:int,j:int> -- !query 30 output -342,7 +342,7 struct<i:int,f:double,i:int,j:int> -- !query 31 -SELECT '' AS Five, NULLIF(a.i,b.i) AS `NULLIF(a.i,b.i)`, +SELECT udf('') AS Five, NULLIF(a.i,b.i) AS `NULLIF(a.i,b.i)`, NULLIF(b.i, 4) AS `NULLIF(b.i,4)` FROM CASE_TBL a, CASE2_TBL b -- !query 31 schema -377,7 +377,7 struct<Five:string,NULLIF(a.i,b.i):int,NULLIF(b.i,4):int> -- !query 32 SELECT '' AS `Two`, * FROM CASE_TBL a, CASE2_TBL b - WHERE COALESCE(f,b.i) = 2 + WHERE CAST(udf(COALESCE(f,b.i) = 2) AS boolean) -- !query 32 schema struct<Two:string,i:int,f:double,i:int,j:int> -- !query 32 output -388,15 +388,15 struct<Two:string,i:int,f:double,i:int,j:int> -- !query 33 SELECT CASE (CASE vol('bar') - WHEN 'foo' THEN 'it was foo!' - WHEN vol(null) THEN 'null input' + WHEN udf('foo') THEN 'it was foo!' + WHEN udf(vol(null)) THEN 'null input' WHEN 'bar' THEN 'it was bar!' END ) - WHEN 'it was foo!' THEN 'foo recognized' - WHEN 'it was bar!' THEN 'bar recognized' - ELSE 'unrecognized' END + WHEN udf('it was foo!') THEN 'foo recognized' + WHEN 'it was bar!' THEN udf('bar recognized') + ELSE 'unrecognized' END AS col -- !query 33 schema -struct<CASE WHEN (CASE WHEN (UDF:vol(bar) = foo) THEN it was foo! WHEN (UDF:vol(bar) = UDF:vol(null)) THEN null input WHEN (UDF:vol(bar) = bar) THEN it was bar! END = it was foo!) THEN foo recognized WHEN (CASE WHEN (UDF:vol(bar) = foo) THEN it was foo! WHEN (UDF:vol(bar) = UDF:vol(null)) THEN null input WHEN (UDF:vol(bar) = bar) THEN it was bar! END = it was bar!) THEN bar recognized ELSE unrecognized END:string> +struct<col:string> -- !query 33 output bar recognized ``` </p> </details> https://github.com/apache/spark/pull/25069 contains the same minor fixes as it's required to write the tests. ## How was this patch tested? Tested as guided in [SPARK-27921](https://issues.apache.org/jira/browse/SPARK-27921). Closes #25070 from HyukjinKwon/SPARK-28273. Authored-by: HyukjinKwon <gurwls223@apache.org> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
437 lines
19 KiB
Python
437 lines
19 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
"""
|
|
User-defined function related classes and functions
|
|
"""
|
|
import functools
|
|
import sys
|
|
|
|
from pyspark import SparkContext, since
|
|
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix
|
|
from pyspark.sql.column import Column, _to_java_column, _to_seq
|
|
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\
|
|
to_arrow_type, to_arrow_schema
|
|
from pyspark.util import _get_argspec
|
|
|
|
__all__ = ["UDFRegistration"]
|
|
|
|
|
|
def _wrap_function(sc, func, returnType):
|
|
command = (func, returnType)
|
|
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
|
|
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
|
|
sc.pythonVer, broadcast_vars, sc._javaAccumulator)
|
|
|
|
|
|
def _create_udf(f, returnType, evalType):
|
|
|
|
if evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
|
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
|
|
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
|
|
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
|
|
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF):
|
|
|
|
from pyspark.sql.utils import require_minimum_pyarrow_version
|
|
require_minimum_pyarrow_version()
|
|
|
|
argspec = _get_argspec(f)
|
|
|
|
if (evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF or
|
|
evalType == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF) and \
|
|
len(argspec.args) == 0 and \
|
|
argspec.varargs is None:
|
|
raise ValueError(
|
|
"Invalid function: 0-arg pandas_udfs are not supported. "
|
|
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
|
|
)
|
|
|
|
if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
|
|
and len(argspec.args) not in (1, 2):
|
|
raise ValueError(
|
|
"Invalid function: pandas_udfs with function type GROUPED_MAP "
|
|
"must take either one argument (data) or two arguments (key, data).")
|
|
|
|
# Set the name of the UserDefinedFunction object to be the name of function f
|
|
udf_obj = UserDefinedFunction(
|
|
f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
|
|
return udf_obj._wrapped()
|
|
|
|
|
|
class UserDefinedFunction(object):
|
|
"""
|
|
User defined function in Python
|
|
|
|
.. versionadded:: 1.3
|
|
"""
|
|
def __init__(self, func,
|
|
returnType=StringType(),
|
|
name=None,
|
|
evalType=PythonEvalType.SQL_BATCHED_UDF,
|
|
deterministic=True):
|
|
if not callable(func):
|
|
raise TypeError(
|
|
"Invalid function: not a function or callable (__call__ is not defined): "
|
|
"{0}".format(type(func)))
|
|
|
|
if not isinstance(returnType, (DataType, str)):
|
|
raise TypeError(
|
|
"Invalid returnType: returnType should be DataType or str "
|
|
"but is {}".format(returnType))
|
|
|
|
if not isinstance(evalType, int):
|
|
raise TypeError(
|
|
"Invalid evalType: evalType should be an int but is {}".format(evalType))
|
|
|
|
self.func = func
|
|
self._returnType = returnType
|
|
# Stores UserDefinedPythonFunctions jobj, once initialized
|
|
self._returnType_placeholder = None
|
|
self._judf_placeholder = None
|
|
self._name = name or (
|
|
func.__name__ if hasattr(func, '__name__')
|
|
else func.__class__.__name__)
|
|
self.evalType = evalType
|
|
self.deterministic = deterministic
|
|
|
|
@property
|
|
def returnType(self):
|
|
# This makes sure this is called after SparkContext is initialized.
|
|
# ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string.
|
|
if self._returnType_placeholder is None:
|
|
if isinstance(self._returnType, DataType):
|
|
self._returnType_placeholder = self._returnType
|
|
else:
|
|
self._returnType_placeholder = _parse_datatype_string(self._returnType)
|
|
|
|
if self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF or \
|
|
self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
|
|
try:
|
|
to_arrow_type(self._returnType_placeholder)
|
|
except TypeError:
|
|
raise NotImplementedError(
|
|
"Invalid returnType with scalar Pandas UDFs: %s is "
|
|
"not supported" % str(self._returnType_placeholder))
|
|
elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
|
|
if isinstance(self._returnType_placeholder, StructType):
|
|
try:
|
|
to_arrow_type(self._returnType_placeholder)
|
|
except TypeError:
|
|
raise NotImplementedError(
|
|
"Invalid returnType with grouped map Pandas UDFs: "
|
|
"%s is not supported" % str(self._returnType_placeholder))
|
|
else:
|
|
raise TypeError("Invalid returnType for grouped map Pandas "
|
|
"UDFs: returnType must be a StructType.")
|
|
elif self.evalType == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
|
|
if isinstance(self._returnType_placeholder, StructType):
|
|
try:
|
|
to_arrow_type(self._returnType_placeholder)
|
|
except TypeError:
|
|
raise NotImplementedError(
|
|
"Invalid returnType with map iterator Pandas UDFs: "
|
|
"%s is not supported" % str(self._returnType_placeholder))
|
|
else:
|
|
raise TypeError("Invalid returnType for map iterator Pandas "
|
|
"UDFs: returnType must be a StructType.")
|
|
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
|
|
try:
|
|
# StructType is not yet allowed as a return type, explicitly check here to fail fast
|
|
if isinstance(self._returnType_placeholder, StructType):
|
|
raise TypeError
|
|
to_arrow_type(self._returnType_placeholder)
|
|
except TypeError:
|
|
raise NotImplementedError(
|
|
"Invalid returnType with grouped aggregate Pandas UDFs: "
|
|
"%s is not supported" % str(self._returnType_placeholder))
|
|
|
|
return self._returnType_placeholder
|
|
|
|
@property
|
|
def _judf(self):
|
|
# It is possible that concurrent access, to newly created UDF,
|
|
# will initialize multiple UserDefinedPythonFunctions.
|
|
# This is unlikely, doesn't affect correctness,
|
|
# and should have a minimal performance impact.
|
|
if self._judf_placeholder is None:
|
|
self._judf_placeholder = self._create_judf()
|
|
return self._judf_placeholder
|
|
|
|
def _create_judf(self):
|
|
from pyspark.sql import SparkSession
|
|
|
|
spark = SparkSession.builder.getOrCreate()
|
|
sc = spark.sparkContext
|
|
|
|
wrapped_func = _wrap_function(sc, self.func, self.returnType)
|
|
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
|
|
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
|
|
self._name, wrapped_func, jdt, self.evalType, self.deterministic)
|
|
return judf
|
|
|
|
def __call__(self, *cols):
|
|
judf = self._judf
|
|
sc = SparkContext._active_spark_context
|
|
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
|
|
|
|
# This function is for improving the online help system in the interactive interpreter.
|
|
# For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and
|
|
# argument annotation. (See: SPARK-19161)
|
|
def _wrapped(self):
|
|
"""
|
|
Wrap this udf with a function and attach docstring from func
|
|
"""
|
|
|
|
# It is possible for a callable instance without __name__ attribute or/and
|
|
# __module__ attribute to be wrapped here. For example, functools.partial. In this case,
|
|
# we should avoid wrapping the attributes from the wrapped function to the wrapper
|
|
# function. So, we take out these attribute names from the default names to set and
|
|
# then manually assign it after being wrapped.
|
|
assignments = tuple(
|
|
a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__')
|
|
|
|
@functools.wraps(self.func, assigned=assignments)
|
|
def wrapper(*args):
|
|
return self(*args)
|
|
|
|
wrapper.__name__ = self._name
|
|
wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__')
|
|
else self.func.__class__.__module__)
|
|
|
|
wrapper.func = self.func
|
|
wrapper.returnType = self.returnType
|
|
wrapper.evalType = self.evalType
|
|
wrapper.deterministic = self.deterministic
|
|
wrapper.asNondeterministic = functools.wraps(
|
|
self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped())
|
|
return wrapper
|
|
|
|
def asNondeterministic(self):
|
|
"""
|
|
Updates UserDefinedFunction to nondeterministic.
|
|
|
|
.. versionadded:: 2.3
|
|
"""
|
|
# Here, we explicitly clean the cache to create a JVM UDF instance
|
|
# with 'deterministic' updated. See SPARK-23233.
|
|
self._judf_placeholder = None
|
|
self.deterministic = False
|
|
return self
|
|
|
|
|
|
class UDFRegistration(object):
|
|
"""
|
|
Wrapper for user-defined function registration. This instance can be accessed by
|
|
:attr:`spark.udf` or :attr:`sqlContext.udf`.
|
|
|
|
.. versionadded:: 1.3.1
|
|
"""
|
|
|
|
def __init__(self, sparkSession):
|
|
self.sparkSession = sparkSession
|
|
|
|
@ignore_unicode_prefix
|
|
@since("1.3.1")
|
|
def register(self, name, f, returnType=None):
|
|
"""Register a Python function (including lambda function) or a user-defined function
|
|
as a SQL function.
|
|
|
|
:param name: name of the user-defined function in SQL statements.
|
|
:param f: a Python function, or a user-defined function. The user-defined function can
|
|
be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and
|
|
:meth:`pyspark.sql.functions.pandas_udf`.
|
|
:param returnType: the return type of the registered user-defined function. The value can
|
|
be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
|
|
:return: a user-defined function.
|
|
|
|
To register a nondeterministic Python function, users need to first build
|
|
a nondeterministic user-defined function for the Python function and then register it
|
|
as a SQL function.
|
|
|
|
`returnType` can be optionally specified when `f` is a Python function but not
|
|
when `f` is a user-defined function. Please see below.
|
|
|
|
1. When `f` is a Python function:
|
|
|
|
`returnType` defaults to string type and can be optionally specified. The produced
|
|
object must match the specified type. In this case, this API works as if
|
|
`register(name, f, returnType=StringType())`.
|
|
|
|
>>> strlen = spark.udf.register("stringLengthString", lambda x: len(x))
|
|
>>> spark.sql("SELECT stringLengthString('test')").collect()
|
|
[Row(stringLengthString(test)=u'4')]
|
|
|
|
>>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
|
|
[Row(stringLengthString(text)=u'3')]
|
|
|
|
>>> from pyspark.sql.types import IntegerType
|
|
>>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
|
|
>>> spark.sql("SELECT stringLengthInt('test')").collect()
|
|
[Row(stringLengthInt(test)=4)]
|
|
|
|
>>> from pyspark.sql.types import IntegerType
|
|
>>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
|
|
>>> spark.sql("SELECT stringLengthInt('test')").collect()
|
|
[Row(stringLengthInt(test)=4)]
|
|
|
|
2. When `f` is a user-defined function:
|
|
|
|
Spark uses the return type of the given user-defined function as the return type of
|
|
the registered user-defined function. `returnType` should not be specified.
|
|
In this case, this API works as if `register(name, f)`.
|
|
|
|
>>> from pyspark.sql.types import IntegerType
|
|
>>> from pyspark.sql.functions import udf
|
|
>>> slen = udf(lambda s: len(s), IntegerType())
|
|
>>> _ = spark.udf.register("slen", slen)
|
|
>>> spark.sql("SELECT slen('test')").collect()
|
|
[Row(slen(test)=4)]
|
|
|
|
>>> import random
|
|
>>> from pyspark.sql.functions import udf
|
|
>>> from pyspark.sql.types import IntegerType
|
|
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
|
|
>>> new_random_udf = spark.udf.register("random_udf", random_udf)
|
|
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
|
|
[Row(random_udf()=82)]
|
|
|
|
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
|
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
|
|
... def add_one(x):
|
|
... return x + 1
|
|
...
|
|
>>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP
|
|
>>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
|
|
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
|
|
|
|
>>> @pandas_udf("integer", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
|
|
... def sum_udf(v):
|
|
... return v.sum()
|
|
...
|
|
>>> _ = spark.udf.register("sum_udf", sum_udf) # doctest: +SKIP
|
|
>>> q = "SELECT sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
|
|
>>> spark.sql(q).collect() # doctest: +SKIP
|
|
[Row(sum_udf(v1)=1), Row(sum_udf(v1)=5)]
|
|
|
|
.. note:: Registration for a user-defined function (case 2.) was added from
|
|
Spark 2.3.0.
|
|
"""
|
|
|
|
# This is to check whether the input function is from a user-defined function or
|
|
# Python function.
|
|
if hasattr(f, 'asNondeterministic'):
|
|
if returnType is not None:
|
|
raise TypeError(
|
|
"Invalid returnType: data type can not be specified when f is"
|
|
"a user-defined function, but got %s." % returnType)
|
|
if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
|
|
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
|
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
|
|
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
|
|
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF]:
|
|
raise ValueError(
|
|
"Invalid f: f must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, "
|
|
"SQL_SCALAR_PANDAS_ITER_UDF, SQL_GROUPED_AGG_PANDAS_UDF or "
|
|
"SQL_MAP_PANDAS_ITER_UDF.")
|
|
register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
|
|
evalType=f.evalType,
|
|
deterministic=f.deterministic)
|
|
return_udf = f
|
|
else:
|
|
if returnType is None:
|
|
returnType = StringType()
|
|
register_udf = UserDefinedFunction(f, returnType=returnType, name=name,
|
|
evalType=PythonEvalType.SQL_BATCHED_UDF)
|
|
return_udf = register_udf._wrapped()
|
|
self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf)
|
|
return return_udf
|
|
|
|
@ignore_unicode_prefix
|
|
@since(2.3)
|
|
def registerJavaFunction(self, name, javaClassName, returnType=None):
|
|
"""Register a Java user-defined function as a SQL function.
|
|
|
|
In addition to a name and the function itself, the return type can be optionally specified.
|
|
When the return type is not specified we would infer it via reflection.
|
|
|
|
:param name: name of the user-defined function
|
|
:param javaClassName: fully qualified name of java class
|
|
:param returnType: the return type of the registered Java function. The value can be either
|
|
a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
|
|
|
|
>>> from pyspark.sql.types import IntegerType
|
|
>>> spark.udf.registerJavaFunction(
|
|
... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType())
|
|
>>> spark.sql("SELECT javaStringLength('test')").collect()
|
|
[Row(javaStringLength(test)=4)]
|
|
|
|
>>> spark.udf.registerJavaFunction(
|
|
... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength")
|
|
>>> spark.sql("SELECT javaStringLength2('test')").collect()
|
|
[Row(javaStringLength2(test)=4)]
|
|
|
|
>>> spark.udf.registerJavaFunction(
|
|
... "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer")
|
|
>>> spark.sql("SELECT javaStringLength3('test')").collect()
|
|
[Row(javaStringLength3(test)=4)]
|
|
"""
|
|
|
|
jdt = None
|
|
if returnType is not None:
|
|
if not isinstance(returnType, DataType):
|
|
returnType = _parse_datatype_string(returnType)
|
|
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
|
|
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
|
|
|
|
@ignore_unicode_prefix
|
|
@since(2.3)
|
|
def registerJavaUDAF(self, name, javaClassName):
|
|
"""Register a Java user-defined aggregate function as a SQL function.
|
|
|
|
:param name: name of the user-defined aggregate function
|
|
:param javaClassName: fully qualified name of java class
|
|
|
|
>>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg")
|
|
>>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
|
|
>>> df.createOrReplaceTempView("df")
|
|
>>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect()
|
|
[Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
|
|
"""
|
|
|
|
self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)
|
|
|
|
|
|
def _test():
|
|
import doctest
|
|
from pyspark.sql import SparkSession
|
|
import pyspark.sql.udf
|
|
globs = pyspark.sql.udf.__dict__.copy()
|
|
spark = SparkSession.builder\
|
|
.master("local[4]")\
|
|
.appName("sql.udf tests")\
|
|
.getOrCreate()
|
|
globs['spark'] = spark
|
|
(failure_count, test_count) = doctest.testmod(
|
|
pyspark.sql.udf, globs=globs,
|
|
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
|
|
spark.stop()
|
|
if failure_count:
|
|
sys.exit(-1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
_test()
|