05988b256e
### What changes were proposed in this pull request? Adds a new cogroup Pandas UDF. This allows two grouped dataframes to be cogrouped together and apply a (pandas.DataFrame, pandas.DataFrame) -> pandas.DataFrame UDF to each cogroup. **Example usage** ``` from pyspark.sql.functions import pandas_udf, PandasUDFType df1 = spark.createDataFrame( [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], ("time", "id", "v1")) df2 = spark.createDataFrame( [(20000101, 1, "x"), (20000101, 2, "y")], ("time", "id", "v2")) pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP) def asof_join(l, r): return pd.merge_asof(l, r, on="time", by="id") df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() ``` +--------+---+---+---+ | time| id| v1| v2| +--------+---+---+---+ |20000101| 1|1.0| x| |20000102| 1|3.0| x| |20000101| 2|2.0| y| |20000102| 2|4.0| y| +--------+---+---+---+ ### How was this patch tested? Added unit test test_pandas_udf_cogrouped_map Closes #24981 from d80tb7/SPARK-27463-poc-arrow-stream. Authored-by: Chris Martin <chris@cmartinit.co.uk> Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
99 lines
3.9 KiB
Python
99 lines
3.9 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from pyspark import since
|
|
from pyspark.rdd import PythonEvalType
|
|
from pyspark.sql.column import Column
|
|
from pyspark.sql.dataframe import DataFrame
|
|
|
|
|
|
class CoGroupedData(object):
|
|
"""
|
|
A logical grouping of two :class:`GroupedData`,
|
|
created by :func:`GroupedData.cogroup`.
|
|
|
|
.. note:: Experimental
|
|
|
|
.. versionadded:: 3.0
|
|
"""
|
|
|
|
def __init__(self, gd1, gd2):
|
|
self._gd1 = gd1
|
|
self._gd2 = gd2
|
|
self.sql_ctx = gd1.sql_ctx
|
|
|
|
@since(3.0)
|
|
def apply(self, udf):
|
|
"""
|
|
Applies a function to each cogroup using a pandas udf and returns the result
|
|
as a `DataFrame`.
|
|
|
|
The user-defined function should take two `pandas.DataFrame` and return another
|
|
`pandas.DataFrame`. For each side of the cogroup, all columns are passed together
|
|
as a `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame`
|
|
are combined as a :class:`DataFrame`.
|
|
|
|
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
|
|
returnType of the pandas udf.
|
|
|
|
.. note:: This function requires a full shuffle. All the data of a cogroup will be loaded
|
|
into memory, so the user should be aware of the potential OOM risk if data is skewed
|
|
and certain groups are too large to fit in memory.
|
|
|
|
.. note:: Experimental
|
|
|
|
:param udf: a cogrouped map user-defined function returned by
|
|
:func:`pyspark.sql.functions.pandas_udf`.
|
|
|
|
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
|
>>> df1 = spark.createDataFrame(
|
|
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
|
|
... ("time", "id", "v1"))
|
|
>>> df2 = spark.createDataFrame(
|
|
... [(20000101, 1, "x"), (20000101, 2, "y")],
|
|
... ("time", "id", "v2"))
|
|
>>> @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP)
|
|
... def asof_join(l, r):
|
|
... return pd.merge_asof(l, r, on="time", by="id")
|
|
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show()
|
|
+--------+---+---+---+
|
|
| time| id| v1| v2|
|
|
+--------+---+---+---+
|
|
|20000101| 1|1.0| x|
|
|
|20000102| 1|3.0| x|
|
|
|20000101| 2|2.0| y|
|
|
|20000102| 2|4.0| y|
|
|
+--------+---+---+---+
|
|
|
|
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
|
|
|
|
"""
|
|
# Columns are special because hasattr always return True
|
|
if isinstance(udf, Column) or not hasattr(udf, 'func') \
|
|
or udf.evalType != PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
|
|
raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
|
|
"COGROUPED_MAP.")
|
|
all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2)
|
|
udf_column = udf(*all_cols)
|
|
jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr())
|
|
return DataFrame(jdf, self.sql_ctx)
|
|
|
|
@staticmethod
|
|
def _extract_cols(gd):
|
|
df = gd._df
|
|
return [df[col] for col in df.columns]
|