[SPARK-19160][PYTHON][SQL] Add udf decorator
## What changes were proposed in this pull request? This PR adds `udf` decorator syntax as proposed in [SPARK-19160](https://issues.apache.org/jira/browse/SPARK-19160). This allows users to define UDF using simplified syntax: ```python from pyspark.sql.decorators import udf udf(IntegerType()) def add_one(x): """Adds one""" if x is not None: return x + 1 ``` without need to define a separate function and udf. ## How was this patch tested? Existing unit tests to ensure backward compatibility and additional unit tests covering new functionality. Author: zero323 <zero323@users.noreply.github.com> Closes #16533 from zero323/SPARK-19160.
This commit is contained in:
parent
6eca21ba88
commit
c97f4e17de
|
@ -20,6 +20,7 @@ A collections of builtin functions
|
|||
"""
|
||||
import math
|
||||
import sys
|
||||
import functools
|
||||
|
||||
if sys.version < "3":
|
||||
from itertools import imap as map
|
||||
|
@ -1908,22 +1909,48 @@ class UserDefinedFunction(object):
|
|||
|
||||
|
||||
@since(1.3)
|
||||
def udf(f, returnType=StringType()):
|
||||
def udf(f=None, returnType=StringType()):
|
||||
"""Creates a :class:`Column` expression representing a user defined function (UDF).
|
||||
|
||||
.. note:: The user-defined functions must be deterministic. Due to optimization,
|
||||
duplicate invocations may be eliminated or the function may even be invoked more times than
|
||||
it is present in the query.
|
||||
|
||||
:param f: python function
|
||||
:param returnType: a :class:`pyspark.sql.types.DataType` object or data type string.
|
||||
:param f: python function if used as a standalone function
|
||||
:param returnType: a :class:`pyspark.sql.types.DataType` object
|
||||
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> slen = udf(lambda s: len(s), IntegerType())
|
||||
>>> df.select(slen(df.name).alias('slen')).collect()
|
||||
[Row(slen=5), Row(slen=3)]
|
||||
>>> @udf
|
||||
... def to_upper(s):
|
||||
... if s is not None:
|
||||
... return s.upper()
|
||||
...
|
||||
>>> @udf(returnType=IntegerType())
|
||||
... def add_one(x):
|
||||
... if x is not None:
|
||||
... return x + 1
|
||||
...
|
||||
>>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
|
||||
>>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")).show()
|
||||
+----------+--------------+------------+
|
||||
|slen(name)|to_upper(name)|add_one(age)|
|
||||
+----------+--------------+------------+
|
||||
| 8| JOHN DOE| 22|
|
||||
+----------+--------------+------------+
|
||||
"""
|
||||
return UserDefinedFunction(f, returnType)
|
||||
def _udf(f, returnType=StringType()):
|
||||
return UserDefinedFunction(f, returnType)
|
||||
|
||||
# decorator @udf, @udf() or @udf(dataType())
|
||||
if f is None or isinstance(f, (str, DataType)):
|
||||
# If DataType has been passed as a positional argument
|
||||
# for decorator use it as a returnType
|
||||
return_type = f or returnType
|
||||
return functools.partial(_udf, returnType=return_type)
|
||||
else:
|
||||
return _udf(f=f, returnType=returnType)
|
||||
|
||||
|
||||
blacklist = ['map', 'since', 'ignore_unicode_prefix']
|
||||
__all__ = [k for k, v in globals().items()
|
||||
|
|
|
@ -514,6 +514,63 @@ class SQLTests(ReusedPySparkTestCase):
|
|||
non_callable = None
|
||||
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
|
||||
|
||||
def test_udf_with_decorator(self):
|
||||
from pyspark.sql.functions import lit, udf
|
||||
from pyspark.sql.types import IntegerType, DoubleType
|
||||
|
||||
@udf(IntegerType())
|
||||
def add_one(x):
|
||||
if x is not None:
|
||||
return x + 1
|
||||
|
||||
@udf(returnType=DoubleType())
|
||||
def add_two(x):
|
||||
if x is not None:
|
||||
return float(x + 2)
|
||||
|
||||
@udf
|
||||
def to_upper(x):
|
||||
if x is not None:
|
||||
return x.upper()
|
||||
|
||||
@udf()
|
||||
def to_lower(x):
|
||||
if x is not None:
|
||||
return x.lower()
|
||||
|
||||
@udf
|
||||
def substr(x, start, end):
|
||||
if x is not None:
|
||||
return x[start:end]
|
||||
|
||||
@udf("long")
|
||||
def trunc(x):
|
||||
return int(x)
|
||||
|
||||
@udf(returnType="double")
|
||||
def as_double(x):
|
||||
return float(x)
|
||||
|
||||
df = (
|
||||
self.spark
|
||||
.createDataFrame(
|
||||
[(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float"))
|
||||
.select(
|
||||
add_one("one"), add_two("one"),
|
||||
to_upper("Foo"), to_lower("Foo"),
|
||||
substr("foobar", lit(0), lit(3)),
|
||||
trunc("float"), as_double("one")))
|
||||
|
||||
self.assertListEqual(
|
||||
[tpe for _, tpe in df.dtypes],
|
||||
["int", "double", "string", "string", "string", "bigint", "double"]
|
||||
)
|
||||
|
||||
self.assertListEqual(
|
||||
list(df.first()),
|
||||
[2, 3.0, "FOO", "foo", "foo", 3, 1.0]
|
||||
)
|
||||
|
||||
def test_basic_functions(self):
|
||||
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
|
||||
df = self.spark.read.json(rdd)
|
||||
|
|
Loading…
Reference in a new issue