[SPARK-35476][PYTHON] Fix disallow_untyped_defs mypy checks for pyspark.pandas.series

### What changes were proposed in this pull request?

Adds more type annotations in the file `python/pyspark/pandas/series.py` and fixes the mypy check failures.

### Why are the changes needed?

We should enable more disallow_untyped_defs mypy checks.

### Does this PR introduce _any_ user-facing change?

Yes.
This PR adds more type annotations in pandas APIs on Spark module, which can impact interaction with development tools for users.

### How was this patch tested?

The mypy check with a new configuration and existing tests should pass.

Closes #33045 from ueshin/issues/SPARK-35476/disallow_untyped_defs_series.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
Takuya UESHIN 2021-06-24 19:32:33 +09:00 committed by Hyukjin Kwon
parent 3904c0edba
commit cfcfbca965
8 changed files with 329 additions and 221 deletions

View file

@ -161,6 +161,3 @@ disallow_untyped_defs = False
[mypy-pyspark.pandas.frame]
disallow_untyped_defs = False
[mypy-pyspark.pandas.series]
disallow_untyped_defs = False

View file

@ -321,7 +321,8 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
pass
# arithmetic operators
__neg__ = column_op(Column.__neg__)
def __neg__(self: T_IndexOps) -> T_IndexOps:
return cast(T_IndexOps, column_op(Column.__neg__)(self))
def __add__(self, other: Any) -> IndexOpsLike:
return self._dtype_op.add(self, other)
@ -397,7 +398,8 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
def __rpow__(self, other: Any) -> IndexOpsLike:
return self._dtype_op.rpow(self, other)
__abs__ = column_op(F.abs)
def __abs__(self: T_IndexOps) -> T_IndexOps:
return cast(T_IndexOps, column_op(F.abs)(self))
# comparison operators
def __eq__(self, other: Any) -> IndexOpsLike: # type: ignore[override]
@ -411,7 +413,8 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
__ge__ = column_op(Column.__ge__)
__gt__ = column_op(Column.__gt__)
__invert__ = column_op(Column.__invert__)
def __invert__(self: T_IndexOps) -> T_IndexOps:
return cast(T_IndexOps, column_op(Column.__invert__)(self))
# `and`, `or`, `not` cannot be overloaded in Python,
# so use bitwise operators as boolean operators

View file

@ -31,16 +31,18 @@ import sys
from itertools import zip_longest
from typing import (
Any,
Optional,
List,
Tuple,
Union,
Callable,
Dict,
Generic,
TypeVar,
IO,
Iterable,
Iterator,
Dict,
Callable,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
TYPE_CHECKING,
)
@ -766,7 +768,9 @@ class DataFrame(Frame, Generic[T]):
"""
return self._pssers[label]
def _apply_series_op(self, op, should_resolve: bool = False):
def _apply_series_op(
self, op: Callable[["Series"], "Series"], should_resolve: bool = False
) -> "DataFrame":
applied = []
for label in self._internal.column_labels:
applied.append(op(self._psser_for(label)))
@ -1579,7 +1583,7 @@ class DataFrame(Frame, Generic[T]):
"""This is an alias of ``iteritems``."""
return self.iteritems()
def to_clipboard(self, excel=True, sep=None, **kwargs) -> None:
def to_clipboard(self, excel: bool = True, sep: Optional[str] = None, **kwargs: Any) -> None:
"""
Copy object to the system clipboard.
@ -1988,25 +1992,27 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
def to_latex(
self,
buf=None,
columns=None,
col_space=None,
header=True,
index=True,
na_rep="NaN",
formatters=None,
float_format=None,
sparsify=None,
index_names=True,
bold_rows=False,
column_format=None,
longtable=None,
escape=None,
encoding=None,
decimal=".",
multicolumn=None,
multicolumn_format=None,
multirow=None,
buf: Optional[IO[str]] = None,
columns: Optional[List[Union[Any, Tuple]]] = None,
col_space: Optional[int] = None,
header: bool = True,
index: bool = True,
na_rep: str = "NaN",
formatters: Optional[
Union[List[Callable[[Any], str]], Dict[Union[Any, Tuple], Callable[[Any], str]]]
] = None,
float_format: Optional[Callable[[float], str]] = None,
sparsify: Optional[bool] = None,
index_names: bool = True,
bold_rows: bool = False,
column_format: Optional[str] = None,
longtable: Optional[bool] = None,
escape: Optional[bool] = None,
encoding: Optional[str] = None,
decimal: str = ".",
multicolumn: Optional[bool] = None,
multicolumn_format: Optional[str] = None,
multirow: Optional[bool] = None,
) -> Optional[str]:
r"""
Render an object to a LaTeX tabular environment table.
@ -3600,7 +3606,12 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
return DataFrame(internal)
def reset_index(
self, level=None, drop=False, inplace=False, col_level=0, col_fill=""
self,
level: Optional[Union[int, Any, Tuple, Sequence[Union[int, Any, Tuple]]]] = None,
drop: bool = False,
inplace: bool = False,
col_level: int = 0,
col_fill: str = "",
) -> Optional["DataFrame"]:
"""Reset the index, or a level of it.
@ -3772,29 +3783,30 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
index_fields = []
else:
if is_list_like(level):
level = list(level)
level = list(cast(Sequence[Union[int, Any, Tuple]], level))
if isinstance(level, int) or is_name_like_tuple(level):
level = [level]
level_list = [cast(Union[int, Tuple], level)]
elif is_name_like_value(level):
level = [(level,)]
level_list = [(level,)]
else:
level = [
level_list = [
lvl if isinstance(lvl, int) or is_name_like_tuple(lvl) else (lvl,)
for lvl in level
]
if all(isinstance(l, int) for l in level):
for lev in level:
if all(isinstance(l, int) for l in level_list):
int_level_list = cast(List[int], level_list)
for lev in int_level_list:
if lev >= self._internal.index_level:
raise IndexError(
"Too many levels: Index has only {} level, not {}".format(
self._internal.index_level, lev + 1
)
)
idx = level
elif all(is_name_like_tuple(lev) for lev in level):
idx = int_level_list
elif all(is_name_like_tuple(lev) for lev in level_list):
idx = []
for l in level:
for l in cast(List[Tuple], level_list):
try:
i = self._internal.index_names.index(l)
idx.append(i)
@ -5129,13 +5141,13 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
psdf._to_internal_pandas(), self.to_records, pd.DataFrame.to_records, args
)
def copy(self, deep=None) -> "DataFrame":
def copy(self, deep: bool = True) -> "DataFrame":
"""
Make a copy of this object's indices and data.
Parameters
----------
deep : None
deep : bool, default True
this parameter is not supported but just dummy parameter to match pandas.
Returns
@ -6500,7 +6512,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
self._internal.with_new_columns([self._psser_for(label) for label in column_labels])
)
def droplevel(self, level, axis=0) -> "DataFrame":
def droplevel(
self, level: Union[int, Any, Tuple, List[Union[int, Any, Tuple]]], axis: Union[int, str] = 0
) -> "DataFrame":
"""
Return DataFrame with requested index / column level(s) removed.
@ -6950,7 +6964,12 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
return self._sort(by=by, ascending=ascending, inplace=inplace, na_position=na_position)
def swaplevel(self, i=-2, j=-1, axis=0) -> "DataFrame":
def swaplevel(
self,
i: Union[int, Any, Tuple] = -2,
j: Union[int, Any, Tuple] = -1,
axis: Union[int, str] = 0,
) -> "DataFrame":
"""
Swap levels i and j in a MultiIndex on a particular axis.
@ -7108,7 +7127,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
return self.copy() if i == j else self.transpose()
def _swaplevel_columns(self, i, j) -> InternalFrame:
def _swaplevel_columns(
self, i: Union[int, Any, Tuple], j: Union[int, Any, Tuple]
) -> InternalFrame:
assert isinstance(self.columns, pd.MultiIndex)
for index in (i, j):
if not isinstance(index, int) and index not in self.columns.names:
@ -7138,7 +7159,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
)
return internal
def _swaplevel_index(self, i, j) -> InternalFrame:
def _swaplevel_index(
self, i: Union[int, Any, Tuple], j: Union[int, Any, Tuple]
) -> InternalFrame:
assert isinstance(self.index, ps.MultiIndex)
for index in (i, j):
if not isinstance(index, int) and index not in self.index.names:
@ -9750,7 +9773,13 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
lambda psser: psser._rank(method=method, ascending=ascending), should_resolve=True
)
def filter(self, items=None, like=None, regex=None, axis=None) -> "DataFrame":
def filter(
self,
items: Optional[Sequence[Any]] = None,
like: Optional[str] = None,
regex: Optional[str] = None,
axis: Optional[Union[int, str]] = None,
) -> "DataFrame":
"""
Subset rows or columns of dataframe according to labels in
the specified index.
@ -10710,10 +10739,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
"accuracy must be an integer; however, got [%s]" % type(accuracy).__name__
)
if isinstance(q, Iterable):
q = list(q)
qq = list(q) if isinstance(q, Iterable) else q # type: Union[float, List[float]]
for v in q if isinstance(q, list) else [q]:
for v in qq if isinstance(qq, list) else [qq]:
if not isinstance(v, float):
raise TypeError(
"q must be a float or an array of floats; however, [%s] found." % type(v)
@ -10721,9 +10749,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
if v < 0.0 or v > 1.0:
raise ValueError("percentiles should all be in the interval [0, 1].")
def quantile(spark_column, spark_type):
def quantile(spark_column: Column, spark_type: DataType) -> Column:
if isinstance(spark_type, (BooleanType, NumericType)):
return F.percentile_approx(spark_column.cast(DoubleType()), q, accuracy)
return F.percentile_approx(spark_column.cast(DoubleType()), qq, accuracy)
else:
raise TypeError(
"Could not convert {} ({}) to numeric".format(
@ -10731,7 +10759,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
)
)
if isinstance(q, list):
if isinstance(qq, list):
# First calculate the percentiles from all columns and map it to each `quantiles`
# by creating each entry as a struct. So, it becomes an array of structs as below:
#
@ -10759,7 +10787,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
column_labels.append(label)
if len(percentile_cols) == 0:
return DataFrame(index=q)
return DataFrame(index=qq)
sdf = self._internal.spark_frame.select(percentile_cols)
# Here, after select percentile cols, a spark_frame looks like below:
@ -10772,13 +10800,13 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
cols_dict = OrderedDict() # type: OrderedDict
for column in percentile_col_names:
cols_dict[column] = list()
for i in range(len(q)):
for i in range(len(qq)):
cols_dict[column].append(scol_for(sdf, column).getItem(i).alias(column))
internal_index_column = SPARK_DEFAULT_INDEX_NAME
cols = []
for i, col in enumerate(zip(*cols_dict.values())):
cols.append(F.struct(F.lit(q[i]).alias(internal_index_column), *col))
cols.append(F.struct(F.lit(qq[i]).alias(internal_index_column), *col))
sdf = sdf.select(F.array(*cols).alias("arrays"))
# And then, explode it and manually set the index.
@ -10801,7 +10829,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
else:
return self._reduce_for_stat_function(
quantile, name="quantile", numeric_only=numeric_only
).rename(q)
).rename(qq)
def query(self, expr, inplace=False) -> Optional["DataFrame"]:
"""

View file

@ -2263,8 +2263,8 @@ class Index(IndexOpsMixin):
"Union between Index and MultiIndex is not yet supported"
)
elif isinstance(other, Series):
other = other.to_frame()
other_idx = other.set_index(other.columns[0]).index
other_frame = other.to_frame()
other_idx = other_frame.set_index(other_frame.columns[0]).index
elif isinstance(other, DataFrame):
raise ValueError("Index data must be 1-dimensional")
else:

View file

@ -149,7 +149,7 @@ class MultiIndex(Index):
def _column_label(self) -> Optional[Tuple]:
return None
def __abs__(self) -> Index:
def __abs__(self) -> "MultiIndex":
raise TypeError("TypeError: cannot perform __abs__ with this index type: MultiIndex")
def _with_new_scol(

View file

@ -1432,9 +1432,9 @@ class InternalFrame(object):
assert isinstance(pred.spark.data_type, BooleanType), pred.spark.data_type
condition = pred.spark.column
else:
spark_type = self.spark_frame.select(pred).schema[0].dataType
assert isinstance(spark_type, BooleanType), spark_type
condition = pred
spark_type = self.spark_frame.select(condition).schema[0].dataType
assert isinstance(spark_type, BooleanType), spark_type
return self.with_new_sdf(self.spark_frame.filter(condition).select(self.spark_columns))

File diff suppressed because it is too large Load diff

View file

@ -22,10 +22,11 @@ import functools
from collections import OrderedDict
from contextlib import contextmanager
import os
from typing import (
from typing import ( # noqa: F401 (SPARK-34943)
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
@ -381,11 +382,11 @@ def align_diff_frames(
# 2. Apply the given function to transform the columns in a batch and keep the new columns.
combined_column_labels = combined._internal.column_labels
that_columns_to_apply = []
this_columns_to_apply = []
additional_that_columns = []
columns_to_keep = []
column_labels_to_keep = []
that_columns_to_apply = [] # type: List[Tuple]
this_columns_to_apply = [] # type: List[Tuple]
additional_that_columns = [] # type: List[Tuple]
columns_to_keep = [] # type: List[Union[Series, spark.Column]]
column_labels_to_keep = [] # type: List[Tuple]
for combined_label in combined_column_labels:
for common_label in common_column_labels:
@ -418,9 +419,9 @@ def align_diff_frames(
if len(this_columns_to_apply) > 0 or len(that_columns_to_apply) > 0:
psser_set, column_labels_set = zip(
*resolve_func(combined, this_columns_to_apply, that_columns_to_apply)
)
columns_applied = list(psser_set)
column_labels_applied = list(column_labels_set)
) # type: Tuple[Iterable[Series], Iterable[Tuple]]
columns_applied = list(psser_set) # type: List[Union[Series, spark.Column]]
column_labels_applied = list(column_labels_set) # type: List[Tuple]
else:
columns_applied = []
column_labels_applied = []