[SPARK-25601][PYTHON] Register Grouped aggregate UDF Vectorized UDFs for SQL Statement
## What changes were proposed in this pull request? This PR proposes to register Grouped aggregate UDF Vectorized UDFs for SQL Statement, for instance: ```python from pyspark.sql.functions import pandas_udf, PandasUDFType pandas_udf("integer", PandasUDFType.GROUPED_AGG) def sum_udf(v): return v.sum() spark.udf.register("sum_udf", sum_udf) q = "SELECT v2, sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2" spark.sql(q).show() ``` ``` +---+-----------+ | v2|sum_udf(v1)| +---+-----------+ | 1| 1| | 0| 5| +---+-----------+ ``` ## How was this patch tested? Manual test and unit test. Closes #22620 from HyukjinKwon/SPARK-25601. Authored-by: hyukjinkwon <gurwls223@apache.org> Signed-off-by: hyukjinkwon <gurwls223@apache.org>
This commit is contained in:
parent
075dd620e3
commit
79dd4c9648
|
@ -5642,8 +5642,9 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
|
|||
|
||||
foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
|
||||
with QuietTest(self.sc):
|
||||
with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or '
|
||||
'SQL_SCALAR_PANDAS_UDF'):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
'f.*SQL_BATCHED_UDF.*SQL_SCALAR_PANDAS_UDF.*SQL_GROUPED_AGG_PANDAS_UDF.*'):
|
||||
self.spark.catalog.registerFunction("foo_udf", foo_udf)
|
||||
|
||||
def test_decorator(self):
|
||||
|
@ -6459,6 +6460,21 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
|
|||
'mixture.*aggregate function.*group aggregate pandas UDF'):
|
||||
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
|
||||
|
||||
def test_register_vectorized_udf_basic(self):
|
||||
from pyspark.sql.functions import pandas_udf
|
||||
from pyspark.rdd import PythonEvalType
|
||||
|
||||
sum_pandas_udf = pandas_udf(
|
||||
lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
|
||||
|
||||
self.assertEqual(sum_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
|
||||
group_agg_pandas_udf = self.spark.udf.register("sum_pandas_udf", sum_pandas_udf)
|
||||
self.assertEqual(group_agg_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
|
||||
q = "SELECT sum_pandas_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
|
||||
actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect()))
|
||||
expected = [1, 5]
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not _have_pandas or not _have_pyarrow,
|
||||
|
|
|
@ -298,6 +298,15 @@ class UDFRegistration(object):
|
|||
>>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
|
||||
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
|
||||
|
||||
>>> @pandas_udf("integer", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
|
||||
... def sum_udf(v):
|
||||
... return v.sum()
|
||||
...
|
||||
>>> _ = spark.udf.register("sum_udf", sum_udf) # doctest: +SKIP
|
||||
>>> q = "SELECT sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
|
||||
>>> spark.sql(q).collect() # doctest: +SKIP
|
||||
[Row(sum_udf(v1)=1), Row(sum_udf(v1)=5)]
|
||||
|
||||
.. note:: Registration for a user-defined function (case 2.) was added from
|
||||
Spark 2.3.0.
|
||||
"""
|
||||
|
@ -310,9 +319,11 @@ class UDFRegistration(object):
|
|||
"Invalid returnType: data type can not be specified when f is"
|
||||
"a user-defined function, but got %s." % returnType)
|
||||
if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
|
||||
PythonEvalType.SQL_SCALAR_PANDAS_UDF]:
|
||||
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
|
||||
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]:
|
||||
raise ValueError(
|
||||
"Invalid f: f must be either SQL_BATCHED_UDF or SQL_SCALAR_PANDAS_UDF")
|
||||
"Invalid f: f must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF or "
|
||||
"SQL_GROUPED_AGG_PANDAS_UDF")
|
||||
register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
|
||||
evalType=f.evalType,
|
||||
deterministic=f.deterministic)
|
||||
|
|
Loading…
Reference in a new issue