[SPARK-27463][PYTHON][FOLLOW-UP] Miscellaneous documentation and code cleanup of cogroup pandas UDF

Follow up from https://github.com/apache/spark/pull/24981 incorporating some comments from HyukjinKwon.

Specifically:

- Adding `CoGroupedData` to `pyspark/sql/__init__.py __all__` so that documentation is generated.
- Added pydoc, including example, for the use case whereby the user supplies a cogrouping function including a key.
- Added the boilerplate for doctests to cogroup.py.  Note that cogroup.py only contains the apply() function which has doctests disabled as per the  other Pandas Udfs.
- Restricted the newly exposed RelationalGroupedDataset constructor parameters to access only by the sql package.
- Some minor  formatting tweaks.

This was tested by running the appropriate unit tests.  I'm unsure as to how to check that my change will cause the documentation to be generated correctly, but it someone can describe how I can do this I'd be happy to check.

Closes #25939 from d80tb7/SPARK-27463-fixes.

Authored-by: Chris Martin <chris@cmartinit.co.uk>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
Chris Martin 2019-09-30 22:25:35 +09:00 committed by HyukjinKwon
parent 39eb79ac4b
commit 76791b89f5
5 changed files with 64 additions and 28 deletions

View file

@ -51,11 +51,12 @@ from pyspark.sql.dataframe import DataFrame, DataFrameNaFunctions, DataFrameStat
from pyspark.sql.group import GroupedData
from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter
from pyspark.sql.window import Window, WindowSpec
from pyspark.sql.cogroup import CoGroupedData
__all__ = [
'SparkSession', 'SQLContext', 'UDFRegistration',
'DataFrame', 'GroupedData', 'Column', 'Catalog', 'Row',
'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec',
'DataFrameReader', 'DataFrameWriter'
'DataFrameReader', 'DataFrameWriter', 'CoGroupedData'
]

View file

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
from pyspark import since
from pyspark.rdd import PythonEvalType
@ -43,9 +44,9 @@ class CoGroupedData(object):
as a `DataFrame`.
The user-defined function should take two `pandas.DataFrame` and return another
`pandas.DataFrame`. For each side of the cogroup, all columns are passed together
as a `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame`
are combined as a :class:`DataFrame`.
`pandas.DataFrame`. For each side of the cogroup, all columns are passed together as a
`pandas.DataFrame` to the user-function and the returned `pandas.DataFrame` are combined as
a :class:`DataFrame`.
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
returnType of the pandas udf.
@ -61,15 +62,16 @@ class CoGroupedData(object):
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> df1 = spark.createDataFrame(
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
... ("time", "id", "v1"))
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
... ("time", "id", "v1"))
>>> df2 = spark.createDataFrame(
... [(20000101, 1, "x"), (20000101, 2, "y")],
... ("time", "id", "v2"))
>>> @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP)
... [(20000101, 1, "x"), (20000101, 2, "y")],
... ("time", "id", "v2"))
>>> @pandas_udf("time int, id int, v1 double, v2 string",
... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP
... def asof_join(l, r):
... return pd.merge_asof(l, r, on="time", by="id")
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show()
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP
+--------+---+---+---+
| time| id| v1| v2|
+--------+---+---+---+
@ -79,6 +81,27 @@ class CoGroupedData(object):
|20000102| 2|4.0| y|
+--------+---+---+---+
Alternatively, the user can define a function that takes three arguments. In this case,
the grouping key(s) will be passed as the first argument and the data will be passed as the
second and third arguments. The grouping key(s) will be passed as a tuple of numpy data
types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in as two
`pandas.DataFrame` containing all columns from the original Spark DataFrames.
>>> @pandas_udf("time int, id int, v1 double, v2 string",
... PandasUDFType.COGROUPED_MAP) # doctest: +SKIP
... def asof_join(k, l, r):
... if k == (1,):
... return pd.merge_asof(l, r, on="time", by="id")
... else:
... return pd.DataFrame(columns=['time', 'id', 'v1', 'v2'])
>>> df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() # doctest: +SKIP
+--------+---+---+---+
| time| id| v1| v2|
+--------+---+---+---+
|20000101| 1|1.0| x|
|20000102| 1|3.0| x|
+--------+---+---+---+
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
"""
@ -96,3 +119,25 @@ class CoGroupedData(object):
def _extract_cols(gd):
df = gd._df
return [df[col] for col in df.columns]
def _test():
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.cogroup
globs = pyspark.sql.cogroup.__dict__.copy()
spark = SparkSession.builder\
.master("local[4]")\
.appName("sql.cogroup tests")\
.getOrCreate()
globs['spark'] = spark
(failure_count, test_count) = doctest.testmod(
pyspark.sql.cogroup, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
spark.stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()

View file

@ -32,14 +32,9 @@ if have_pyarrow:
import pyarrow as pa
"""
Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
"""
if sys.version < '3':
_check_column_type = False
else:
_check_column_type = True
# Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
# From kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
_check_column_type = sys.version >= '3'
@unittest.skipIf(

View file

@ -37,14 +37,9 @@ if have_pyarrow:
import pyarrow as pa
"""
Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
"""
if sys.version < '3':
_check_column_type = False
else:
_check_column_type = True
# Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
# from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
_check_column_type = sys.version >= '3'
@unittest.skipIf(

View file

@ -47,8 +47,8 @@ import org.apache.spark.sql.types.{NumericType, StructType}
*/
@Stable
class RelationalGroupedDataset protected[sql](
val df: DataFrame,
val groupingExprs: Seq[Expression],
private[sql] val df: DataFrame,
private[sql] val groupingExprs: Seq[Expression],
groupType: RelationalGroupedDataset.GroupType) {
private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {