From 76791b89f58f9a0116bdb4b9dbd53e482c161722 Mon Sep 17 00:00:00 2001 From: Chris Martin Date: Mon, 30 Sep 2019 22:25:35 +0900 Subject: [PATCH] [SPARK-27463][PYTHON][FOLLOW-UP] Miscellaneous documentation and code cleanup of cogroup pandas UDF Follow up from https://github.com/apache/spark/pull/24981 incorporating some comments from HyukjinKwon. Specifically: - Adding `CoGroupedData` to `pyspark/sql/__init__.py __all__` so that documentation is generated. - Added pydoc, including example, for the use case whereby the user supplies a cogrouping function including a key. - Added the boilerplate for doctests to cogroup.py. Note that cogroup.py only contains the apply() function which has doctests disabled as per the other Pandas Udfs. - Restricted the newly exposed RelationalGroupedDataset constructor parameters to access only by the sql package. - Some minor formatting tweaks. This was tested by running the appropriate unit tests. I'm unsure as to how to check that my change will cause the documentation to be generated correctly, but it someone can describe how I can do this I'd be happy to check. Closes #25939 from d80tb7/SPARK-27463-fixes. Authored-by: Chris Martin Signed-off-by: HyukjinKwon --- python/pyspark/sql/__init__.py | 3 +- python/pyspark/sql/cogroup.py | 63 ++++++++++++++++--- .../tests/test_pandas_udf_cogrouped_map.py | 11 +--- .../sql/tests/test_pandas_udf_grouped_map.py | 11 +--- .../spark/sql/RelationalGroupedDataset.scala | 4 +- 5 files changed, 64 insertions(+), 28 deletions(-) diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 9c760e3527..ba4c4feec7 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -51,11 +51,12 @@ from pyspark.sql.dataframe import DataFrame, DataFrameNaFunctions, DataFrameStat from pyspark.sql.group import GroupedData from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter from pyspark.sql.window import Window, WindowSpec +from pyspark.sql.cogroup import CoGroupedData __all__ = [ 'SparkSession', 'SQLContext', 'UDFRegistration', 'DataFrame', 'GroupedData', 'Column', 'Catalog', 'Row', 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', - 'DataFrameReader', 'DataFrameWriter' + 'DataFrameReader', 'DataFrameWriter', 'CoGroupedData' ] diff --git a/python/pyspark/sql/cogroup.py b/python/pyspark/sql/cogroup.py index 9b725e4baf..ef87e703bc 100644 --- a/python/pyspark/sql/cogroup.py +++ b/python/pyspark/sql/cogroup.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import sys from pyspark import since from pyspark.rdd import PythonEvalType @@ -43,9 +44,9 @@ class CoGroupedData(object): as a `DataFrame`. The user-defined function should take two `pandas.DataFrame` and return another - `pandas.DataFrame`. For each side of the cogroup, all columns are passed together - as a `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame` - are combined as a :class:`DataFrame`. + `pandas.DataFrame`. For each side of the cogroup, all columns are passed together as a + `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame` are combined as + a :class:`DataFrame`. The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the returnType of the pandas udf. @@ -61,15 +62,16 @@ class CoGroupedData(object): >>> from pyspark.sql.functions import pandas_udf, PandasUDFType >>> df1 = spark.createDataFrame( - ... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], - ... ("time", "id", "v1")) + ... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], + ... ("time", "id", "v1")) >>> df2 = spark.createDataFrame( - ... [(20000101, 1, "x"), (20000101, 2, "y")], - ... ("time", "id", "v2")) - >>> @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP) + ... [(20000101, 1, "x"), (20000101, 2, "y")], + ... ("time", "id", "v2")) + >>> @pandas_udf("time int, id int, v1 double, v2 string", + ... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP ... def asof_join(l, r): ... return pd.merge_asof(l, r, on="time", by="id") - >>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() + >>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP +--------+---+---+---+ | time| id| v1| v2| +--------+---+---+---+ @@ -79,6 +81,27 @@ class CoGroupedData(object): |20000102| 2|4.0| y| +--------+---+---+---+ + Alternatively, the user can define a function that takes three arguments. In this case, + the grouping key(s) will be passed as the first argument and the data will be passed as the + second and third arguments. The grouping key(s) will be passed as a tuple of numpy data + types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in as two + `pandas.DataFrame` containing all columns from the original Spark DataFrames. + + >>> @pandas_udf("time int, id int, v1 double, v2 string", + ... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP + ... def asof_join(k, l, r): + ... if k == (1,): + ... return pd.merge_asof(l, r, on="time", by="id") + ... else: + ... return pd.DataFrame(columns=['time', 'id', 'v1', 'v2']) + >>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP + +--------+---+---+---+ + | time| id| v1| v2| + +--------+---+---+---+ + |20000101| 1|1.0| x| + |20000102| 1|3.0| x| + +--------+---+---+---+ + .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` """ @@ -96,3 +119,25 @@ class CoGroupedData(object): def _extract_cols(gd): df = gd._df return [df[col] for col in df.columns] + + +def _test(): + import doctest + from pyspark.sql import SparkSession + import pyspark.sql.cogroup + globs = pyspark.sql.cogroup.__dict__.copy() + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.cogroup tests")\ + .getOrCreate() + globs['spark'] = spark + (failure_count, test_count) = doctest.testmod( + pyspark.sql.cogroup, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + spark.stop() + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py index 7f3f7fa316..bc2265fc5f 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py @@ -32,14 +32,9 @@ if have_pyarrow: import pyarrow as pa -""" -Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names -from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check -""" -if sys.version < '3': - _check_column_type = False -else: - _check_column_type = True +# Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names +# From kwargs w/ Python 2, so need to set check_column_type=False and avoid this check +_check_column_type = sys.version >= '3' @unittest.skipIf( diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py index adbe2d103a..8918d5ac0c 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py @@ -37,14 +37,9 @@ if have_pyarrow: import pyarrow as pa -""" -Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names -from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check -""" -if sys.version < '3': - _check_column_type = False -else: - _check_column_type = True +# Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names +# from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check +_check_column_type = sys.version >= '3' @unittest.skipIf( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index f6d13be0e8..4d47318707 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -47,8 +47,8 @@ import org.apache.spark.sql.types.{NumericType, StructType} */ @Stable class RelationalGroupedDataset protected[sql]( - val df: DataFrame, - val groupingExprs: Seq[Expression], + private[sql] val df: DataFrame, + private[sql] val groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) { private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {