[SPARK-23122][PYTHON][SQL] Deprecate register* for UDFs in SQLContext and Catalog in PySpark

## What changes were proposed in this pull request?

This PR proposes to deprecate `register*` for UDFs in `SQLContext` and `Catalog` in Spark 2.3.0.

These are inconsistent with Scala / Java APIs and also these basically do the same things with `spark.udf.register*`.

Also, this PR moves the logcis from `[sqlContext|spark.catalog].register*` to `spark.udf.register*` and reuse the docstring.

This PR also handles minor doc corrections. It also includes https://github.com/apache/spark/pull/20158

## How was this patch tested?

Manually tested, manually checked the API documentation and tests added to check if deprecated APIs call the aliases correctly.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #20288 from HyukjinKwon/deprecate-udf.
This commit is contained in:
hyukjinkwon 2018-01-18 14:51:05 +09:00 committed by Takuya UESHIN
parent 0219470206
commit 39d244d921
8 changed files with 232 additions and 208 deletions

View file

@ -400,6 +400,7 @@ pyspark_sql = Module(
"pyspark.sql.functions",
"pyspark.sql.readwriter",
"pyspark.sql.streaming",
"pyspark.sql.udf",
"pyspark.sql.window",
"pyspark.sql.tests",
]

View file

@ -224,92 +224,17 @@ class Catalog(object):
"""
self._jcatalog.dropGlobalTempView(viewName)
@ignore_unicode_prefix
@since(2.0)
def registerFunction(self, name, f, returnType=None):
"""Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
as a UDF. The registered UDF can be used in SQL statements.
"""An alias for :func:`spark.udf.register`.
See :meth:`pyspark.sql.UDFRegistration.register`.
:func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`.
In addition to a name and the function itself, `returnType` can be optionally specified.
1) When f is a Python function, `returnType` defaults to a string. The produced object must
match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return
type of the given UDF as the return type of the registered UDF. The input parameter
`returnType` is None by default. If given by users, the value must be None.
:param name: name of the UDF in SQL statements.
:param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either
row-at-a-time or vectorized.
:param returnType: the return type of the registered UDF.
:return: a wrapped/native :class:`UserDefinedFunction`
>>> strlen = spark.catalog.registerFunction("stringLengthString", len)
>>> spark.sql("SELECT stringLengthString('test')").collect()
[Row(stringLengthString(test)=u'4')]
>>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
[Row(stringLengthString(text)=u'3')]
>>> from pyspark.sql.types import IntegerType
>>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType())
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
>>> from pyspark.sql.types import IntegerType
>>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
>>> from pyspark.sql.types import IntegerType
>>> from pyspark.sql.functions import udf
>>> slen = udf(lambda s: len(s), IntegerType())
>>> _ = spark.udf.register("slen", slen)
>>> spark.sql("SELECT slen('test')").collect()
[Row(slen(test)=4)]
>>> import random
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
>>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf)
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
[Row(random_udf()=82)]
>>> spark.range(1).select(new_random_udf()).collect() # doctest: +SKIP
[Row(<lambda>()=26)]
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
... def add_one(x):
... return x + 1
...
>>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP
>>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
.. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead.
"""
# This is to check whether the input function is a wrapped/native UserDefinedFunction
if hasattr(f, 'asNondeterministic'):
if returnType is not None:
raise TypeError(
"Invalid returnType: None is expected when f is a UserDefinedFunction, "
"but got %s." % returnType)
if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
PythonEvalType.SQL_PANDAS_SCALAR_UDF]:
raise ValueError(
"Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF")
register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
evalType=f.evalType,
deterministic=f.deterministic)
return_udf = f
else:
if returnType is None:
returnType = StringType()
register_udf = UserDefinedFunction(f, returnType=returnType, name=name,
evalType=PythonEvalType.SQL_BATCHED_UDF)
return_udf = register_udf._wrapped()
self._jsparkSession.udf().registerPython(name, register_udf._judf)
return return_udf
warnings.warn(
"Deprecated in 2.3.0. Use spark.udf.register instead.",
DeprecationWarning)
return self._sparkSession.udf.register(name, f, returnType)
@since(2.0)
def isCached(self, tableName):

View file

@ -29,9 +29,10 @@ from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.streaming import DataStreamReader
from pyspark.sql.types import IntegerType, Row, StringType
from pyspark.sql.udf import UDFRegistration
from pyspark.sql.utils import install_exception_handler
__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
__all__ = ["SQLContext", "HiveContext"]
class SQLContext(object):
@ -147,7 +148,7 @@ class SQLContext(object):
:return: :class:`UDFRegistration`
"""
return UDFRegistration(self)
return self.sparkSession.udf
@since(1.4)
def range(self, start, end=None, step=1, numPartitions=None):
@ -172,113 +173,29 @@ class SQLContext(object):
"""
return self.sparkSession.range(start, end, step, numPartitions)
@ignore_unicode_prefix
@since(1.2)
def registerFunction(self, name, f, returnType=None):
"""Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
as a UDF. The registered UDF can be used in SQL statements.
"""An alias for :func:`spark.udf.register`.
See :meth:`pyspark.sql.UDFRegistration.register`.
:func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`.
In addition to a name and the function itself, `returnType` can be optionally specified.
1) When f is a Python function, `returnType` defaults to a string. The produced object must
match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return
type of the given UDF as the return type of the registered UDF. The input parameter
`returnType` is None by default. If given by users, the value must be None.
:param name: name of the UDF in SQL statements.
:param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either
row-at-a-time or vectorized.
:param returnType: the return type of the registered UDF.
:return: a wrapped/native :class:`UserDefinedFunction`
>>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
[Row(stringLengthString(test)=u'4')]
>>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
[Row(stringLengthString(text)=u'3')]
>>> from pyspark.sql.types import IntegerType
>>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
>>> from pyspark.sql.types import IntegerType
>>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
>>> from pyspark.sql.types import IntegerType
>>> from pyspark.sql.functions import udf
>>> slen = udf(lambda s: len(s), IntegerType())
>>> _ = sqlContext.udf.register("slen", slen)
>>> sqlContext.sql("SELECT slen('test')").collect()
[Row(slen(test)=4)]
>>> import random
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
>>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf)
>>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP
[Row(random_udf()=82)]
>>> sqlContext.range(1).select(new_random_udf()).collect() # doctest: +SKIP
[Row(<lambda>()=26)]
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
... def add_one(x):
... return x + 1
...
>>> _ = sqlContext.udf.register("add_one", add_one) # doctest: +SKIP
>>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
.. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead.
"""
return self.sparkSession.catalog.registerFunction(name, f, returnType)
warnings.warn(
"Deprecated in 2.3.0. Use spark.udf.register instead.",
DeprecationWarning)
return self.sparkSession.udf.register(name, f, returnType)
@ignore_unicode_prefix
@since(2.1)
def registerJavaFunction(self, name, javaClassName, returnType=None):
"""Register a java UDF so it can be used in SQL statements.
In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not specified we would infer it via reflection.
:param name: name of the UDF
:param javaClassName: fully qualified name of java class
:param returnType: a :class:`pyspark.sql.types.DataType` object
>>> sqlContext.registerJavaFunction("javaStringLength",
... "test.org.apache.spark.sql.JavaStringLength", IntegerType())
>>> sqlContext.sql("SELECT javaStringLength('test')").collect()
[Row(UDF:javaStringLength(test)=4)]
>>> sqlContext.registerJavaFunction("javaStringLength2",
... "test.org.apache.spark.sql.JavaStringLength")
>>> sqlContext.sql("SELECT javaStringLength2('test')").collect()
[Row(UDF:javaStringLength2(test)=4)]
"""An alias for :func:`spark.udf.registerJavaFunction`.
See :meth:`pyspark.sql.UDFRegistration.registerJavaFunction`.
.. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead.
"""
jdt = None
if returnType is not None:
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
@ignore_unicode_prefix
@since(2.3)
def registerJavaUDAF(self, name, javaClassName):
"""Register a java UDAF so it can be used in SQL statements.
:param name: name of the UDAF
:param javaClassName: fully qualified name of java class
>>> sqlContext.registerJavaUDAF("javaUDAF",
... "test.org.apache.spark.sql.MyDoubleAvg")
>>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
>>> df.registerTempTable("df")
>>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect()
[Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
"""
self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)
warnings.warn(
"Deprecated in 2.3.0. Use spark.udf.registerJavaFunction instead.",
DeprecationWarning)
return self.sparkSession.udf.registerJavaFunction(name, javaClassName, returnType)
# TODO(andrew): delete this once we refactor things to take in SparkSession
def _inferSchema(self, rdd, samplingRatio=None):
@ -590,24 +507,6 @@ class HiveContext(SQLContext):
self._ssql_ctx.refreshTable(tableName)
class UDFRegistration(object):
"""Wrapper for user-defined function registration."""
def __init__(self, sqlContext):
self.sqlContext = sqlContext
def register(self, name, f, returnType=None):
return self.sqlContext.registerFunction(name, f, returnType)
def registerJavaFunction(self, name, javaClassName, returnType=None):
self.sqlContext.registerJavaFunction(name, javaClassName, returnType)
def registerJavaUDAF(self, name, javaClassName):
self.sqlContext.registerJavaUDAF(name, javaClassName)
register.__doc__ = SQLContext.registerFunction.__doc__
def _test():
import os
import doctest

View file

@ -2103,7 +2103,7 @@ def udf(f=None, returnType=StringType()):
>>> import random
>>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic()
.. note:: The user-defined functions do not support conditional expressions or short curcuiting
.. note:: The user-defined functions do not support conditional expressions or short circuiting
in boolean expressions and it ends up with being executed all internally. If the functions
can fail on special rows, the workaround is to incorporate the condition into the functions.
@ -2231,7 +2231,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
... return pd.Series(np.random.randn(len(v))
>>> random = random.asNondeterministic() # doctest: +SKIP
.. note:: The user-defined functions do not support conditional expressions or short curcuiting
.. note:: The user-defined functions do not support conditional expressions or short circuiting
in boolean expressions and it ends up with being executed all internally. If the functions
can fail on special rows, the workaround is to incorporate the condition into the functions.
"""

View file

@ -212,7 +212,8 @@ class GroupedData(object):
This function does not support partial aggregation, and requires shuffling all the data in
the :class:`DataFrame`.
:param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf`
:param udf: a group map user-defined function returned by
:meth:`pyspark.sql.functions.pandas_udf`.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(

View file

@ -29,7 +29,6 @@ else:
from pyspark import since
from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.sql.catalog import Catalog
from pyspark.sql.conf import RuntimeConfig
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
@ -280,6 +279,7 @@ class SparkSession(object):
:return: :class:`Catalog`
"""
from pyspark.sql.catalog import Catalog
if not hasattr(self, "_catalog"):
self._catalog = Catalog(self)
return self._catalog
@ -291,8 +291,8 @@ class SparkSession(object):
:return: :class:`UDFRegistration`
"""
from pyspark.sql.context import UDFRegistration
return UDFRegistration(self._wrapped)
from pyspark.sql.udf import UDFRegistration
return UDFRegistration(self)
@since(2.0)
def range(self, start, end=None, step=1, numPartitions=None):

View file

@ -372,6 +372,12 @@ class SQLTests(ReusedSQLTestCase):
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
self.assertEqual(row[0], 5)
# This is to check if a deprecated 'SQLContext.registerFunction' can call its alias.
sqlContext = self.spark._wrapped
sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
[row] = sqlContext.sql("SELECT oneArg('test')").collect()
self.assertEqual(row[0], 4)
def test_udf2(self):
self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType())
self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\
@ -577,11 +583,25 @@ class SQLTests(ReusedSQLTestCase):
df.select(add_three("id").alias("plus_three")).collect()
)
# This is to check if a 'SQLContext.udf' can call its alias.
sqlContext = self.spark._wrapped
add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType())
self.assertListEqual(
df.selectExpr("add_four(id) AS plus_four").collect(),
df.select(add_four("id").alias("plus_four")).collect()
)
def test_non_existed_udf(self):
spark = self.spark
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
# This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias.
sqlContext = spark._wrapped
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf"))
def test_non_existed_udaf(self):
spark = self.spark
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",

View file

@ -19,11 +19,13 @@ User-defined function related classes and functions
"""
import functools
from pyspark import SparkContext
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType
from pyspark import SparkContext, since
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string
__all__ = ["UDFRegistration"]
def _wrap_function(sc, func, returnType):
command = (func, returnType)
@ -181,3 +183,179 @@ class UserDefinedFunction(object):
"""
self.deterministic = False
return self
class UDFRegistration(object):
"""
Wrapper for user-defined function registration. This instance can be accessed by
:attr:`spark.udf` or :attr:`sqlContext.udf`.
.. versionadded:: 1.3.1
"""
def __init__(self, sparkSession):
self.sparkSession = sparkSession
@ignore_unicode_prefix
@since("1.3.1")
def register(self, name, f, returnType=None):
"""Registers a Python function (including lambda function) or a user-defined function
in SQL statements.
:param name: name of the user-defined function in SQL statements.
:param f: a Python function, or a user-defined function. The user-defined function can
be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and
:meth:`pyspark.sql.functions.pandas_udf`.
:param returnType: the return type of the registered user-defined function.
:return: a user-defined function.
`returnType` can be optionally specified when `f` is a Python function but not
when `f` is a user-defined function. Please see below.
1. When `f` is a Python function:
`returnType` defaults to string type and can be optionally specified. The produced
object must match the specified type. In this case, this API works as if
`register(name, f, returnType=StringType())`.
>>> strlen = spark.udf.register("stringLengthString", lambda x: len(x))
>>> spark.sql("SELECT stringLengthString('test')").collect()
[Row(stringLengthString(test)=u'4')]
>>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
[Row(stringLengthString(text)=u'3')]
>>> from pyspark.sql.types import IntegerType
>>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
>>> from pyspark.sql.types import IntegerType
>>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
2. When `f` is a user-defined function:
Spark uses the return type of the given user-defined function as the return type of
the registered user-defined function. `returnType` should not be specified.
In this case, this API works as if `register(name, f)`.
>>> from pyspark.sql.types import IntegerType
>>> from pyspark.sql.functions import udf
>>> slen = udf(lambda s: len(s), IntegerType())
>>> _ = spark.udf.register("slen", slen)
>>> spark.sql("SELECT slen('test')").collect()
[Row(slen(test)=4)]
>>> import random
>>> from pyspark.sql.functions import udf
>>> from pyspark.sql.types import IntegerType
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
>>> new_random_udf = spark.udf.register("random_udf", random_udf)
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
[Row(random_udf()=82)]
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
... def add_one(x):
... return x + 1
...
>>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP
>>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
.. note:: Registration for a user-defined function (case 2.) was added from
Spark 2.3.0.
"""
# This is to check whether the input function is from a user-defined function or
# Python function.
if hasattr(f, 'asNondeterministic'):
if returnType is not None:
raise TypeError(
"Invalid returnType: data type can not be specified when f is"
"a user-defined function, but got %s." % returnType)
if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
PythonEvalType.SQL_PANDAS_SCALAR_UDF]:
raise ValueError(
"Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF")
register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
evalType=f.evalType,
deterministic=f.deterministic)
return_udf = f
else:
if returnType is None:
returnType = StringType()
register_udf = UserDefinedFunction(f, returnType=returnType, name=name,
evalType=PythonEvalType.SQL_BATCHED_UDF)
return_udf = register_udf._wrapped()
self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf)
return return_udf
@ignore_unicode_prefix
@since(2.3)
def registerJavaFunction(self, name, javaClassName, returnType=None):
"""Register a Java user-defined function so it can be used in SQL statements.
In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not specified we would infer it via reflection.
:param name: name of the user-defined function
:param javaClassName: fully qualified name of java class
:param returnType: a :class:`pyspark.sql.types.DataType` object
>>> from pyspark.sql.types import IntegerType
>>> spark.udf.registerJavaFunction(
... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType())
>>> spark.sql("SELECT javaStringLength('test')").collect()
[Row(UDF:javaStringLength(test)=4)]
>>> spark.udf.registerJavaFunction(
... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength")
>>> spark.sql("SELECT javaStringLength2('test')").collect()
[Row(UDF:javaStringLength2(test)=4)]
"""
jdt = None
if returnType is not None:
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
@ignore_unicode_prefix
@since(2.3)
def registerJavaUDAF(self, name, javaClassName):
"""Register a Java user-defined aggregate function so it can be used in SQL statements.
:param name: name of the user-defined aggregate function
:param javaClassName: fully qualified name of java class
>>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg")
>>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
>>> df.registerTempTable("df")
>>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect()
[Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
"""
self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)
def _test():
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.udf
globs = pyspark.sql.udf.__dict__.copy()
spark = SparkSession.builder\
.master("local[4]")\
.appName("sql.udf tests")\
.getOrCreate()
globs['spark'] = spark
(failure_count, test_count) = doctest.testmod(
pyspark.sql.udf, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
spark.stop()
if failure_count:
exit(-1)
if __name__ == "__main__":
_test()