diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4f4ae10892..d261720314 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -20,6 +20,7 @@ A collections of builtin functions """ import math import sys +import functools if sys.version < "3": from itertools import imap as map @@ -1908,22 +1909,48 @@ class UserDefinedFunction(object): @since(1.3) -def udf(f, returnType=StringType()): +def udf(f=None, returnType=StringType()): """Creates a :class:`Column` expression representing a user defined function (UDF). .. note:: The user-defined functions must be deterministic. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked more times than it is present in the query. - :param f: python function - :param returnType: a :class:`pyspark.sql.types.DataType` object or data type string. + :param f: python function if used as a standalone function + :param returnType: a :class:`pyspark.sql.types.DataType` object >>> from pyspark.sql.types import IntegerType >>> slen = udf(lambda s: len(s), IntegerType()) - >>> df.select(slen(df.name).alias('slen')).collect() - [Row(slen=5), Row(slen=3)] + >>> @udf + ... def to_upper(s): + ... if s is not None: + ... return s.upper() + ... + >>> @udf(returnType=IntegerType()) + ... def add_one(x): + ... if x is not None: + ... return x + 1 + ... + >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) + >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")).show() + +----------+--------------+------------+ + |slen(name)|to_upper(name)|add_one(age)| + +----------+--------------+------------+ + | 8| JOHN DOE| 22| + +----------+--------------+------------+ """ - return UserDefinedFunction(f, returnType) + def _udf(f, returnType=StringType()): + return UserDefinedFunction(f, returnType) + + # decorator @udf, @udf() or @udf(dataType()) + if f is None or isinstance(f, (str, DataType)): + # If DataType has been passed as a positional argument + # for decorator use it as a returnType + return_type = f or returnType + return functools.partial(_udf, returnType=return_type) + else: + return _udf(f=f, returnType=returnType) + blacklist = ['map', 'since', 'ignore_unicode_prefix'] __all__ = [k for k, v in globals().items() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 62e1a8c363..d8b7b3137c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -514,6 +514,63 @@ class SQLTests(ReusedPySparkTestCase): non_callable = None self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) + def test_udf_with_decorator(self): + from pyspark.sql.functions import lit, udf + from pyspark.sql.types import IntegerType, DoubleType + + @udf(IntegerType()) + def add_one(x): + if x is not None: + return x + 1 + + @udf(returnType=DoubleType()) + def add_two(x): + if x is not None: + return float(x + 2) + + @udf + def to_upper(x): + if x is not None: + return x.upper() + + @udf() + def to_lower(x): + if x is not None: + return x.lower() + + @udf + def substr(x, start, end): + if x is not None: + return x[start:end] + + @udf("long") + def trunc(x): + return int(x) + + @udf(returnType="double") + def as_double(x): + return float(x) + + df = ( + self.spark + .createDataFrame( + [(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float")) + .select( + add_one("one"), add_two("one"), + to_upper("Foo"), to_lower("Foo"), + substr("foobar", lit(0), lit(3)), + trunc("float"), as_double("one"))) + + self.assertListEqual( + [tpe for _, tpe in df.dtypes], + ["int", "double", "string", "string", "string", "bigint", "double"] + ) + + self.assertListEqual( + list(df.first()), + [2, 3.0, "FOO", "foo", "foo", 3, 1.0] + ) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd)