[SPARK-23122][PYTHON][SQL] Deprecate register* for UDFs in SQLContext and Catalog in PySpark
## What changes were proposed in this pull request? This PR proposes to deprecate `register*` for UDFs in `SQLContext` and `Catalog` in Spark 2.3.0. These are inconsistent with Scala / Java APIs and also these basically do the same things with `spark.udf.register*`. Also, this PR moves the logcis from `[sqlContext|spark.catalog].register*` to `spark.udf.register*` and reuse the docstring. This PR also handles minor doc corrections. It also includes https://github.com/apache/spark/pull/20158 ## How was this patch tested? Manually tested, manually checked the API documentation and tests added to check if deprecated APIs call the aliases correctly. Author: hyukjinkwon <gurwls223@gmail.com> Closes #20288 from HyukjinKwon/deprecate-udf.
This commit is contained in:
parent
0219470206
commit
39d244d921
|
@ -400,6 +400,7 @@ pyspark_sql = Module(
|
|||
"pyspark.sql.functions",
|
||||
"pyspark.sql.readwriter",
|
||||
"pyspark.sql.streaming",
|
||||
"pyspark.sql.udf",
|
||||
"pyspark.sql.window",
|
||||
"pyspark.sql.tests",
|
||||
]
|
||||
|
|
|
@ -224,92 +224,17 @@ class Catalog(object):
|
|||
"""
|
||||
self._jcatalog.dropGlobalTempView(viewName)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(2.0)
|
||||
def registerFunction(self, name, f, returnType=None):
|
||||
"""Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
|
||||
as a UDF. The registered UDF can be used in SQL statements.
|
||||
"""An alias for :func:`spark.udf.register`.
|
||||
See :meth:`pyspark.sql.UDFRegistration.register`.
|
||||
|
||||
:func:`spark.udf.register` is an alias for :func:`spark.catalog.registerFunction`.
|
||||
|
||||
In addition to a name and the function itself, `returnType` can be optionally specified.
|
||||
1) When f is a Python function, `returnType` defaults to a string. The produced object must
|
||||
match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return
|
||||
type of the given UDF as the return type of the registered UDF. The input parameter
|
||||
`returnType` is None by default. If given by users, the value must be None.
|
||||
|
||||
:param name: name of the UDF in SQL statements.
|
||||
:param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either
|
||||
row-at-a-time or vectorized.
|
||||
:param returnType: the return type of the registered UDF.
|
||||
:return: a wrapped/native :class:`UserDefinedFunction`
|
||||
|
||||
>>> strlen = spark.catalog.registerFunction("stringLengthString", len)
|
||||
>>> spark.sql("SELECT stringLengthString('test')").collect()
|
||||
[Row(stringLengthString(test)=u'4')]
|
||||
|
||||
>>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
|
||||
[Row(stringLengthString(text)=u'3')]
|
||||
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> _ = spark.catalog.registerFunction("stringLengthInt", len, IntegerType())
|
||||
>>> spark.sql("SELECT stringLengthInt('test')").collect()
|
||||
[Row(stringLengthInt(test)=4)]
|
||||
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> _ = spark.udf.register("stringLengthInt", len, IntegerType())
|
||||
>>> spark.sql("SELECT stringLengthInt('test')").collect()
|
||||
[Row(stringLengthInt(test)=4)]
|
||||
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> from pyspark.sql.functions import udf
|
||||
>>> slen = udf(lambda s: len(s), IntegerType())
|
||||
>>> _ = spark.udf.register("slen", slen)
|
||||
>>> spark.sql("SELECT slen('test')").collect()
|
||||
[Row(slen(test)=4)]
|
||||
|
||||
>>> import random
|
||||
>>> from pyspark.sql.functions import udf
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
|
||||
>>> new_random_udf = spark.catalog.registerFunction("random_udf", random_udf)
|
||||
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
|
||||
[Row(random_udf()=82)]
|
||||
>>> spark.range(1).select(new_random_udf()).collect() # doctest: +SKIP
|
||||
[Row(<lambda>()=26)]
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
|
||||
... def add_one(x):
|
||||
... return x + 1
|
||||
...
|
||||
>>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP
|
||||
>>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
|
||||
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
|
||||
.. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead.
|
||||
"""
|
||||
|
||||
# This is to check whether the input function is a wrapped/native UserDefinedFunction
|
||||
if hasattr(f, 'asNondeterministic'):
|
||||
if returnType is not None:
|
||||
raise TypeError(
|
||||
"Invalid returnType: None is expected when f is a UserDefinedFunction, "
|
||||
"but got %s." % returnType)
|
||||
if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
|
||||
PythonEvalType.SQL_PANDAS_SCALAR_UDF]:
|
||||
raise ValueError(
|
||||
"Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF")
|
||||
register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
|
||||
evalType=f.evalType,
|
||||
deterministic=f.deterministic)
|
||||
return_udf = f
|
||||
else:
|
||||
if returnType is None:
|
||||
returnType = StringType()
|
||||
register_udf = UserDefinedFunction(f, returnType=returnType, name=name,
|
||||
evalType=PythonEvalType.SQL_BATCHED_UDF)
|
||||
return_udf = register_udf._wrapped()
|
||||
self._jsparkSession.udf().registerPython(name, register_udf._judf)
|
||||
return return_udf
|
||||
warnings.warn(
|
||||
"Deprecated in 2.3.0. Use spark.udf.register instead.",
|
||||
DeprecationWarning)
|
||||
return self._sparkSession.udf.register(name, f, returnType)
|
||||
|
||||
@since(2.0)
|
||||
def isCached(self, tableName):
|
||||
|
|
|
@ -29,9 +29,10 @@ from pyspark.sql.dataframe import DataFrame
|
|||
from pyspark.sql.readwriter import DataFrameReader
|
||||
from pyspark.sql.streaming import DataStreamReader
|
||||
from pyspark.sql.types import IntegerType, Row, StringType
|
||||
from pyspark.sql.udf import UDFRegistration
|
||||
from pyspark.sql.utils import install_exception_handler
|
||||
|
||||
__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
|
||||
__all__ = ["SQLContext", "HiveContext"]
|
||||
|
||||
|
||||
class SQLContext(object):
|
||||
|
@ -147,7 +148,7 @@ class SQLContext(object):
|
|||
|
||||
:return: :class:`UDFRegistration`
|
||||
"""
|
||||
return UDFRegistration(self)
|
||||
return self.sparkSession.udf
|
||||
|
||||
@since(1.4)
|
||||
def range(self, start, end=None, step=1, numPartitions=None):
|
||||
|
@ -172,113 +173,29 @@ class SQLContext(object):
|
|||
"""
|
||||
return self.sparkSession.range(start, end, step, numPartitions)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(1.2)
|
||||
def registerFunction(self, name, f, returnType=None):
|
||||
"""Registers a Python function (including lambda function) or a :class:`UserDefinedFunction`
|
||||
as a UDF. The registered UDF can be used in SQL statements.
|
||||
"""An alias for :func:`spark.udf.register`.
|
||||
See :meth:`pyspark.sql.UDFRegistration.register`.
|
||||
|
||||
:func:`spark.udf.register` is an alias for :func:`sqlContext.registerFunction`.
|
||||
|
||||
In addition to a name and the function itself, `returnType` can be optionally specified.
|
||||
1) When f is a Python function, `returnType` defaults to a string. The produced object must
|
||||
match the specified type. 2) When f is a :class:`UserDefinedFunction`, Spark uses the return
|
||||
type of the given UDF as the return type of the registered UDF. The input parameter
|
||||
`returnType` is None by default. If given by users, the value must be None.
|
||||
|
||||
:param name: name of the UDF in SQL statements.
|
||||
:param f: a Python function, or a wrapped/native UserDefinedFunction. The UDF can be either
|
||||
row-at-a-time or vectorized.
|
||||
:param returnType: the return type of the registered UDF.
|
||||
:return: a wrapped/native :class:`UserDefinedFunction`
|
||||
|
||||
>>> strlen = sqlContext.registerFunction("stringLengthString", lambda x: len(x))
|
||||
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
|
||||
[Row(stringLengthString(test)=u'4')]
|
||||
|
||||
>>> sqlContext.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
|
||||
[Row(stringLengthString(text)=u'3')]
|
||||
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> _ = sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
|
||||
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
|
||||
[Row(stringLengthInt(test)=4)]
|
||||
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
|
||||
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
|
||||
[Row(stringLengthInt(test)=4)]
|
||||
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> from pyspark.sql.functions import udf
|
||||
>>> slen = udf(lambda s: len(s), IntegerType())
|
||||
>>> _ = sqlContext.udf.register("slen", slen)
|
||||
>>> sqlContext.sql("SELECT slen('test')").collect()
|
||||
[Row(slen(test)=4)]
|
||||
|
||||
>>> import random
|
||||
>>> from pyspark.sql.functions import udf
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
|
||||
>>> new_random_udf = sqlContext.registerFunction("random_udf", random_udf)
|
||||
>>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP
|
||||
[Row(random_udf()=82)]
|
||||
>>> sqlContext.range(1).select(new_random_udf()).collect() # doctest: +SKIP
|
||||
[Row(<lambda>()=26)]
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
|
||||
... def add_one(x):
|
||||
... return x + 1
|
||||
...
|
||||
>>> _ = sqlContext.udf.register("add_one", add_one) # doctest: +SKIP
|
||||
>>> sqlContext.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
|
||||
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
|
||||
.. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead.
|
||||
"""
|
||||
return self.sparkSession.catalog.registerFunction(name, f, returnType)
|
||||
warnings.warn(
|
||||
"Deprecated in 2.3.0. Use spark.udf.register instead.",
|
||||
DeprecationWarning)
|
||||
return self.sparkSession.udf.register(name, f, returnType)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(2.1)
|
||||
def registerJavaFunction(self, name, javaClassName, returnType=None):
|
||||
"""Register a java UDF so it can be used in SQL statements.
|
||||
|
||||
In addition to a name and the function itself, the return type can be optionally specified.
|
||||
When the return type is not specified we would infer it via reflection.
|
||||
:param name: name of the UDF
|
||||
:param javaClassName: fully qualified name of java class
|
||||
:param returnType: a :class:`pyspark.sql.types.DataType` object
|
||||
|
||||
>>> sqlContext.registerJavaFunction("javaStringLength",
|
||||
... "test.org.apache.spark.sql.JavaStringLength", IntegerType())
|
||||
>>> sqlContext.sql("SELECT javaStringLength('test')").collect()
|
||||
[Row(UDF:javaStringLength(test)=4)]
|
||||
>>> sqlContext.registerJavaFunction("javaStringLength2",
|
||||
... "test.org.apache.spark.sql.JavaStringLength")
|
||||
>>> sqlContext.sql("SELECT javaStringLength2('test')").collect()
|
||||
[Row(UDF:javaStringLength2(test)=4)]
|
||||
"""An alias for :func:`spark.udf.registerJavaFunction`.
|
||||
See :meth:`pyspark.sql.UDFRegistration.registerJavaFunction`.
|
||||
|
||||
.. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead.
|
||||
"""
|
||||
jdt = None
|
||||
if returnType is not None:
|
||||
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
|
||||
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(2.3)
|
||||
def registerJavaUDAF(self, name, javaClassName):
|
||||
"""Register a java UDAF so it can be used in SQL statements.
|
||||
|
||||
:param name: name of the UDAF
|
||||
:param javaClassName: fully qualified name of java class
|
||||
|
||||
>>> sqlContext.registerJavaUDAF("javaUDAF",
|
||||
... "test.org.apache.spark.sql.MyDoubleAvg")
|
||||
>>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
|
||||
>>> df.registerTempTable("df")
|
||||
>>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect()
|
||||
[Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
|
||||
"""
|
||||
self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)
|
||||
warnings.warn(
|
||||
"Deprecated in 2.3.0. Use spark.udf.registerJavaFunction instead.",
|
||||
DeprecationWarning)
|
||||
return self.sparkSession.udf.registerJavaFunction(name, javaClassName, returnType)
|
||||
|
||||
# TODO(andrew): delete this once we refactor things to take in SparkSession
|
||||
def _inferSchema(self, rdd, samplingRatio=None):
|
||||
|
@ -590,24 +507,6 @@ class HiveContext(SQLContext):
|
|||
self._ssql_ctx.refreshTable(tableName)
|
||||
|
||||
|
||||
class UDFRegistration(object):
|
||||
"""Wrapper for user-defined function registration."""
|
||||
|
||||
def __init__(self, sqlContext):
|
||||
self.sqlContext = sqlContext
|
||||
|
||||
def register(self, name, f, returnType=None):
|
||||
return self.sqlContext.registerFunction(name, f, returnType)
|
||||
|
||||
def registerJavaFunction(self, name, javaClassName, returnType=None):
|
||||
self.sqlContext.registerJavaFunction(name, javaClassName, returnType)
|
||||
|
||||
def registerJavaUDAF(self, name, javaClassName):
|
||||
self.sqlContext.registerJavaUDAF(name, javaClassName)
|
||||
|
||||
register.__doc__ = SQLContext.registerFunction.__doc__
|
||||
|
||||
|
||||
def _test():
|
||||
import os
|
||||
import doctest
|
||||
|
|
|
@ -2103,7 +2103,7 @@ def udf(f=None, returnType=StringType()):
|
|||
>>> import random
|
||||
>>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic()
|
||||
|
||||
.. note:: The user-defined functions do not support conditional expressions or short curcuiting
|
||||
.. note:: The user-defined functions do not support conditional expressions or short circuiting
|
||||
in boolean expressions and it ends up with being executed all internally. If the functions
|
||||
can fail on special rows, the workaround is to incorporate the condition into the functions.
|
||||
|
||||
|
@ -2231,7 +2231,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
|
|||
... return pd.Series(np.random.randn(len(v))
|
||||
>>> random = random.asNondeterministic() # doctest: +SKIP
|
||||
|
||||
.. note:: The user-defined functions do not support conditional expressions or short curcuiting
|
||||
.. note:: The user-defined functions do not support conditional expressions or short circuiting
|
||||
in boolean expressions and it ends up with being executed all internally. If the functions
|
||||
can fail on special rows, the workaround is to incorporate the condition into the functions.
|
||||
"""
|
||||
|
|
|
@ -212,7 +212,8 @@ class GroupedData(object):
|
|||
This function does not support partial aggregation, and requires shuffling all the data in
|
||||
the :class:`DataFrame`.
|
||||
|
||||
:param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf`
|
||||
:param udf: a group map user-defined function returned by
|
||||
:meth:`pyspark.sql.functions.pandas_udf`.
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df = spark.createDataFrame(
|
||||
|
|
|
@ -29,7 +29,6 @@ else:
|
|||
|
||||
from pyspark import since
|
||||
from pyspark.rdd import RDD, ignore_unicode_prefix
|
||||
from pyspark.sql.catalog import Catalog
|
||||
from pyspark.sql.conf import RuntimeConfig
|
||||
from pyspark.sql.dataframe import DataFrame
|
||||
from pyspark.sql.readwriter import DataFrameReader
|
||||
|
@ -280,6 +279,7 @@ class SparkSession(object):
|
|||
|
||||
:return: :class:`Catalog`
|
||||
"""
|
||||
from pyspark.sql.catalog import Catalog
|
||||
if not hasattr(self, "_catalog"):
|
||||
self._catalog = Catalog(self)
|
||||
return self._catalog
|
||||
|
@ -291,8 +291,8 @@ class SparkSession(object):
|
|||
|
||||
:return: :class:`UDFRegistration`
|
||||
"""
|
||||
from pyspark.sql.context import UDFRegistration
|
||||
return UDFRegistration(self._wrapped)
|
||||
from pyspark.sql.udf import UDFRegistration
|
||||
return UDFRegistration(self)
|
||||
|
||||
@since(2.0)
|
||||
def range(self, start, end=None, step=1, numPartitions=None):
|
||||
|
|
|
@ -372,6 +372,12 @@ class SQLTests(ReusedSQLTestCase):
|
|||
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
|
||||
self.assertEqual(row[0], 5)
|
||||
|
||||
# This is to check if a deprecated 'SQLContext.registerFunction' can call its alias.
|
||||
sqlContext = self.spark._wrapped
|
||||
sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
|
||||
[row] = sqlContext.sql("SELECT oneArg('test')").collect()
|
||||
self.assertEqual(row[0], 4)
|
||||
|
||||
def test_udf2(self):
|
||||
self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType())
|
||||
self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\
|
||||
|
@ -577,11 +583,25 @@ class SQLTests(ReusedSQLTestCase):
|
|||
df.select(add_three("id").alias("plus_three")).collect()
|
||||
)
|
||||
|
||||
# This is to check if a 'SQLContext.udf' can call its alias.
|
||||
sqlContext = self.spark._wrapped
|
||||
add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType())
|
||||
|
||||
self.assertListEqual(
|
||||
df.selectExpr("add_four(id) AS plus_four").collect(),
|
||||
df.select(add_four("id").alias("plus_four")).collect()
|
||||
)
|
||||
|
||||
def test_non_existed_udf(self):
|
||||
spark = self.spark
|
||||
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
|
||||
lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
|
||||
|
||||
# This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias.
|
||||
sqlContext = spark._wrapped
|
||||
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
|
||||
lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf"))
|
||||
|
||||
def test_non_existed_udaf(self):
|
||||
spark = self.spark
|
||||
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
|
||||
|
|
|
@ -19,11 +19,13 @@ User-defined function related classes and functions
|
|||
"""
|
||||
import functools
|
||||
|
||||
from pyspark import SparkContext
|
||||
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType
|
||||
from pyspark import SparkContext, since
|
||||
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix
|
||||
from pyspark.sql.column import Column, _to_java_column, _to_seq
|
||||
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string
|
||||
|
||||
__all__ = ["UDFRegistration"]
|
||||
|
||||
|
||||
def _wrap_function(sc, func, returnType):
|
||||
command = (func, returnType)
|
||||
|
@ -181,3 +183,179 @@ class UserDefinedFunction(object):
|
|||
"""
|
||||
self.deterministic = False
|
||||
return self
|
||||
|
||||
|
||||
class UDFRegistration(object):
|
||||
"""
|
||||
Wrapper for user-defined function registration. This instance can be accessed by
|
||||
:attr:`spark.udf` or :attr:`sqlContext.udf`.
|
||||
|
||||
.. versionadded:: 1.3.1
|
||||
"""
|
||||
|
||||
def __init__(self, sparkSession):
|
||||
self.sparkSession = sparkSession
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since("1.3.1")
|
||||
def register(self, name, f, returnType=None):
|
||||
"""Registers a Python function (including lambda function) or a user-defined function
|
||||
in SQL statements.
|
||||
|
||||
:param name: name of the user-defined function in SQL statements.
|
||||
:param f: a Python function, or a user-defined function. The user-defined function can
|
||||
be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and
|
||||
:meth:`pyspark.sql.functions.pandas_udf`.
|
||||
:param returnType: the return type of the registered user-defined function.
|
||||
:return: a user-defined function.
|
||||
|
||||
`returnType` can be optionally specified when `f` is a Python function but not
|
||||
when `f` is a user-defined function. Please see below.
|
||||
|
||||
1. When `f` is a Python function:
|
||||
|
||||
`returnType` defaults to string type and can be optionally specified. The produced
|
||||
object must match the specified type. In this case, this API works as if
|
||||
`register(name, f, returnType=StringType())`.
|
||||
|
||||
>>> strlen = spark.udf.register("stringLengthString", lambda x: len(x))
|
||||
>>> spark.sql("SELECT stringLengthString('test')").collect()
|
||||
[Row(stringLengthString(test)=u'4')]
|
||||
|
||||
>>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
|
||||
[Row(stringLengthString(text)=u'3')]
|
||||
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
|
||||
>>> spark.sql("SELECT stringLengthInt('test')").collect()
|
||||
[Row(stringLengthInt(test)=4)]
|
||||
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
|
||||
>>> spark.sql("SELECT stringLengthInt('test')").collect()
|
||||
[Row(stringLengthInt(test)=4)]
|
||||
|
||||
2. When `f` is a user-defined function:
|
||||
|
||||
Spark uses the return type of the given user-defined function as the return type of
|
||||
the registered user-defined function. `returnType` should not be specified.
|
||||
In this case, this API works as if `register(name, f)`.
|
||||
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> from pyspark.sql.functions import udf
|
||||
>>> slen = udf(lambda s: len(s), IntegerType())
|
||||
>>> _ = spark.udf.register("slen", slen)
|
||||
>>> spark.sql("SELECT slen('test')").collect()
|
||||
[Row(slen(test)=4)]
|
||||
|
||||
>>> import random
|
||||
>>> from pyspark.sql.functions import udf
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
|
||||
>>> new_random_udf = spark.udf.register("random_udf", random_udf)
|
||||
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
|
||||
[Row(random_udf()=82)]
|
||||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
|
||||
... def add_one(x):
|
||||
... return x + 1
|
||||
...
|
||||
>>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP
|
||||
>>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
|
||||
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
|
||||
|
||||
.. note:: Registration for a user-defined function (case 2.) was added from
|
||||
Spark 2.3.0.
|
||||
"""
|
||||
|
||||
# This is to check whether the input function is from a user-defined function or
|
||||
# Python function.
|
||||
if hasattr(f, 'asNondeterministic'):
|
||||
if returnType is not None:
|
||||
raise TypeError(
|
||||
"Invalid returnType: data type can not be specified when f is"
|
||||
"a user-defined function, but got %s." % returnType)
|
||||
if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
|
||||
PythonEvalType.SQL_PANDAS_SCALAR_UDF]:
|
||||
raise ValueError(
|
||||
"Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF")
|
||||
register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
|
||||
evalType=f.evalType,
|
||||
deterministic=f.deterministic)
|
||||
return_udf = f
|
||||
else:
|
||||
if returnType is None:
|
||||
returnType = StringType()
|
||||
register_udf = UserDefinedFunction(f, returnType=returnType, name=name,
|
||||
evalType=PythonEvalType.SQL_BATCHED_UDF)
|
||||
return_udf = register_udf._wrapped()
|
||||
self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf)
|
||||
return return_udf
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(2.3)
|
||||
def registerJavaFunction(self, name, javaClassName, returnType=None):
|
||||
"""Register a Java user-defined function so it can be used in SQL statements.
|
||||
|
||||
In addition to a name and the function itself, the return type can be optionally specified.
|
||||
When the return type is not specified we would infer it via reflection.
|
||||
|
||||
:param name: name of the user-defined function
|
||||
:param javaClassName: fully qualified name of java class
|
||||
:param returnType: a :class:`pyspark.sql.types.DataType` object
|
||||
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> spark.udf.registerJavaFunction(
|
||||
... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType())
|
||||
>>> spark.sql("SELECT javaStringLength('test')").collect()
|
||||
[Row(UDF:javaStringLength(test)=4)]
|
||||
>>> spark.udf.registerJavaFunction(
|
||||
... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength")
|
||||
>>> spark.sql("SELECT javaStringLength2('test')").collect()
|
||||
[Row(UDF:javaStringLength2(test)=4)]
|
||||
"""
|
||||
|
||||
jdt = None
|
||||
if returnType is not None:
|
||||
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
|
||||
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(2.3)
|
||||
def registerJavaUDAF(self, name, javaClassName):
|
||||
"""Register a Java user-defined aggregate function so it can be used in SQL statements.
|
||||
|
||||
:param name: name of the user-defined aggregate function
|
||||
:param javaClassName: fully qualified name of java class
|
||||
|
||||
>>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg")
|
||||
>>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
|
||||
>>> df.registerTempTable("df")
|
||||
>>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect()
|
||||
[Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
|
||||
"""
|
||||
|
||||
self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest
|
||||
from pyspark.sql import SparkSession
|
||||
import pyspark.sql.udf
|
||||
globs = pyspark.sql.udf.__dict__.copy()
|
||||
spark = SparkSession.builder\
|
||||
.master("local[4]")\
|
||||
.appName("sql.udf tests")\
|
||||
.getOrCreate()
|
||||
globs['spark'] = spark
|
||||
(failure_count, test_count) = doctest.testmod(
|
||||
pyspark.sql.udf, globs=globs,
|
||||
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
|
||||
spark.stop()
|
||||
if failure_count:
|
||||
exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test()
|
||||
|
|
Loading…
Reference in a new issue