[SPARK-25601][PYTHON] Register Grouped aggregate UDF Vectorized UDFs for SQL Statement

## What changes were proposed in this pull request?

This PR proposes to register Grouped aggregate UDF Vectorized UDFs for SQL Statement, for instance:

```python
from pyspark.sql.functions import pandas_udf, PandasUDFType

pandas_udf("integer", PandasUDFType.GROUPED_AGG)
def sum_udf(v):
    return v.sum()

spark.udf.register("sum_udf", sum_udf)
q = "SELECT v2, sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
spark.sql(q).show()
```

```
+---+-----------+
| v2|sum_udf(v1)|
+---+-----------+
|  1|          1|
|  0|          5|
+---+-----------+
```

## How was this patch tested?

Manual test and unit test.

Closes #22620 from HyukjinKwon/SPARK-25601.

Authored-by: hyukjinkwon <gurwls223@apache.org>
Signed-off-by: hyukjinkwon <gurwls223@apache.org>
This commit is contained in:
hyukjinkwon 2018-10-04 09:36:23 +08:00
parent 075dd620e3
commit 79dd4c9648
2 changed files with 31 additions and 4 deletions

View file

@ -5642,8 +5642,9 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
with QuietTest(self.sc):
with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or '
'SQL_SCALAR_PANDAS_UDF'):
with self.assertRaisesRegexp(
ValueError,
'f.*SQL_BATCHED_UDF.*SQL_SCALAR_PANDAS_UDF.*SQL_GROUPED_AGG_PANDAS_UDF.*'):
self.spark.catalog.registerFunction("foo_udf", foo_udf)
def test_decorator(self):
@ -6459,6 +6460,21 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
'mixture.*aggregate function.*group aggregate pandas UDF'):
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
def test_register_vectorized_udf_basic(self):
from pyspark.sql.functions import pandas_udf
from pyspark.rdd import PythonEvalType
sum_pandas_udf = pandas_udf(
lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
self.assertEqual(sum_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
group_agg_pandas_udf = self.spark.udf.register("sum_pandas_udf", sum_pandas_udf)
self.assertEqual(group_agg_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
q = "SELECT sum_pandas_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect()))
expected = [1, 5]
self.assertEqual(actual, expected)
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,

View file

@ -298,6 +298,15 @@ class UDFRegistration(object):
>>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
>>> @pandas_udf("integer", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
... def sum_udf(v):
... return v.sum()
...
>>> _ = spark.udf.register("sum_udf", sum_udf) # doctest: +SKIP
>>> q = "SELECT sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
>>> spark.sql(q).collect() # doctest: +SKIP
[Row(sum_udf(v1)=1), Row(sum_udf(v1)=5)]
.. note:: Registration for a user-defined function (case 2.) was added from
Spark 2.3.0.
"""
@ -310,9 +319,11 @@ class UDFRegistration(object):
"Invalid returnType: data type can not be specified when f is"
"a user-defined function, but got %s." % returnType)
if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_UDF]:
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]:
raise ValueError(
"Invalid f: f must be either SQL_BATCHED_UDF or SQL_SCALAR_PANDAS_UDF")
"Invalid f: f must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF or "
"SQL_GROUPED_AGG_PANDAS_UDF")
register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
evalType=f.evalType,
deterministic=f.deterministic)