[SPARK-23261][PYSPARK] Rename Pandas UDFs

## What changes were proposed in this pull request?
Rename the public APIs and names of pandas udfs.

- `PANDAS SCALAR UDF` -> `SCALAR PANDAS UDF`
- `PANDAS GROUP MAP UDF` -> `GROUPED MAP PANDAS UDF`
- `PANDAS GROUP AGG UDF` -> `GROUPED AGG PANDAS UDF`

## How was this patch tested?
The existing tests

Author: gatorsmile <gatorsmile@gmail.com>

Closes #20428 from gatorsmile/renamePandasUDFs.
This commit is contained in:
gatorsmile 2018-01-30 21:55:55 +09:00 committed by hyukjinkwon
parent 0a9ac0248b
commit 7a2ada223e
16 changed files with 120 additions and 120 deletions

View file

@ -37,16 +37,16 @@ private[spark] object PythonEvalType {
val SQL_BATCHED_UDF = 100
val SQL_PANDAS_SCALAR_UDF = 200
val SQL_PANDAS_GROUP_MAP_UDF = 201
val SQL_PANDAS_GROUP_AGG_UDF = 202
val SQL_SCALAR_PANDAS_UDF = 200
val SQL_GROUPED_MAP_PANDAS_UDF = 201
val SQL_GROUPED_AGG_PANDAS_UDF = 202
def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
case SQL_BATCHED_UDF => "SQL_BATCHED_UDF"
case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF"
case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF"
case SQL_PANDAS_GROUP_AGG_UDF => "SQL_PANDAS_GROUP_AGG_UDF"
case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF"
case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF"
case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF"
}
}

View file

@ -1684,7 +1684,7 @@ Spark will fall back to create the DataFrame without Arrow.
Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and
Pandas to work with the data. A Pandas UDF is defined using the keyword `pandas_udf` as a decorator
or to wrap the function, no additional configuration is required. Currently, there are two types of
Pandas UDF: Scalar and Group Map.
Pandas UDF: Scalar and Grouped Map.
### Scalar
@ -1702,8 +1702,8 @@ The following example shows how to create a scalar Pandas UDF that computes the
</div>
</div>
### Group Map
Group map Pandas UDFs are used with `groupBy().apply()` which implements the "split-apply-combine" pattern.
### Grouped Map
Grouped map Pandas UDFs are used with `groupBy().apply()` which implements the "split-apply-combine" pattern.
Split-apply-combine consists of three steps:
* Split the data into groups by using `DataFrame.groupBy`.
* Apply a function on each group. The input and output of the function are both `pandas.DataFrame`. The
@ -1723,7 +1723,7 @@ The following example shows how to use `groupby().apply()` to subtract the mean
<div class="codetabs">
<div data-lang="python" markdown="1">
{% include_example group_map_pandas_udf python/sql/arrow.py %}
{% include_example grouped_map_pandas_udf python/sql/arrow.py %}
</div>
</div>

View file

@ -86,15 +86,15 @@ def scalar_pandas_udf_example(spark):
# $example off:scalar_pandas_udf$
def group_map_pandas_udf_example(spark):
# $example on:group_map_pandas_udf$
def grouped_map_pandas_udf_example(spark):
# $example on:grouped_map_pandas_udf$
from pyspark.sql.functions import pandas_udf, PandasUDFType
df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
("id", "v"))
@pandas_udf("id long, v double", PandasUDFType.GROUP_MAP)
@pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)
def substract_mean(pdf):
# pdf is a pandas.DataFrame
v = pdf.v
@ -110,7 +110,7 @@ def group_map_pandas_udf_example(spark):
# | 2|-1.0|
# | 2| 4.0|
# +---+----+
# $example off:group_map_pandas_udf$
# $example off:grouped_map_pandas_udf$
if __name__ == "__main__":
@ -123,7 +123,7 @@ if __name__ == "__main__":
dataframe_with_arrow_example(spark)
print("Running pandas_udf scalar example")
scalar_pandas_udf_example(spark)
print("Running pandas_udf group map example")
group_map_pandas_udf_example(spark)
print("Running pandas_udf grouped map example")
grouped_map_pandas_udf_example(spark)
spark.stop()

View file

@ -68,9 +68,9 @@ class PythonEvalType(object):
SQL_BATCHED_UDF = 100
SQL_PANDAS_SCALAR_UDF = 200
SQL_PANDAS_GROUP_MAP_UDF = 201
SQL_PANDAS_GROUP_AGG_UDF = 202
SQL_SCALAR_PANDAS_UDF = 200
SQL_GROUPED_MAP_PANDAS_UDF = 201
SQL_GROUPED_AGG_PANDAS_UDF = 202
def portable_hash(x):

View file

@ -1737,8 +1737,8 @@ def translate(srcCol, matching, replace):
def create_map(*cols):
"""Creates a new map column.
:param cols: list of column names (string) or list of :class:`Column` expressions that grouped
as key-value pairs, e.g. (key1, value1, key2, value2, ...).
:param cols: list of column names (string) or list of :class:`Column` expressions that are
grouped as key-value pairs, e.g. (key1, value1, key2, value2, ...).
>>> df.select(create_map('name', 'age').alias("map")).collect()
[Row(map={u'Alice': 2}), Row(map={u'Bob': 5})]
@ -2085,11 +2085,11 @@ def map_values(col):
class PandasUDFType(object):
"""Pandas UDF Types. See :meth:`pyspark.sql.functions.pandas_udf`.
"""
SCALAR = PythonEvalType.SQL_PANDAS_SCALAR_UDF
SCALAR = PythonEvalType.SQL_SCALAR_PANDAS_UDF
GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF
GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
GROUP_AGG = PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF
GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
@since(1.3)
@ -2193,20 +2193,20 @@ def pandas_udf(f=None, returnType=None, functionType=None):
Therefore, this can be used, for example, to ensure the length of each returned
`pandas.Series`, and can not be used as the column length.
2. GROUP_MAP
2. GROUPED_MAP
A group map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame`
A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame`
The returnType should be a :class:`StructType` describing the schema of the returned
`pandas.DataFrame`.
The length of the returned `pandas.DataFrame` can be arbitrary.
Group map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`.
Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v")) # doctest: +SKIP
>>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) # doctest: +SKIP
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
... def normalize(pdf):
... v = pdf.v
... return pdf.assign(v=(v - v.mean()) / v.std())
@ -2223,9 +2223,9 @@ def pandas_udf(f=None, returnType=None, functionType=None):
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
3. GROUP_AGG
3. GROUPED_AGG
A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar
A grouped aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar
The `returnType` should be a primitive data type, e.g., :class:`DoubleType`.
The returned scalar can be either a python primitive type, e.g., `int` or `float`
or a numpy data type, e.g., `numpy.int64` or `numpy.float64`.
@ -2239,7 +2239,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v"))
>>> @pandas_udf("double", PandasUDFType.GROUP_AGG) # doctest: +SKIP
>>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
... def mean_udf(v):
... return v.mean()
>>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP
@ -2285,21 +2285,21 @@ def pandas_udf(f=None, returnType=None, functionType=None):
eval_type = returnType
else:
# @pandas_udf(dataType) or @pandas_udf(returnType=dataType)
eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF
eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF
else:
return_type = returnType
if functionType is not None:
eval_type = functionType
else:
eval_type = PythonEvalType.SQL_PANDAS_SCALAR_UDF
eval_type = PythonEvalType.SQL_SCALAR_PANDAS_UDF
if return_type is None:
raise ValueError("Invalid returnType: returnType can not be None")
if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF,
PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF,
PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF]:
if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]:
raise ValueError("Invalid functionType: "
"functionType must be one the values from PandasUDFType")

View file

@ -98,7 +98,7 @@ class GroupedData(object):
[Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)]
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> @pandas_udf('int', PandasUDFType.GROUP_AGG) # doctest: +SKIP
>>> @pandas_udf('int', PandasUDFType.GROUPED_AGG) # doctest: +SKIP
... def min_udf(v):
... return v.min()
>>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP
@ -235,14 +235,14 @@ class GroupedData(object):
into memory, so the user should be aware of the potential OOM risk if data is skewed
and certain groups are too large to fit in memory.
:param udf: a group map user-defined function returned by
:param udf: a grouped map user-defined function returned by
:func:`pyspark.sql.functions.pandas_udf`.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
... ("id", "v"))
>>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) # doctest: +SKIP
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
... def normalize(pdf):
... v = pdf.v
... return pdf.assign(v=(v - v.mean()) / v.std())
@ -262,9 +262,9 @@ class GroupedData(object):
"""
# Columns are special because hasattr always return True
if isinstance(udf, Column) or not hasattr(udf, 'func') \
or udf.evalType != PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
or udf.evalType != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
"GROUP_MAP.")
"GROUPED_MAP.")
df = self._df
udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())

View file

@ -3621,34 +3621,34 @@ class PandasUDFTests(ReusedSQLTestCase):
udf = pandas_udf(lambda x: x, DoubleType())
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
udf = pandas_udf(lambda x: x, 'double', PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
udf = pandas_udf(lambda x: x, StructType([StructField("v", DoubleType())]),
PandasUDFType.GROUP_MAP)
PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF)
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUP_MAP)
udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF)
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
udf = pandas_udf(lambda x: x, 'v double',
functionType=PandasUDFType.GROUP_MAP)
functionType=PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF)
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
udf = pandas_udf(lambda x: x, returnType='v double',
functionType=PandasUDFType.GROUP_MAP)
functionType=PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF)
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
def test_pandas_udf_decorator(self):
from pyspark.rdd import PythonEvalType
@ -3659,45 +3659,45 @@ class PandasUDFTests(ReusedSQLTestCase):
def foo(x):
return x
self.assertEqual(foo.returnType, DoubleType())
self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
@pandas_udf(returnType=DoubleType())
def foo(x):
return x
self.assertEqual(foo.returnType, DoubleType())
self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
schema = StructType([StructField("v", DoubleType())])
@pandas_udf(schema, PandasUDFType.GROUP_MAP)
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def foo(x):
return x
self.assertEqual(foo.returnType, schema)
self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF)
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
@pandas_udf('v double', PandasUDFType.GROUP_MAP)
@pandas_udf('v double', PandasUDFType.GROUPED_MAP)
def foo(x):
return x
self.assertEqual(foo.returnType, schema)
self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF)
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
@pandas_udf(schema, functionType=PandasUDFType.GROUP_MAP)
@pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
def foo(x):
return x
self.assertEqual(foo.returnType, schema)
self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF)
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
@pandas_udf(returnType='v double', functionType=PandasUDFType.SCALAR)
def foo(x):
return x
self.assertEqual(foo.returnType, schema)
self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
@pandas_udf(returnType=schema, functionType=PandasUDFType.GROUP_MAP)
@pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
def foo(x):
return x
self.assertEqual(foo.returnType, schema)
self.assertEqual(foo.evalType, PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF)
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
def test_udf_wrong_arg(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
@ -3724,15 +3724,15 @@ class PandasUDFTests(ReusedSQLTestCase):
return 1
with self.assertRaisesRegexp(TypeError, 'Invalid returnType'):
@pandas_udf(returnType=PandasUDFType.GROUP_MAP)
@pandas_udf(returnType=PandasUDFType.GROUPED_MAP)
def foo(df):
return df
with self.assertRaisesRegexp(ValueError, 'Invalid returnType'):
@pandas_udf(returnType='double', functionType=PandasUDFType.GROUP_MAP)
@pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP)
def foo(df):
return df
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
@pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUP_MAP)
@pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP)
def foo(k, v):
return k
@ -3804,11 +3804,11 @@ class ScalarPandasUDF(ReusedSQLTestCase):
random_pandas_udf = pandas_udf(
lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic()
self.assertEqual(random_pandas_udf.deterministic, False)
self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
nondeterministic_pandas_udf = self.spark.catalog.registerFunction(
"randomPandasUDF", random_pandas_udf)
self.assertEqual(nondeterministic_pandas_udf.deterministic, False)
self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
[row] = self.spark.sql("SELECT randomPandasUDF(1)").collect()
self.assertEqual(row[0], 7)
@ -4206,7 +4206,7 @@ class ScalarPandasUDF(ReusedSQLTestCase):
col('id').cast('int').alias('b'))
original_add = pandas_udf(lambda x, y: x + y, IntegerType())
self.assertEqual(original_add.deterministic, True)
self.assertEqual(original_add.evalType, PythonEvalType.SQL_PANDAS_SCALAR_UDF)
self.assertEqual(original_add.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
new_add = self.spark.catalog.registerFunction("add1", original_add)
res1 = df.select(new_add(col('a'), col('b')))
res2 = self.spark.sql(
@ -4237,20 +4237,20 @@ class GroupbyApplyPandasUDFTests(ReusedSQLTestCase):
StructField('v', IntegerType()),
StructField('v1', DoubleType()),
StructField('v2', LongType())]),
PandasUDFType.GROUP_MAP
PandasUDFType.GROUPED_MAP
)
result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
self.assertPandasEqual(expected, result)
def test_register_group_map_udf(self):
def test_register_grouped_map_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUP_MAP)
foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
with QuietTest(self.sc):
with self.assertRaisesRegexp(ValueError, 'f must be either SQL_BATCHED_UDF or '
'SQL_PANDAS_SCALAR_UDF'):
'SQL_SCALAR_PANDAS_UDF'):
self.spark.catalog.registerFunction("foo_udf", foo_udf)
def test_decorator(self):
@ -4259,7 +4259,7 @@ class GroupbyApplyPandasUDFTests(ReusedSQLTestCase):
@pandas_udf(
'id long, v int, v1 double, v2 long',
PandasUDFType.GROUP_MAP
PandasUDFType.GROUPED_MAP
)
def foo(pdf):
return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id)
@ -4275,7 +4275,7 @@ class GroupbyApplyPandasUDFTests(ReusedSQLTestCase):
foo = pandas_udf(
lambda pdf: pdf,
'id long, v double',
PandasUDFType.GROUP_MAP
PandasUDFType.GROUPED_MAP
)
result = df.groupby('id').apply(foo).sort('id').toPandas()
@ -4289,7 +4289,7 @@ class GroupbyApplyPandasUDFTests(ReusedSQLTestCase):
@pandas_udf(
'id long, v int, norm double',
PandasUDFType.GROUP_MAP
PandasUDFType.GROUPED_MAP
)
def normalize(pdf):
v = pdf.v
@ -4308,7 +4308,7 @@ class GroupbyApplyPandasUDFTests(ReusedSQLTestCase):
@pandas_udf(
'id long, v int, norm double',
PandasUDFType.GROUP_MAP
PandasUDFType.GROUPED_MAP
)
def normalize(pdf):
v = pdf.v
@ -4328,7 +4328,7 @@ class GroupbyApplyPandasUDFTests(ReusedSQLTestCase):
foo_udf = pandas_udf(
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
'id long, v int, v1 double, v2 long',
PandasUDFType.GROUP_MAP
PandasUDFType.GROUPED_MAP
)
result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
@ -4342,7 +4342,7 @@ class GroupbyApplyPandasUDFTests(ReusedSQLTestCase):
foo = pandas_udf(
lambda pdf: pdf,
'id long, v map<int, int>',
PandasUDFType.GROUP_MAP
PandasUDFType.GROUPED_MAP
)
with QuietTest(self.sc):
@ -4368,7 +4368,7 @@ class GroupbyApplyPandasUDFTests(ReusedSQLTestCase):
with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
df.groupby('id').apply(
pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())])))
with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUP_MAP'):
with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'):
df.groupby('id').apply(
pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]),
PandasUDFType.SCALAR))
@ -4379,7 +4379,7 @@ class GroupbyApplyPandasUDFTests(ReusedSQLTestCase):
[StructField("id", LongType(), True),
StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(1, None,)], schema=schema)
f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP)
f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUPED_MAP)
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.groupby('id').apply(f).collect()
@ -4422,7 +4422,7 @@ class GroupbyAggPandasUDFTests(ReusedSQLTestCase):
def pandas_agg_mean_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf('double', PandasUDFType.GROUP_AGG)
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def avg(v):
return v.mean()
return avg
@ -4431,7 +4431,7 @@ class GroupbyAggPandasUDFTests(ReusedSQLTestCase):
def pandas_agg_sum_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf('double', PandasUDFType.GROUP_AGG)
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def sum(v):
return v.sum()
return sum
@ -4441,7 +4441,7 @@ class GroupbyAggPandasUDFTests(ReusedSQLTestCase):
import numpy as np
from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf('double', PandasUDFType.GROUP_AGG)
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def weighted_mean(v, w):
return np.average(v, weights=w)
return weighted_mean
@ -4505,19 +4505,19 @@ class GroupbyAggPandasUDFTests(ReusedSQLTestCase):
with QuietTest(self.sc):
with self.assertRaisesRegex(NotImplementedError, 'not supported'):
@pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUP_AGG)
@pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUPED_AGG)
def mean_and_std_udf(v):
return [v.mean(), v.std()]
with QuietTest(self.sc):
with self.assertRaisesRegex(NotImplementedError, 'not supported'):
@pandas_udf('mean double, std double', PandasUDFType.GROUP_AGG)
@pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG)
def mean_and_std_udf(v):
return v.mean(), v.std()
with QuietTest(self.sc):
with self.assertRaisesRegex(NotImplementedError, 'not supported'):
@pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUP_AGG)
@pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG)
def mean_and_std_udf(v):
return {v.mean(): v.std()}

View file

@ -37,9 +37,9 @@ def _wrap_function(sc, func, returnType):
def _create_udf(f, returnType, evalType):
if evalType in (PythonEvalType.SQL_PANDAS_SCALAR_UDF,
PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF,
PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF):
if evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
import inspect
from pyspark.sql.utils import require_minimum_pyarrow_version
@ -47,16 +47,16 @@ def _create_udf(f, returnType, evalType):
require_minimum_pyarrow_version()
argspec = inspect.getargspec(f)
if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF and len(argspec.args) == 0 and \
if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \
argspec.varargs is None:
raise ValueError(
"Invalid function: 0-arg pandas_udfs are not supported. "
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
)
if evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF and len(argspec.args) != 1:
if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1:
raise ValueError(
"Invalid function: pandas_udfs with function type GROUP_MAP "
"Invalid function: pandas_udfs with function type GROUPED_MAP "
"must take a single arg that is a pandas DataFrame."
)
@ -112,14 +112,15 @@ class UserDefinedFunction(object):
else:
self._returnType_placeholder = _parse_datatype_string(self._returnType)
if self.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF \
if self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
and not isinstance(self._returnType_placeholder, StructType):
raise ValueError("Invalid returnType: returnType must be a StructType for "
"pandas_udf with function type GROUP_MAP")
elif self.evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF \
"pandas_udf with function type GROUPED_MAP")
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF \
and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)):
raise NotImplementedError(
"ArrayType, StructType and MapType are not supported with PandasUDFType.GROUP_AGG")
"ArrayType, StructType and MapType are not supported with "
"PandasUDFType.GROUPED_AGG")
return self._returnType_placeholder
@ -292,9 +293,9 @@ class UDFRegistration(object):
"Invalid returnType: data type can not be specified when f is"
"a user-defined function, but got %s." % returnType)
if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
PythonEvalType.SQL_PANDAS_SCALAR_UDF]:
PythonEvalType.SQL_SCALAR_PANDAS_UDF]:
raise ValueError(
"Invalid f: f must be either SQL_BATCHED_UDF or SQL_PANDAS_SCALAR_UDF")
"Invalid f: f must be either SQL_BATCHED_UDF or SQL_SCALAR_PANDAS_UDF")
register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
evalType=f.evalType,
deterministic=f.deterministic)

View file

@ -74,7 +74,7 @@ def wrap_udf(f, return_type):
return lambda *a: f(*a)
def wrap_pandas_scalar_udf(f, return_type):
def wrap_scalar_pandas_udf(f, return_type):
arrow_return_type = to_arrow_type(return_type)
def verify_result_length(*a):
@ -90,7 +90,7 @@ def wrap_pandas_scalar_udf(f, return_type):
return lambda *a: (verify_result_length(*a), arrow_return_type)
def wrap_pandas_group_map_udf(f, return_type):
def wrap_grouped_map_pandas_udf(f, return_type):
def wrapped(*series):
import pandas as pd
@ -110,7 +110,7 @@ def wrap_pandas_group_map_udf(f, return_type):
return wrapped
def wrap_pandas_group_agg_udf(f, return_type):
def wrap_grouped_agg_pandas_udf(f, return_type):
arrow_return_type = to_arrow_type(return_type)
def wrapped(*series):
@ -133,12 +133,12 @@ def read_single_udf(pickleSer, infile, eval_type):
row_func = chain(row_func, f)
# the last returnType will be the return type of UDF
if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF:
return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type)
elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type)
elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF:
return arg_offsets, wrap_pandas_group_agg_udf(row_func, return_type)
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type)
elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
return arg_offsets, wrap_udf(row_func, return_type)
else:
@ -163,9 +163,9 @@ def read_udfs(pickleSer, infile, eval_type):
func = lambda _, it: map(mapper, it)
if eval_type in (PythonEvalType.SQL_PANDAS_SCALAR_UDF,
PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF,
PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF):
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
timezone = utf8_deserializer.loads(infile)
ser = ArrowStreamPandasSerializer(timezone)
else:

View file

@ -27,7 +27,7 @@ import org.apache.spark.sql.types.DataType
object PythonUDF {
private[this] val SCALAR_TYPES = Set(
PythonEvalType.SQL_BATCHED_UDF,
PythonEvalType.SQL_PANDAS_SCALAR_UDF
PythonEvalType.SQL_SCALAR_PANDAS_UDF
)
def isScalarPythonUDF(e: Expression): Boolean = {
@ -36,7 +36,7 @@ object PythonUDF {
def isGroupAggPandasUDF(e: Expression): Boolean = {
e.isInstanceOf[PythonUDF] &&
e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF
e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
}
}

View file

@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.planning
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression

View file

@ -449,8 +449,8 @@ class RelationalGroupedDataset protected[sql](
* workers.
*/
private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = {
require(expr.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF,
"Must pass a group map udf")
require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
"Must pass a grouped map udf")
require(expr.dataType.isInstanceOf[StructType],
"The returnType of the udf must be a StructType")

View file

@ -136,7 +136,7 @@ case class AggregateInPandasExec(
val columnarBatchIter = new ArrowPythonRunner(
pyFuncs, bufferSize, reuseWorker,
PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF, argOffsets, aggInputSchema,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, argOffsets, aggInputSchema,
sessionLocalTimeZone, pandasRespectSessionTimeZone)
.compute(projectedRowIter, context.partitionId(), context)

View file

@ -81,7 +81,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
val columnarBatchIter = new ArrowPythonRunner(
funcs, bufferSize, reuseWorker,
PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema,
PythonEvalType.SQL_SCALAR_PANDAS_UDF, argOffsets, schema,
sessionLocalTimeZone, pandasRespectSessionTimeZone)
.compute(batchIter, context.partitionId(), context)

View file

@ -160,7 +160,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
}
val evaluation = validUdfs.partition(
_.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF
_.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF
) match {
case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)

View file

@ -96,7 +96,7 @@ case class FlatMapGroupsInPandasExec(
val columnarBatchIter = new ArrowPythonRunner(
chainedFunc, bufferSize, reuseWorker,
PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF, argOffsets, schema,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, schema,
sessionLocalTimeZone, pandasRespectSessionTimeZone)
.compute(grouped, context.partitionId(), context)