[SPARK-35473][PYTHON] Fix disallow_untyped_defs mypy checks for pyspark.pandas.groupby

### What changes were proposed in this pull request?

Adds more type annotations in the file `python/pyspark/pandas/groupby.py` and fixes the mypy check failures.

### Why are the changes needed?

We should enable more disallow_untyped_defs mypy checks.

### Does this PR introduce _any_ user-facing change?

Yes.
This PR adds more type annotations in pandas APIs on Spark module, which can impact interaction with development tools for users.

### How was this patch tested?

The mypy check with a new configuration and existing tests should pass.

Closes #33032 from ueshin/issues/SPARK-35473/disallow_untyped_defs_groupby.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
Takuya UESHIN 2021-06-23 09:51:33 +09:00 committed by Hyukjin Kwon
parent 7a21e9c48f
commit 68b54b702c
6 changed files with 198 additions and 135 deletions

View file

@ -162,8 +162,5 @@ disallow_untyped_defs = False
[mypy-pyspark.pandas.frame]
disallow_untyped_defs = False
[mypy-pyspark.pandas.groupby]
disallow_untyped_defs = False
[mypy-pyspark.pandas.series]
disallow_untyped_defs = False

View file

@ -1129,7 +1129,11 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
return self._shift(periods, fill_value).spark.analyzed
def _shift(
self: T_IndexOps, periods: int, fill_value: Any, *, part_cols: Sequence[str] = ()
self: T_IndexOps,
periods: int,
fill_value: Any,
*,
part_cols: Sequence[Union[str, Column]] = ()
) -> T_IndexOps:
if not isinstance(periods, int):
raise TypeError("periods should be an int; however, got [%s]" % type(periods).__name__)

View file

@ -123,6 +123,7 @@ from pyspark.pandas.typedef import (
from pyspark.pandas.plot import PandasOnSparkPlotAccessor
if TYPE_CHECKING:
from pyspark.pandas.groupby import DataFrameGroupBy # noqa: F401 (SPARK-34943)
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
@ -11587,6 +11588,13 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
"""
return DataFrame(pd.DataFrame.from_dict(data, orient=orient, dtype=dtype, columns=columns))
def _build_groupby(
self, by: List[Union["Series", Tuple]], as_index: bool, dropna: bool
) -> "DataFrameGroupBy":
from pyspark.pandas.groupby import DataFrameGroupBy
return DataFrameGroupBy._build(self, by, as_index=as_index, dropna=dropna)
def _to_internal_pandas(self):
"""
Return a pandas DataFrame directly from _internal to avoid overhead of copy.

View file

@ -72,7 +72,7 @@ from pyspark.pandas.window import Rolling, Expanding
if TYPE_CHECKING:
from pyspark.pandas.frame import DataFrame # noqa: F401 (SPARK-34943)
from pyspark.pandas.indexes.base import Index # noqa: F401 (SPARK-34943)
from pyspark.pandas.groupby import DataFrameGroupBy, SeriesGroupBy # noqa: F401 (SPARK-34943)
from pyspark.pandas.groupby import GroupBy # noqa: F401 (SPARK-34943)
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
@ -2114,12 +2114,12 @@ class Frame(object, metaclass=ABCMeta):
# TODO: by argument only support the grouping name and as_index only for now. Documentation
# should be updated when it's supported.
def groupby(
self,
self: T_Frame,
by: Union[Any, Tuple, "Series", List[Union[Any, Tuple, "Series"]]],
axis: Union[int, str] = 0,
as_index: bool = True,
dropna: bool = True,
) -> Union["DataFrameGroupBy", "SeriesGroupBy"]:
) -> "GroupBy[T_Frame]":
"""
Group DataFrame or Series using a Series of columns.
@ -2199,8 +2199,6 @@ class Frame(object, metaclass=ABCMeta):
2.0 2 5
NaN 1 4
"""
from pyspark.pandas.groupby import DataFrameGroupBy, SeriesGroupBy
if isinstance(by, ps.DataFrame):
raise ValueError("Grouper for '{}' not 1-dimensional".format(type(by).__name__))
elif isinstance(by, ps.Series):
@ -2242,14 +2240,13 @@ class Frame(object, metaclass=ABCMeta):
if axis != 0:
raise NotImplementedError('axis should be either 0 or "index" currently.')
if isinstance(self, ps.DataFrame):
return DataFrameGroupBy._build(self, new_by, as_index=as_index, dropna=dropna)
elif isinstance(self, ps.Series):
return SeriesGroupBy._build(self, new_by, as_index=as_index, dropna=dropna)
else:
raise TypeError(
"Constructor expects DataFrame or Series; however, " "got [%s]" % (self,)
)
return self._build_groupby(by=new_by, as_index=as_index, dropna=dropna)
@abstractmethod
def _build_groupby(
self: T_Frame, by: List[Union["Series", Tuple]], as_index: bool, dropna: bool
) -> "GroupBy[T_Frame]":
pass
def bool(self) -> bool:
"""

View file

@ -26,7 +26,21 @@ from collections import OrderedDict, namedtuple
from distutils.version import LooseVersion
from functools import partial
from itertools import product
from typing import Any, Callable, List, Set, Tuple, Union, cast
from typing import (
Any,
Callable,
Dict,
Generic,
Mapping,
List,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
cast,
)
import pandas as pd
from pandas.api.types import is_hashable, is_list_like
@ -45,6 +59,7 @@ from pyspark.sql.types import ( # noqa: F401
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
from pyspark.pandas.typedef import infer_return_type, DataFrameType, ScalarType, SeriesType
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.generic import Frame
from pyspark.pandas.internal import (
InternalField,
InternalFrame,
@ -75,8 +90,10 @@ from pyspark.pandas.exceptions import DataError
# to keep it the same as pandas
NamedAgg = namedtuple("NamedAgg", ["column", "aggfunc"])
T_Frame = TypeVar("T_Frame", bound=Frame)
class GroupBy(object, metaclass=ABCMeta):
class GroupBy(Generic[T_Frame], metaclass=ABCMeta):
"""
:ivar _psdf: The parent dataframe that is used to perform the groupby
:type _psdf: DataFrame
@ -103,20 +120,36 @@ class GroupBy(object, metaclass=ABCMeta):
self._agg_columns = agg_columns
@property
def _groupkeys_scols(self):
def _groupkeys_scols(self) -> List[Column]:
return [s.spark.column for s in self._groupkeys]
@property
def _agg_columns_scols(self):
def _agg_columns_scols(self) -> List[Column]:
return [s.spark.column for s in self._agg_columns]
@abstractmethod
def _apply_series_op(self, op, should_resolve: bool = False, numeric_only: bool = False):
def _apply_series_op(
self,
op: Callable[["SeriesGroupBy"], Series],
should_resolve: bool = False,
numeric_only: bool = False,
) -> T_Frame:
pass
@abstractmethod
def _cleanup_and_return(self, psdf: DataFrame) -> T_Frame:
pass
# TODO: Series support is not implemented yet.
# TODO: not all arguments are implemented comparing to pandas' for now.
def aggregate(self, func_or_funcs=None, *args, **kwargs) -> DataFrame:
def aggregate(
self,
func_or_funcs: Optional[
Union[str, List[str], Dict[Union[Any, Tuple], Union[str, List[str]]]]
] = None,
*args: Any,
**kwargs: Any
) -> DataFrame:
"""Aggregate using one or more operations over the specified axis.
Parameters
@ -223,7 +256,7 @@ class GroupBy(object, metaclass=ABCMeta):
relabeling = func_or_funcs is None and is_multi_agg_with_relabel(**kwargs)
if relabeling:
func_or_funcs, columns, order = normalize_keyword_aggregation(kwargs)
func_or_funcs, columns, order = normalize_keyword_aggregation(kwargs) # type: ignore
if not isinstance(func_or_funcs, (str, list)):
if not isinstance(func_or_funcs, dict) or not all(
@ -274,7 +307,11 @@ class GroupBy(object, metaclass=ABCMeta):
agg = aggregate
@staticmethod
def _spark_groupby(psdf, func, groupkeys=()):
def _spark_groupby(
psdf: DataFrame,
func: Mapping[Union[Any, Tuple], Union[str, List[str]]],
groupkeys: Sequence[Series] = (),
) -> InternalFrame:
groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(groupkeys))]
groupkey_scols = [s.spark.column.alias(name) for s, name in zip(groupkeys, groupkey_names)]
@ -329,7 +366,7 @@ class GroupBy(object, metaclass=ABCMeta):
data_spark_columns=[scol_for(sdf, col) for col in data_columns],
)
def count(self) -> Union[DataFrame, Series]:
def count(self) -> T_Frame:
"""
Compute count of group, excluding missing values.
@ -352,7 +389,7 @@ class GroupBy(object, metaclass=ABCMeta):
return self._reduce_for_stat_function(F.count, only_numeric=False)
# TODO: We should fix See Also when Series implementation is finished.
def first(self) -> Union[DataFrame, Series]:
def first(self) -> T_Frame:
"""
Compute first of group values.
@ -363,7 +400,7 @@ class GroupBy(object, metaclass=ABCMeta):
"""
return self._reduce_for_stat_function(F.first, only_numeric=False)
def last(self) -> Union[DataFrame, Series]:
def last(self) -> T_Frame:
"""
Compute last of group values.
@ -376,7 +413,7 @@ class GroupBy(object, metaclass=ABCMeta):
lambda col: F.last(col, ignorenulls=True), only_numeric=False
)
def max(self) -> Union[DataFrame, Series]:
def max(self) -> T_Frame:
"""
Compute max of group values.
@ -388,7 +425,7 @@ class GroupBy(object, metaclass=ABCMeta):
return self._reduce_for_stat_function(F.max, only_numeric=False)
# TODO: examples should be updated.
def mean(self) -> Union[DataFrame, Series]:
def mean(self) -> T_Frame:
"""
Compute mean of groups, excluding missing values.
@ -419,7 +456,7 @@ class GroupBy(object, metaclass=ABCMeta):
return self._reduce_for_stat_function(F.mean, only_numeric=True)
def min(self) -> Union[DataFrame, Series]:
def min(self) -> T_Frame:
"""
Compute min of group values.
@ -431,7 +468,7 @@ class GroupBy(object, metaclass=ABCMeta):
return self._reduce_for_stat_function(F.min, only_numeric=False)
# TODO: sync the doc.
def std(self, ddof: int = 1) -> Union[DataFrame, Series]:
def std(self, ddof: int = 1) -> T_Frame:
"""
Compute standard deviation of groups, excluding missing values.
@ -452,7 +489,7 @@ class GroupBy(object, metaclass=ABCMeta):
F.stddev_pop if ddof == 0 else F.stddev_samp, only_numeric=True
)
def sum(self) -> Union[DataFrame, Series]:
def sum(self) -> T_Frame:
"""
Compute sum of group values
@ -464,7 +501,7 @@ class GroupBy(object, metaclass=ABCMeta):
return self._reduce_for_stat_function(F.sum, only_numeric=True)
# TODO: sync the doc.
def var(self, ddof: int = 1) -> Union[DataFrame, Series]:
def var(self, ddof: int = 1) -> T_Frame:
"""
Compute variance of groups, excluding missing values.
@ -486,7 +523,7 @@ class GroupBy(object, metaclass=ABCMeta):
)
# TODO: skipna should be implemented.
def all(self) -> Union[DataFrame, Series]:
def all(self) -> T_Frame:
"""
Returns True if all values in the group are truthful, else False.
@ -528,7 +565,7 @@ class GroupBy(object, metaclass=ABCMeta):
)
# TODO: skipna should be implemented.
def any(self) -> Union[DataFrame, Series]:
def any(self) -> T_Frame:
"""
Returns True if any value in the group is truthful, else False.
@ -644,7 +681,7 @@ class GroupBy(object, metaclass=ABCMeta):
)
return first_series(DataFrame(internal))
def diff(self, periods=1) -> Union[DataFrame, Series]:
def diff(self, periods: int = 1) -> T_Frame:
"""
First discrete difference of element.
@ -703,7 +740,7 @@ class GroupBy(object, metaclass=ABCMeta):
lambda sg: sg._psser._diff(periods, part_cols=sg._groupkeys_scols), should_resolve=True
)
def cumcount(self, ascending=True) -> Series:
def cumcount(self, ascending: bool = True) -> Series:
"""
Number each item in each group from 0 to the length of that group - 1.
@ -763,7 +800,7 @@ class GroupBy(object, metaclass=ABCMeta):
internal = ret._internal.resolved_copy
return first_series(DataFrame(internal))
def cummax(self) -> Union[DataFrame, Series]:
def cummax(self) -> T_Frame:
"""
Cumulative max for each group.
@ -812,7 +849,7 @@ class GroupBy(object, metaclass=ABCMeta):
numeric_only=True,
)
def cummin(self) -> Union[DataFrame, Series]:
def cummin(self) -> T_Frame:
"""
Cumulative min for each group.
@ -861,7 +898,7 @@ class GroupBy(object, metaclass=ABCMeta):
numeric_only=True,
)
def cumprod(self) -> Union[DataFrame, Series]:
def cumprod(self) -> T_Frame:
"""
Cumulative product for each group.
@ -910,7 +947,7 @@ class GroupBy(object, metaclass=ABCMeta):
numeric_only=True,
)
def cumsum(self) -> Union[DataFrame, Series]:
def cumsum(self) -> T_Frame:
"""
Cumulative sum for each group.
@ -959,7 +996,7 @@ class GroupBy(object, metaclass=ABCMeta):
numeric_only=True,
)
def apply(self, func, *args, **kwargs) -> Union[DataFrame, Series]:
def apply(self, func: Callable, *args: Any, **kwargs: Any) -> Union[DataFrame, Series]:
"""
Apply function `func` group-wise and combine the results together.
@ -1137,7 +1174,7 @@ class GroupBy(object, metaclass=ABCMeta):
else:
f = SelectionMixin._builtin_table.get(func, func)
def pandas_apply(pdf, *a, **k):
def pandas_apply(pdf: pd.DataFrame, *a: Any, **k: Any) -> Any:
return f(pdf.drop(groupkey_names, axis=1), *a, **k)
should_return_series = False
@ -1206,7 +1243,7 @@ class GroupBy(object, metaclass=ABCMeta):
]
return_schema = StructType([field.struct_field for field in data_fields])
def pandas_groupby_apply(pdf):
def pandas_groupby_apply(pdf: pd.DataFrame) -> pd.DataFrame:
if not is_series_groupby and LooseVersion(pd.__version__) < LooseVersion("0.25"):
# `groupby.apply` in pandas<0.25 runs the functions twice for the first group.
@ -1214,7 +1251,9 @@ class GroupBy(object, metaclass=ABCMeta):
should_skip_first_call = True
def wrapped_func(df, *a, **k):
def wrapped_func(
df: Union[pd.DataFrame, pd.Series], *a: Any, **k: Any
) -> Union[pd.DataFrame, pd.Series]:
nonlocal should_skip_first_call
if should_skip_first_call:
should_skip_first_call = False
@ -1266,7 +1305,7 @@ class GroupBy(object, metaclass=ABCMeta):
return DataFrame(internal)
# TODO: implement 'dropna' parameter
def filter(self, func) -> Union[DataFrame, Series]:
def filter(self, func: Callable[[T_Frame], T_Frame]) -> T_Frame:
"""
Return a copy of a DataFrame excluding elements from groups that
do not satisfy the boolean criterion specified by func.
@ -1334,16 +1373,16 @@ class GroupBy(object, metaclass=ABCMeta):
if is_series_groupby:
def pandas_filter(pdf):
def pandas_filter(pdf: pd.DataFrame) -> pd.DataFrame:
return pd.DataFrame(pdf.groupby(groupkey_names)[pdf.columns[-1]].filter(func))
else:
f = SelectionMixin._builtin_table.get(func, func)
def wrapped_func(pdf):
def wrapped_func(pdf: pd.DataFrame) -> pd.DataFrame:
return f(pdf.drop(groupkey_names, axis=1))
def pandas_filter(pdf):
def pandas_filter(pdf: pd.DataFrame) -> pd.DataFrame:
return pdf.groupby(groupkey_names).filter(wrapped_func).drop(groupkey_names, axis=1)
sdf = GroupBy._spark_group_map_apply(
@ -1356,16 +1395,18 @@ class GroupBy(object, metaclass=ABCMeta):
psdf = DataFrame(self._psdf[agg_columns]._internal.with_new_sdf(sdf))
if is_series_groupby:
return first_series(psdf)
return first_series(psdf) # type: ignore
else:
return psdf
return psdf # type: ignore
@staticmethod
def _prepare_group_map_apply(psdf, groupkeys, agg_columns):
def _prepare_group_map_apply(
psdf: DataFrame, groupkeys: List[Series], agg_columns: List[Series]
) -> Tuple[DataFrame, List[Tuple], List[str]]:
groupkey_labels = [
verify_temp_column_name(psdf, "__groupkey_{}__".format(i))
for i in range(len(groupkeys))
]
] # type: List[Tuple]
psdf = psdf[[s.rename(label) for s, label in zip(groupkeys, groupkey_labels)] + agg_columns]
groupkey_names = [label if len(label) > 1 else label[0] for label in groupkey_labels]
return DataFrame(psdf._internal.resolved_copy), groupkey_labels, groupkey_names
@ -1415,7 +1456,7 @@ class GroupBy(object, metaclass=ABCMeta):
return rename_output
def rank(self, method="average", ascending=True) -> Union[DataFrame, Series]:
def rank(self, method: str = "average", ascending: bool = True) -> T_Frame:
"""
Provide the rank of values within each group.
@ -1483,7 +1524,7 @@ class GroupBy(object, metaclass=ABCMeta):
)
# TODO: add axis parameter
def idxmax(self, skipna=True) -> Union[DataFrame, Series]:
def idxmax(self, skipna: bool = True) -> T_Frame:
"""
Return index of first occurrence of maximum over requested axis in group.
NA/null values are excluded.
@ -1562,10 +1603,10 @@ class GroupBy(object, metaclass=ABCMeta):
for psser in self._agg_columns
],
)
return DataFrame(internal)
return self._cleanup_and_return(DataFrame(internal))
# TODO: add axis parameter
def idxmin(self, skipna=True) -> Union[DataFrame, Series]:
def idxmin(self, skipna: bool = True) -> T_Frame:
"""
Return index of first occurrence of minimum over requested axis in group.
NA/null values are excluded.
@ -1644,11 +1685,16 @@ class GroupBy(object, metaclass=ABCMeta):
for psser in self._agg_columns
],
)
return DataFrame(internal)
return self._cleanup_and_return(DataFrame(internal))
def fillna(
self, value=None, method=None, axis=None, inplace=False, limit=None
) -> Union[DataFrame, Series]:
self,
value: Optional[Any] = None,
method: Optional[str] = None,
axis: Optional[Union[int, str]] = None,
inplace: bool = False,
limit: Optional[int] = None,
) -> T_Frame:
"""Fill NA/NaN values in group.
Parameters
@ -1716,7 +1762,7 @@ class GroupBy(object, metaclass=ABCMeta):
should_resolve=(method is not None),
)
def bfill(self, limit=None) -> Union[DataFrame, Series]:
def bfill(self, limit: Optional[int] = None) -> T_Frame:
"""
Synonym for `DataFrame.fillna()` with ``method=`bfill```.
@ -1767,7 +1813,7 @@ class GroupBy(object, metaclass=ABCMeta):
backfill = bfill
def ffill(self, limit=None) -> Union[DataFrame, Series]:
def ffill(self, limit: Optional[int] = None) -> T_Frame:
"""
Synonym for `DataFrame.fillna()` with ``method=`ffill```.
@ -1818,7 +1864,7 @@ class GroupBy(object, metaclass=ABCMeta):
pad = ffill
def _limit(self, n: int, asc: bool):
def _limit(self, n: int, asc: bool) -> T_Frame:
"""
Private function for tail and head.
"""
@ -1860,9 +1906,9 @@ class GroupBy(object, metaclass=ABCMeta):
)
internal = psdf._internal.with_new_sdf(sdf)
return DataFrame(internal).drop(groupkey_labels, axis=1)
return self._cleanup_and_return(DataFrame(internal).drop(groupkey_labels, axis=1))
def head(self, n=5) -> Union[DataFrame, Series]:
def head(self, n: int = 5) -> T_Frame:
"""
Return first n rows of each group.
@ -1910,7 +1956,7 @@ class GroupBy(object, metaclass=ABCMeta):
"""
return self._limit(n, asc=True)
def tail(self, n=5) -> Union[DataFrame, Series]:
def tail(self, n: int = 5) -> T_Frame:
"""
Return last n rows of each group.
@ -1963,7 +2009,7 @@ class GroupBy(object, metaclass=ABCMeta):
"""
return self._limit(n, asc=False)
def shift(self, periods=1, fill_value=None) -> Union[DataFrame, Series]:
def shift(self, periods: int = 1, fill_value: Optional[Any] = None) -> T_Frame:
"""
Shift each group by periods observations.
@ -2025,7 +2071,7 @@ class GroupBy(object, metaclass=ABCMeta):
should_resolve=True,
)
def transform(self, func, *args, **kwargs) -> Union[DataFrame, Series]:
def transform(self, func: Callable[..., pd.Series], *args: Any, **kwargs: Any) -> T_Frame:
"""
Apply function column-by-column to the GroupBy object.
@ -2151,7 +2197,7 @@ class GroupBy(object, metaclass=ABCMeta):
self._psdf, self._groupkeys, agg_columns=self._agg_columns
)
def pandas_transform(pdf):
def pandas_transform(pdf: pd.DataFrame) -> pd.DataFrame:
return pdf.groupby(groupkey_names).transform(func, *args, **kwargs)
should_infer_schema = return_sig is None
@ -2169,7 +2215,7 @@ class GroupBy(object, metaclass=ABCMeta):
)
)
if len(pdf) <= limit:
return psdf_from_pandas
return self._cleanup_and_return(psdf_from_pandas)
sdf = GroupBy._spark_group_map_apply(
psdf,
@ -2218,9 +2264,9 @@ class GroupBy(object, metaclass=ABCMeta):
spark_frame=sdf, index_spark_columns=None, data_fields=data_fields
)
return DataFrame(internal)
return self._cleanup_and_return(DataFrame(internal))
def nunique(self, dropna=True) -> Union[DataFrame, Series]:
def nunique(self, dropna: bool = True) -> T_Frame:
"""
Return DataFrame with number of distinct observations per group for each column.
@ -2273,7 +2319,7 @@ class GroupBy(object, metaclass=ABCMeta):
return self._reduce_for_stat_function(stat_function, only_numeric=False)
def rolling(self, window, min_periods=None) -> RollingGroupby:
def rolling(self, window: int, min_periods: Optional[int] = None) -> RollingGroupby:
"""
Return an rolling grouper, providing rolling
functionality per group.
@ -2302,7 +2348,7 @@ class GroupBy(object, metaclass=ABCMeta):
cast(Union[SeriesGroupBy, DataFrameGroupBy], self), window, min_periods=min_periods
)
def expanding(self, min_periods=1) -> ExpandingGroupby:
def expanding(self, min_periods: int = 1) -> ExpandingGroupby:
"""
Return an expanding grouper, providing expanding
functionality per group.
@ -2326,7 +2372,7 @@ class GroupBy(object, metaclass=ABCMeta):
cast(Union[SeriesGroupBy, DataFrameGroupBy], self), min_periods=min_periods
)
def get_group(self, name) -> Union[DataFrame, Series]:
def get_group(self, name: Union[Any, Tuple, List[Union[Any, Tuple]]]) -> T_Frame:
"""
Construct DataFrame from group with provided name.
@ -2403,9 +2449,9 @@ class GroupBy(object, metaclass=ABCMeta):
if internal.spark_frame.head() is None:
raise KeyError(name)
return DataFrame(internal)
return self._cleanup_and_return(DataFrame(internal))
def median(self, numeric_only=True, accuracy=10000) -> Union[DataFrame, Series]:
def median(self, numeric_only: bool = True, accuracy: int = 10000) -> T_Frame:
"""
Compute median of groups, excluding missing values.
@ -2472,7 +2518,9 @@ class GroupBy(object, metaclass=ABCMeta):
stat_function = lambda col: F.percentile_approx(col, 0.5, accuracy)
return self._reduce_for_stat_function(stat_function, only_numeric=numeric_only)
def _reduce_for_stat_function(self, sfun, only_numeric):
def _reduce_for_stat_function(
self, sfun: Callable[[Column], Column], only_numeric: bool
) -> T_Frame:
agg_columns = self._agg_columns
agg_columns_scols = self._agg_columns_scols
@ -2518,7 +2566,7 @@ class GroupBy(object, metaclass=ABCMeta):
data_spark_columns=[scol_for(sdf, col) for col in data_columns],
column_label_names=self._psdf._internal.column_label_names,
)
psdf = DataFrame(internal)
psdf = DataFrame(internal) # type: DataFrame
if self._dropna:
psdf = DataFrame(
@ -2537,7 +2585,7 @@ class GroupBy(object, metaclass=ABCMeta):
psdf = psdf.reset_index(level=should_drop_index, drop=True)
if len(should_drop_index) < len(self._groupkeys):
psdf = psdf.reset_index()
return psdf
return self._cleanup_and_return(psdf)
@staticmethod
def _resolve_grouping_from_diff_dataframes(
@ -2582,7 +2630,9 @@ class GroupBy(object, metaclass=ABCMeta):
)
)
def assign_columns(psdf, this_column_labels, that_column_labels):
def assign_columns(
psdf: DataFrame, this_column_labels: List[Tuple], that_column_labels: List[Tuple]
) -> Tuple[Series, Tuple]:
raise NotImplementedError(
"Duplicated labels with groupby() and "
"'compute.ops_on_diff_frames' option are not supported currently "
@ -2629,7 +2679,7 @@ class GroupBy(object, metaclass=ABCMeta):
return new_by_series
class DataFrameGroupBy(GroupBy):
class DataFrameGroupBy(GroupBy[DataFrame]):
@staticmethod
def _build(
psdf: DataFrame, by: List[Union[Series, Tuple]], as_index: bool, dropna: bool
@ -2660,7 +2710,6 @@ class DataFrameGroupBy(GroupBy):
column_labels_to_exlcude: Set[Tuple],
agg_columns: List[Tuple] = None,
):
agg_columns_selected = agg_columns is not None
if agg_columns_selected:
for label in agg_columns:
@ -2693,7 +2742,7 @@ class DataFrameGroupBy(GroupBy):
return partial(property_or_func, self)
return self.__getitem__(item)
def __getitem__(self, item):
def __getitem__(self, item: Any) -> GroupBy:
if self._as_index and is_name_like_value(item):
return SeriesGroupBy(
self._psdf._psser_for(item if is_name_like_tuple(item) else (item,)),
@ -2723,10 +2772,15 @@ class DataFrameGroupBy(GroupBy):
agg_columns=item,
)
def _apply_series_op(self, op, should_resolve: bool = False, numeric_only: bool = False):
def _apply_series_op(
self,
op: Callable[["SeriesGroupBy"], Series],
should_resolve: bool = False,
numeric_only: bool = False,
) -> DataFrame:
applied = []
for column in self._agg_columns:
applied.append(op(column.groupby(self._groupkeys)))
applied.append(op(cast(SeriesGroupBy, column.groupby(self._groupkeys))))
if numeric_only:
applied = [col for col in applied if isinstance(col.spark.data_type, NumericType)]
if not applied:
@ -2736,6 +2790,9 @@ class DataFrameGroupBy(GroupBy):
internal = internal.resolved_copy
return DataFrame(internal)
def _cleanup_and_return(self, psdf: DataFrame) -> DataFrame:
return psdf
# TODO: Implement 'percentiles', 'include', and 'exclude' arguments.
# TODO: Add ``DataFrame.select_dtypes`` to See Also when 'include'
# and 'exclude' arguments are implemented.
@ -2826,7 +2883,7 @@ class DataFrameGroupBy(GroupBy):
return DataFrame(internal).astype("float64")
class SeriesGroupBy(GroupBy):
class SeriesGroupBy(GroupBy[Series]):
@staticmethod
def _build(
psser: Series, by: List[Union[Series, Tuple]], as_index: bool, dropna: bool
@ -2870,7 +2927,12 @@ class SeriesGroupBy(GroupBy):
return partial(property_or_func, self)
raise AttributeError(item)
def _apply_series_op(self, op, should_resolve: bool = False, numeric_only: bool = False):
def _apply_series_op(
self,
op: Callable[["SeriesGroupBy"], Series],
should_resolve: bool = False,
numeric_only: bool = False,
) -> Series:
if numeric_only and not isinstance(self._agg_columns[0].spark.data_type, NumericType):
raise DataError("No numeric types to aggregate")
psser = op(self)
@ -2880,52 +2942,22 @@ class SeriesGroupBy(GroupBy):
else:
return psser
def _reduce_for_stat_function(self, sfun, only_numeric):
return first_series(super()._reduce_for_stat_function(sfun, only_numeric))
def _cleanup_and_return(self, pdf: pd.DataFrame) -> Series:
return first_series(pdf).rename().rename(self._psser.name)
def agg(self, *args, **kwargs) -> None:
def agg(self, *args: Any, **kwargs: Any) -> None:
return MissingPandasLikeSeriesGroupBy.agg(self, *args, **kwargs)
def aggregate(self, *args, **kwargs) -> None:
def aggregate(self, *args: Any, **kwargs: Any) -> None:
return MissingPandasLikeSeriesGroupBy.aggregate(self, *args, **kwargs)
def transform(self, func, *args, **kwargs) -> Series:
return first_series(super().transform(func, *args, **kwargs)).rename(self._psser.name)
transform.__doc__ = GroupBy.transform.__doc__
def idxmin(self, skipna=True) -> Series:
return first_series(super().idxmin(skipna))
idxmin.__doc__ = GroupBy.idxmin.__doc__
def idxmax(self, skipna=True) -> Series:
return first_series(super().idxmax(skipna))
idxmax.__doc__ = GroupBy.idxmax.__doc__
def head(self, n=5) -> Series:
return first_series(super().head(n)).rename(self._psser.name)
head.__doc__ = GroupBy.head.__doc__
def tail(self, n=5) -> Series:
return first_series(super().tail(n)).rename(self._psser.name)
tail.__doc__ = GroupBy.tail.__doc__
def size(self) -> Series:
return super().size().rename(self._psser.name)
size.__doc__ = GroupBy.size.__doc__
def get_group(self, name) -> Series:
return first_series(super().get_group(name))
get_group.__doc__ = GroupBy.get_group.__doc__
# TODO: add keep parameter
def nsmallest(self, n=5) -> Series:
def nsmallest(self, n: int = 5) -> Series:
"""
Return the first n rows ordered by columns in ascending order in group.
@ -3010,7 +3042,7 @@ class SeriesGroupBy(GroupBy):
return first_series(DataFrame(internal))
# TODO: add keep parameter
def nlargest(self, n=5) -> Series:
def nlargest(self, n: int = 5) -> Series:
"""
Return the first n rows ordered by columns in descending order in group.
@ -3095,7 +3127,9 @@ class SeriesGroupBy(GroupBy):
return first_series(DataFrame(internal))
# TODO: add bins, normalize parameter
def value_counts(self, sort=None, ascending=None, dropna=True) -> Series:
def value_counts(
self, sort: Optional[bool] = None, ascending: Optional[bool] = None, dropna: bool = True
) -> Series:
"""
Compute group sizes.
@ -3188,7 +3222,7 @@ class SeriesGroupBy(GroupBy):
return self._reduce_for_stat_function(F.collect_set, only_numeric=False)
def is_multi_agg_with_relabel(**kwargs):
def is_multi_agg_with_relabel(**kwargs: Any) -> bool:
"""
Check whether the kwargs pass to .agg look like multi-agg with relabling.
@ -3215,7 +3249,9 @@ def is_multi_agg_with_relabel(**kwargs):
return all(isinstance(v, tuple) and len(v) == 2 for v in kwargs.values())
def normalize_keyword_aggregation(kwargs):
def normalize_keyword_aggregation(
kwargs: Dict[str, Tuple[Union[Any, Tuple], str]],
) -> Tuple[Dict[Union[Any, Tuple], List[str]], List[str], List[Tuple]]:
"""
Normalize user-provided kwargs.
@ -3238,7 +3274,7 @@ def normalize_keyword_aggregation(kwargs):
Examples
--------
>>> normalize_keyword_aggregation({'output': ('input', 'sum')})
(OrderedDict([('input', ['sum'])]), ('output',), [('input', 'sum')])
(OrderedDict([('input', ['sum'])]), ['output'], [('input', 'sum')])
"""
# this is due to python version issue, not sure the impact on pandas-on-Spark
PY36 = sys.version_info >= (3, 6)
@ -3246,9 +3282,9 @@ def normalize_keyword_aggregation(kwargs):
kwargs = OrderedDict(sorted(kwargs.items()))
# TODO(Py35): When we drop python 3.5, change this to defaultdict(list)
aggspec = OrderedDict()
order = []
columns, pairs = list(zip(*kwargs.items()))
aggspec = OrderedDict() # type: Dict[Union[Any, Tuple], List[str]]
order = [] # type: List[Tuple]
columns, pairs = zip(*kwargs.items())
for column, aggfunc in pairs:
if column in aggspec:
@ -3261,10 +3297,10 @@ def normalize_keyword_aggregation(kwargs):
# flattened to ('y', 'A', 'max'), it won't do anything on normal Index.
if isinstance(order[0][0], tuple):
order = [(*levs, method) for levs, method in order]
return aggspec, columns, order
return aggspec, list(columns), order
def _test():
def _test() -> None:
import os
import doctest
import sys

View file

@ -24,7 +24,19 @@ import inspect
import sys
from collections.abc import Mapping
from functools import partial, wraps, reduce
from typing import Any, Callable, Generic, Iterable, List, Optional, Tuple, TypeVar, Union, cast
from typing import (
Any,
Callable,
Generic,
Iterable,
List,
Optional,
Tuple,
TypeVar,
Union,
cast,
TYPE_CHECKING,
)
import numpy as np
import pandas as pd
@ -94,6 +106,8 @@ from pyspark.pandas.typedef import (
SeriesType,
)
if TYPE_CHECKING:
from pyspark.pandas.groupby import SeriesGroupBy # noqa: F401 (SPARK-34943)
# This regular expression pattern is complied and defined here to avoid to compile the same
# pattern every time it is used in _repr_ in Series.
@ -6126,6 +6140,13 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
result = unpack_scalar(self._internal.spark_frame.select(scol))
return result if result is not None else np.nan
def _build_groupby(
self, by: List[Union["Series", Tuple]], as_index: bool, dropna: bool
) -> "SeriesGroupBy":
from pyspark.pandas.groupby import SeriesGroupBy
return SeriesGroupBy._build(self, by, as_index=as_index, dropna=dropna)
def __getitem__(self, key):
try:
if (isinstance(key, slice) and any(type(n) == int for n in [key.start, key.stop])) or (