[SPARK-36388][SPARK-36386][PYTHON] Fix DataFrame groupby-rolling and groupby-expanding to follow pandas 1.3

This PR proposes to fix `RollingGroupBy` and `ExpandingGroupBy` to follow latest pandas behavior.

`RollingGroupBy` and `ExpandingGroupBy` no longer returns grouped-by column in values from pandas 1.3.

Before:
```python
>>> df = pd.DataFrame({"A": [1, 1, 2, 3], "B": [0, 1, 2, 3]})
>>> df.groupby("A").rolling(2).sum()
       A    B
A
1 0  NaN  NaN
  1  2.0  1.0
2 2  NaN  NaN
3 3  NaN  NaN
```

After:
```python
>>> df = pd.DataFrame({"A": [1, 1, 2, 3], "B": [0, 1, 2, 3]})
>>> df.groupby("A").rolling(2).sum()
       B
A
1 0  NaN
  1  1.0
2 2  NaN
3 3  NaN
```

We should follow the behavior of pandas as much as possible.

Yes, the result of `RollingGroupBy` and `ExpandingGroupBy` is changed as described above.

Unit tests.

Closes #33646 from itholic/SPARK-36388.

Authored-by: itholic <haejoon.lee@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
(cherry picked from commit b8508f4876)
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
itholic 2021-08-10 10:12:52 +09:00 committed by Hyukjin Kwon
parent f2f09e4cdb
commit 0fc8c393b4
4 changed files with 199 additions and 148 deletions

View file

@ -125,7 +125,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
groupkeys: List[Series],
as_index: bool,
dropna: bool,
column_labels_to_exlcude: Set[Label],
column_labels_to_exclude: Set[Label],
agg_columns_selected: bool,
agg_columns: List[Series],
):
@ -133,7 +133,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
self._groupkeys = groupkeys
self._as_index = as_index
self._dropna = dropna
self._column_labels_to_exlcude = column_labels_to_exlcude
self._column_labels_to_exclude = column_labels_to_exclude
self._agg_columns_selected = agg_columns_selected
self._agg_columns = agg_columns
@ -1175,7 +1175,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
agg_columns = [
psdf._psser_for(label)
for label in psdf._internal.column_labels
if label not in self._column_labels_to_exlcude
if label not in self._column_labels_to_exclude
]
psdf, groupkey_labels, groupkey_names = GroupBy._prepare_group_map_apply(
@ -1372,7 +1372,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
agg_columns = [
psdf._psser_for(label)
for label in psdf._internal.column_labels
if label not in self._column_labels_to_exlcude
if label not in self._column_labels_to_exclude
]
data_schema = (
@ -1890,7 +1890,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
agg_columns = [
psdf._psser_for(label)
for label in psdf._internal.column_labels
if label not in self._column_labels_to_exlcude
if label not in self._column_labels_to_exclude
]
psdf, groupkey_labels, _ = GroupBy._prepare_group_map_apply(
@ -2708,17 +2708,17 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
(
psdf,
new_by_series,
column_labels_to_exlcude,
column_labels_to_exclude,
) = GroupBy._resolve_grouping_from_diff_dataframes(psdf, by)
else:
new_by_series = GroupBy._resolve_grouping(psdf, by)
column_labels_to_exlcude = set()
column_labels_to_exclude = set()
return DataFrameGroupBy(
psdf,
new_by_series,
as_index=as_index,
dropna=dropna,
column_labels_to_exlcude=column_labels_to_exlcude,
column_labels_to_exclude=column_labels_to_exclude,
)
def __init__(
@ -2727,20 +2727,20 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
by: List[Series],
as_index: bool,
dropna: bool,
column_labels_to_exlcude: Set[Label],
column_labels_to_exclude: Set[Label],
agg_columns: List[Label] = None,
):
agg_columns_selected = agg_columns is not None
if agg_columns_selected:
for label in agg_columns:
if label in column_labels_to_exlcude:
if label in column_labels_to_exclude:
raise KeyError(label)
else:
agg_columns = [
label
for label in psdf._internal.column_labels
if not any(label == key._column_label and key._psdf is psdf for key in by)
and label not in column_labels_to_exlcude
and label not in column_labels_to_exclude
]
super().__init__(
@ -2748,7 +2748,7 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
groupkeys=by,
as_index=as_index,
dropna=dropna,
column_labels_to_exlcude=column_labels_to_exlcude,
column_labels_to_exclude=column_labels_to_exclude,
agg_columns_selected=agg_columns_selected,
agg_columns=[psdf[label] for label in agg_columns],
)
@ -2788,7 +2788,7 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
self._groupkeys,
as_index=self._as_index,
dropna=self._dropna,
column_labels_to_exlcude=self._column_labels_to_exlcude,
column_labels_to_exclude=self._column_labels_to_exclude,
agg_columns=item,
)
@ -2932,7 +2932,7 @@ class SeriesGroupBy(GroupBy[Series]):
groupkeys=by,
as_index=True,
dropna=dropna,
column_labels_to_exlcude=set(),
column_labels_to_exclude=set(),
agg_columns_selected=True,
agg_columns=[psser],
)

View file

@ -146,10 +146,8 @@ class ExpandingTest(PandasOnSparkTestCase, TestUtils):
pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]})
psdf = ps.from_pandas(pdf)
# The behavior of GroupBy.expanding is changed from pandas 1.3.
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
pass
else:
self.assert_eq(
getattr(psdf.groupby(psdf.a).expanding(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a).expanding(2), f)().sort_index(),
@ -162,6 +160,19 @@ class ExpandingTest(PandasOnSparkTestCase, TestUtils):
getattr(psdf.groupby(psdf.a + 1).expanding(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a + 1).expanding(2), f)().sort_index(),
)
else:
self.assert_eq(
getattr(psdf.groupby(psdf.a).expanding(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a).expanding(2), f)().drop("a", axis=1).sort_index(),
)
self.assert_eq(
getattr(psdf.groupby(psdf.a).expanding(2), f)().sum(),
getattr(pdf.groupby(pdf.a).expanding(2), f)().sum().drop("a"),
)
self.assert_eq(
getattr(psdf.groupby(psdf.a + 1).expanding(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a + 1).expanding(2), f)().drop("a", axis=1).sort_index(),
)
self.assert_eq(
getattr(psdf.b.groupby(psdf.a).expanding(2), f)().sort_index(),
@ -181,10 +192,8 @@ class ExpandingTest(PandasOnSparkTestCase, TestUtils):
pdf.columns = columns
psdf.columns = columns
# The behavior of GroupBy.expanding is changed from pandas 1.3.
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
pass
else:
self.assert_eq(
getattr(psdf.groupby(("a", "x")).expanding(2), f)().sort_index(),
getattr(pdf.groupby(("a", "x")).expanding(2), f)().sort_index(),
@ -194,6 +203,20 @@ class ExpandingTest(PandasOnSparkTestCase, TestUtils):
getattr(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)().sort_index(),
getattr(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)().sort_index(),
)
else:
self.assert_eq(
getattr(psdf.groupby(("a", "x")).expanding(2), f)().sort_index(),
getattr(pdf.groupby(("a", "x")).expanding(2), f)()
.drop(("a", "x"), axis=1)
.sort_index(),
)
self.assert_eq(
getattr(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)().sort_index(),
getattr(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)()
.drop([("a", "x"), ("a", "y")], axis=1)
.sort_index(),
)
def test_groupby_expanding_count(self):
# The behaviour of ExpandingGroupby.count are different between pandas>=1.0.0 and lower,

View file

@ -112,10 +112,8 @@ class RollingTest(PandasOnSparkTestCase, TestUtils):
pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]})
psdf = ps.from_pandas(pdf)
# The behavior of GroupBy.rolling is changed from pandas 1.3.
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
pass
else:
self.assert_eq(
getattr(psdf.groupby(psdf.a).rolling(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a).rolling(2), f)().sort_index(),
@ -128,6 +126,19 @@ class RollingTest(PandasOnSparkTestCase, TestUtils):
getattr(psdf.groupby(psdf.a + 1).rolling(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a + 1).rolling(2), f)().sort_index(),
)
else:
self.assert_eq(
getattr(psdf.groupby(psdf.a).rolling(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a).rolling(2), f)().drop("a", axis=1).sort_index(),
)
self.assert_eq(
getattr(psdf.groupby(psdf.a).rolling(2), f)().sum(),
getattr(pdf.groupby(pdf.a).rolling(2), f)().sum().drop("a"),
)
self.assert_eq(
getattr(psdf.groupby(psdf.a + 1).rolling(2), f)().sort_index(),
getattr(pdf.groupby(pdf.a + 1).rolling(2), f)().drop("a", axis=1).sort_index(),
)
self.assert_eq(
getattr(psdf.b.groupby(psdf.a).rolling(2), f)().sort_index(),
@ -147,10 +158,8 @@ class RollingTest(PandasOnSparkTestCase, TestUtils):
pdf.columns = columns
psdf.columns = columns
# The behavior of GroupBy.rolling is changed from pandas 1.3.
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
pass
else:
self.assert_eq(
getattr(psdf.groupby(("a", "x")).rolling(2), f)().sort_index(),
getattr(pdf.groupby(("a", "x")).rolling(2), f)().sort_index(),
@ -160,6 +169,20 @@ class RollingTest(PandasOnSparkTestCase, TestUtils):
getattr(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)().sort_index(),
getattr(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)().sort_index(),
)
else:
self.assert_eq(
getattr(psdf.groupby(("a", "x")).rolling(2), f)().sort_index(),
getattr(pdf.groupby(("a", "x")).rolling(2), f)()
.drop(("a", "x"), axis=1)
.sort_index(),
)
self.assert_eq(
getattr(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)().sort_index(),
getattr(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)()
.drop([("a", "x"), ("a", "y")], axis=1)
.sort_index(),
)
def test_groupby_rolling_count(self):
self._test_groupby_rolling_func("count")

View file

@ -36,7 +36,7 @@ from pyspark.pandas.missing.window import (
# For running doctests and reference resolution in PyCharm.
from pyspark import pandas as ps # noqa: F401
from pyspark.pandas._typing import FrameLike
from pyspark.pandas.groupby import GroupBy
from pyspark.pandas.groupby import GroupBy, DataFrameGroupBy
from pyspark.pandas.internal import NATURAL_ORDER_COLUMN_NAME, SPARK_INDEX_NAME_FORMAT
from pyspark.pandas.spark import functions as SF
from pyspark.pandas.utils import scol_for
@ -706,10 +706,15 @@ class RollingGroupby(RollingLike[FrameLike]):
if groupby._agg_columns_selected:
agg_columns = groupby._agg_columns
else:
# pandas doesn't keep the groupkey as a column from 1.3 for DataFrameGroupBy
column_labels_to_exclude = groupby._column_labels_to_exclude.copy()
if isinstance(groupby, DataFrameGroupBy):
for groupkey in groupby._groupkeys: # type: ignore
column_labels_to_exclude.add(groupkey._internal.column_labels[0])
agg_columns = [
psdf._psser_for(label)
for label in psdf._internal.column_labels
if label not in groupby._column_labels_to_exlcude
if label not in column_labels_to_exclude
]
applied = []
@ -777,19 +782,19 @@ class RollingGroupby(RollingLike[FrameLike]):
>>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2})
>>> df.groupby(df.A).rolling(2).count().sort_index() # doctest: +NORMALIZE_WHITESPACE
A B
B
A
2 0 1.0 1.0
1 2.0 2.0
3 2 1.0 1.0
3 2.0 2.0
4 2.0 2.0
4 5 1.0 1.0
6 2.0 2.0
7 2.0 2.0
8 2.0 2.0
5 9 1.0 1.0
10 2.0 2.0
2 0 1.0
1 2.0
3 2 1.0
3 2.0
4 2.0
4 5 1.0
6 2.0
7 2.0
8 2.0
5 9 1.0
10 2.0
"""
return super().count()
@ -831,19 +836,19 @@ class RollingGroupby(RollingLike[FrameLike]):
>>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2})
>>> df.groupby(df.A).rolling(2).sum().sort_index() # doctest: +NORMALIZE_WHITESPACE
A B
B
A
2 0 NaN NaN
1 4.0 8.0
3 2 NaN NaN
3 6.0 18.0
4 6.0 18.0
4 5 NaN NaN
6 8.0 32.0
7 8.0 32.0
8 8.0 32.0
5 9 NaN NaN
10 10.0 50.0
2 0 NaN
1 8.0
3 2 NaN
3 18.0
4 18.0
4 5 NaN
6 32.0
7 32.0
8 32.0
5 9 NaN
10 50.0
"""
return super().sum()
@ -885,19 +890,19 @@ class RollingGroupby(RollingLike[FrameLike]):
>>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2})
>>> df.groupby(df.A).rolling(2).min().sort_index() # doctest: +NORMALIZE_WHITESPACE
A B
B
A
2 0 NaN NaN
1 2.0 4.0
3 2 NaN NaN
3 3.0 9.0
4 3.0 9.0
4 5 NaN NaN
6 4.0 16.0
7 4.0 16.0
8 4.0 16.0
5 9 NaN NaN
10 5.0 25.0
2 0 NaN
1 4.0
3 2 NaN
3 9.0
4 9.0
4 5 NaN
6 16.0
7 16.0
8 16.0
5 9 NaN
10 25.0
"""
return super().min()
@ -939,19 +944,19 @@ class RollingGroupby(RollingLike[FrameLike]):
>>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2})
>>> df.groupby(df.A).rolling(2).max().sort_index() # doctest: +NORMALIZE_WHITESPACE
A B
B
A
2 0 NaN NaN
1 2.0 4.0
3 2 NaN NaN
3 3.0 9.0
4 3.0 9.0
4 5 NaN NaN
6 4.0 16.0
7 4.0 16.0
8 4.0 16.0
5 9 NaN NaN
10 5.0 25.0
2 0 NaN
1 4.0
3 2 NaN
3 9.0
4 9.0
4 5 NaN
6 16.0
7 16.0
8 16.0
5 9 NaN
10 25.0
"""
return super().max()
@ -993,19 +998,19 @@ class RollingGroupby(RollingLike[FrameLike]):
>>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2})
>>> df.groupby(df.A).rolling(2).mean().sort_index() # doctest: +NORMALIZE_WHITESPACE
A B
B
A
2 0 NaN NaN
1 2.0 4.0
3 2 NaN NaN
3 3.0 9.0
4 3.0 9.0
4 5 NaN NaN
6 4.0 16.0
7 4.0 16.0
8 4.0 16.0
5 9 NaN NaN
10 5.0 25.0
2 0 NaN
1 4.0
3 2 NaN
3 9.0
4 9.0
4 5 NaN
6 16.0
7 16.0
8 16.0
5 9 NaN
10 25.0
"""
return super().mean()
@ -1478,19 +1483,19 @@ class ExpandingGroupby(ExpandingLike[FrameLike]):
>>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2})
>>> df.groupby(df.A).expanding(2).count().sort_index() # doctest: +NORMALIZE_WHITESPACE
A B
B
A
2 0 NaN NaN
1 2.0 2.0
3 2 NaN NaN
3 2.0 2.0
4 3.0 3.0
4 5 NaN NaN
6 2.0 2.0
7 3.0 3.0
8 4.0 4.0
5 9 NaN NaN
10 2.0 2.0
2 0 NaN
1 2.0
3 2 NaN
3 2.0
4 3.0
4 5 NaN
6 2.0
7 3.0
8 4.0
5 9 NaN
10 2.0
"""
return super().count()
@ -1532,19 +1537,19 @@ class ExpandingGroupby(ExpandingLike[FrameLike]):
>>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2})
>>> df.groupby(df.A).expanding(2).sum().sort_index() # doctest: +NORMALIZE_WHITESPACE
A B
B
A
2 0 NaN NaN
1 4.0 8.0
3 2 NaN NaN
3 6.0 18.0
4 9.0 27.0
4 5 NaN NaN
6 8.0 32.0
7 12.0 48.0
8 16.0 64.0
5 9 NaN NaN
10 10.0 50.0
2 0 NaN
1 8.0
3 2 NaN
3 18.0
4 27.0
4 5 NaN
6 32.0
7 48.0
8 64.0
5 9 NaN
10 50.0
"""
return super().sum()
@ -1586,19 +1591,19 @@ class ExpandingGroupby(ExpandingLike[FrameLike]):
>>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2})
>>> df.groupby(df.A).expanding(2).min().sort_index() # doctest: +NORMALIZE_WHITESPACE
A B
B
A
2 0 NaN NaN
1 2.0 4.0
3 2 NaN NaN
3 3.0 9.0
4 3.0 9.0
4 5 NaN NaN
6 4.0 16.0
7 4.0 16.0
8 4.0 16.0
5 9 NaN NaN
10 5.0 25.0
2 0 NaN
1 4.0
3 2 NaN
3 9.0
4 9.0
4 5 NaN
6 16.0
7 16.0
8 16.0
5 9 NaN
10 25.0
"""
return super().min()
@ -1639,19 +1644,19 @@ class ExpandingGroupby(ExpandingLike[FrameLike]):
>>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2})
>>> df.groupby(df.A).expanding(2).max().sort_index() # doctest: +NORMALIZE_WHITESPACE
A B
B
A
2 0 NaN NaN
1 2.0 4.0
3 2 NaN NaN
3 3.0 9.0
4 3.0 9.0
4 5 NaN NaN
6 4.0 16.0
7 4.0 16.0
8 4.0 16.0
5 9 NaN NaN
10 5.0 25.0
2 0 NaN
1 4.0
3 2 NaN
3 9.0
4 9.0
4 5 NaN
6 16.0
7 16.0
8 16.0
5 9 NaN
10 25.0
"""
return super().max()
@ -1693,19 +1698,19 @@ class ExpandingGroupby(ExpandingLike[FrameLike]):
>>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2})
>>> df.groupby(df.A).expanding(2).mean().sort_index() # doctest: +NORMALIZE_WHITESPACE
A B
B
A
2 0 NaN NaN
1 2.0 4.0
3 2 NaN NaN
3 3.0 9.0
4 3.0 9.0
4 5 NaN NaN
6 4.0 16.0
7 4.0 16.0
8 4.0 16.0
5 9 NaN NaN
10 5.0 25.0
2 0 NaN
1 4.0
3 2 NaN
3 9.0
4 9.0
4 5 NaN
6 16.0
7 16.0
8 16.0
5 9 NaN
10 25.0
"""
return super().mean()