[SPARK-27463][PYTHON][FOLLOW-UP] Miscellaneous documentation and code cleanup of cogroup pandas UDF
Follow up from https://github.com/apache/spark/pull/24981 incorporating some comments from HyukjinKwon. Specifically: - Adding `CoGroupedData` to `pyspark/sql/__init__.py __all__` so that documentation is generated. - Added pydoc, including example, for the use case whereby the user supplies a cogrouping function including a key. - Added the boilerplate for doctests to cogroup.py. Note that cogroup.py only contains the apply() function which has doctests disabled as per the other Pandas Udfs. - Restricted the newly exposed RelationalGroupedDataset constructor parameters to access only by the sql package. - Some minor formatting tweaks. This was tested by running the appropriate unit tests. I'm unsure as to how to check that my change will cause the documentation to be generated correctly, but it someone can describe how I can do this I'd be happy to check. Closes #25939 from d80tb7/SPARK-27463-fixes. Authored-by: Chris Martin <chris@cmartinit.co.uk> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
39eb79ac4b
commit
76791b89f5
|
@ -51,11 +51,12 @@ from pyspark.sql.dataframe import DataFrame, DataFrameNaFunctions, DataFrameStat
|
|||
from pyspark.sql.group import GroupedData
|
||||
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
|
||||
from pyspark.sql.window import Window, WindowSpec
|
||||
from pyspark.sql.cogroup import CoGroupedData
|
||||
|
||||
|
||||
__all__ = [
|
||||
'SparkSession', 'SQLContext', 'UDFRegistration',
|
||||
'DataFrame', 'GroupedData', 'Column', 'Catalog', 'Row',
|
||||
'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec',
|
||||
'DataFrameReader', 'DataFrameWriter'
|
||||
'DataFrameReader', 'DataFrameWriter', 'CoGroupedData'
|
||||
]
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import sys
|
||||
|
||||
from pyspark import since
|
||||
from pyspark.rdd import PythonEvalType
|
||||
|
@ -43,9 +44,9 @@ class CoGroupedData(object):
|
|||
as a `DataFrame`.
|
||||
|
||||
The user-defined function should take two `pandas.DataFrame` and return another
|
||||
`pandas.DataFrame`. For each side of the cogroup, all columns are passed together
|
||||
as a `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame`
|
||||
are combined as a :class:`DataFrame`.
|
||||
`pandas.DataFrame`. For each side of the cogroup, all columns are passed together as a
|
||||
`pandas.DataFrame` to the user-function and the returned `pandas.DataFrame` are combined as
|
||||
a :class:`DataFrame`.
|
||||
|
||||
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
|
||||
returnType of the pandas udf.
|
||||
|
@ -61,15 +62,16 @@ class CoGroupedData(object):
|
|||
|
||||
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
|
||||
>>> df1 = spark.createDataFrame(
|
||||
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
|
||||
... ("time", "id", "v1"))
|
||||
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
|
||||
... ("time", "id", "v1"))
|
||||
>>> df2 = spark.createDataFrame(
|
||||
... [(20000101, 1, "x"), (20000101, 2, "y")],
|
||||
... ("time", "id", "v2"))
|
||||
>>> @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP)
|
||||
... [(20000101, 1, "x"), (20000101, 2, "y")],
|
||||
... ("time", "id", "v2"))
|
||||
>>> @pandas_udf("time int, id int, v1 double, v2 string",
|
||||
... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP
|
||||
... def asof_join(l, r):
|
||||
... return pd.merge_asof(l, r, on="time", by="id")
|
||||
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show()
|
||||
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP
|
||||
+--------+---+---+---+
|
||||
| time| id| v1| v2|
|
||||
+--------+---+---+---+
|
||||
|
@ -79,6 +81,27 @@ class CoGroupedData(object):
|
|||
|20000102| 2|4.0| y|
|
||||
+--------+---+---+---+
|
||||
|
||||
Alternatively, the user can define a function that takes three arguments. In this case,
|
||||
the grouping key(s) will be passed as the first argument and the data will be passed as the
|
||||
second and third arguments. The grouping key(s) will be passed as a tuple of numpy data
|
||||
types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in as two
|
||||
`pandas.DataFrame` containing all columns from the original Spark DataFrames.
|
||||
|
||||
>>> @pandas_udf("time int, id int, v1 double, v2 string",
|
||||
... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP
|
||||
... def asof_join(k, l, r):
|
||||
... if k == (1,):
|
||||
... return pd.merge_asof(l, r, on="time", by="id")
|
||||
... else:
|
||||
... return pd.DataFrame(columns=['time', 'id', 'v1', 'v2'])
|
||||
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP
|
||||
+--------+---+---+---+
|
||||
| time| id| v1| v2|
|
||||
+--------+---+---+---+
|
||||
|20000101| 1|1.0| x|
|
||||
|20000102| 1|3.0| x|
|
||||
+--------+---+---+---+
|
||||
|
||||
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
|
||||
|
||||
"""
|
||||
|
@ -96,3 +119,25 @@ class CoGroupedData(object):
|
|||
def _extract_cols(gd):
|
||||
df = gd._df
|
||||
return [df[col] for col in df.columns]
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest
|
||||
from pyspark.sql import SparkSession
|
||||
import pyspark.sql.cogroup
|
||||
globs = pyspark.sql.cogroup.__dict__.copy()
|
||||
spark = SparkSession.builder\
|
||||
.master("local[4]")\
|
||||
.appName("sql.cogroup tests")\
|
||||
.getOrCreate()
|
||||
globs['spark'] = spark
|
||||
(failure_count, test_count) = doctest.testmod(
|
||||
pyspark.sql.cogroup, globs=globs,
|
||||
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
|
||||
spark.stop()
|
||||
if failure_count:
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test()
|
||||
|
|
|
@ -32,14 +32,9 @@ if have_pyarrow:
|
|||
import pyarrow as pa
|
||||
|
||||
|
||||
"""
|
||||
Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
|
||||
from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
|
||||
"""
|
||||
if sys.version < '3':
|
||||
_check_column_type = False
|
||||
else:
|
||||
_check_column_type = True
|
||||
# Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
|
||||
# From kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
|
||||
_check_column_type = sys.version >= '3'
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
|
|
|
@ -37,14 +37,9 @@ if have_pyarrow:
|
|||
import pyarrow as pa
|
||||
|
||||
|
||||
"""
|
||||
Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
|
||||
from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
|
||||
"""
|
||||
if sys.version < '3':
|
||||
_check_column_type = False
|
||||
else:
|
||||
_check_column_type = True
|
||||
# Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
|
||||
# from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
|
||||
_check_column_type = sys.version >= '3'
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
|
|
|
@ -47,8 +47,8 @@ import org.apache.spark.sql.types.{NumericType, StructType}
|
|||
*/
|
||||
@Stable
|
||||
class RelationalGroupedDataset protected[sql](
|
||||
val df: DataFrame,
|
||||
val groupingExprs: Seq[Expression],
|
||||
private[sql] val df: DataFrame,
|
||||
private[sql] val groupingExprs: Seq[Expression],
|
||||
groupType: RelationalGroupedDataset.GroupType) {
|
||||
|
||||
private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
|
||||
|
|
Loading…
Reference in a new issue