[SPARK-35471][PYTHON] Fix disallow_untyped_defs mypy checks for pyspark.pandas.frame

### What changes were proposed in this pull request?

Adds more type annotations in the file `python/pyspark/pandas/frame.py` and fixes the mypy check failures.

### Why are the changes needed?

We should enable more disallow_untyped_defs mypy checks.

### Does this PR introduce _any_ user-facing change?

Yes.
This PR adds more type annotations in pandas APIs on Spark module, which can impact interaction with development tools for users.

### How was this patch tested?

The mypy check with a new configuration and existing tests should pass.

Closes #33073 from ueshin/issues/SPARK-35471/disallow_untyped_defs_frame.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
Takuya UESHIN 2021-06-25 14:41:58 +09:00 committed by Hyukjin Kwon
parent c6555f1845
commit 6497ac3585
7 changed files with 487 additions and 324 deletions

View file

@ -158,6 +158,3 @@ ignore_missing_imports = True
[mypy-pyspark.pandas.data_type_ops.*]
disallow_untyped_defs = False
[mypy-pyspark.pandas.frame]
disallow_untyped_defs = False

File diff suppressed because it is too large Load diff

View file

@ -96,7 +96,9 @@ class Frame(object, metaclass=ABCMeta):
@abstractmethod
def _apply_series_op(
self: T_Frame, op: Callable[["Series"], "Series"], should_resolve: bool = False
self: T_Frame,
op: Callable[["Series"], Union["Series", Column]],
should_resolve: bool = False,
) -> T_Frame:
pass
@ -2096,11 +2098,13 @@ class Frame(object, metaclass=ABCMeta):
3 7 40 50
"""
def abs(psser: "Series") -> "Series":
def abs(psser: "Series") -> Union["Series", Column]:
if isinstance(psser.spark.data_type, BooleanType):
return psser
elif isinstance(psser.spark.data_type, NumericType):
return psser.spark.transform(F.abs)
return psser._with_new_scol(
F.abs(psser.spark.column), field=psser._internal.data_fields[0]
)
else:
raise TypeError(
"bad operand type for abs(): {} ({})".format(

View file

@ -31,6 +31,7 @@ from typing import (
Callable,
Dict,
Generic,
Iterator,
Mapping,
List,
Optional,
@ -2632,7 +2633,7 @@ class GroupBy(Generic[T_Frame], metaclass=ABCMeta):
def assign_columns(
psdf: DataFrame, this_column_labels: List[Tuple], that_column_labels: List[Tuple]
) -> Tuple[Series, Tuple]:
) -> Iterator[Tuple[Series, Tuple]]:
raise NotImplementedError(
"Duplicated labels with groupby() and "
"'compute.ops_on_diff_frames' option are not supported currently "

View file

@ -30,6 +30,7 @@ from typing import ( # noqa: F401 (SPARK-34943)
Union,
cast,
no_type_check,
overload,
)
from collections import OrderedDict
from collections.abc import Iterable
@ -1130,7 +1131,7 @@ def read_excel(
else:
pdf = pdf_or_pser
psdf = from_pandas(pdf)
psdf = cast(DataFrame, from_pandas(pdf))
return_schema = force_decimal_precision_scale(
as_nullable_spark_type(psdf._internal.spark_frame.drop(*HIDDEN_COLUMNS).schema)
)

View file

@ -2500,7 +2500,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
"""
inplace = validate_bool_kwarg(inplace, "inplace")
psdf = self._psdf[[self.name]]._sort(
by=[self.spark.column], ascending=ascending, inplace=False, na_position=na_position
by=[self.spark.column], ascending=ascending, na_position=na_position
)
if inplace:
@ -3041,7 +3041,8 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
sample.__doc__ = DataFrame.sample.__doc__
def hist(self, bins: int = 10, **kwds: Any) -> Any:
@no_type_check
def hist(self, bins=10, **kwds):
return self.plot.hist(bins, **kwds)
hist.__doc__ = PandasOnSparkPlotAccessor.hist.__doc__
@ -4963,17 +4964,19 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
if not self.index.sort_values().equals(other.index.sort_values()):
raise ValueError("matrices are not aligned")
other = other.copy()
column_labels = other._internal.column_labels
other_copy = other.copy() # type: DataFrame
column_labels = other_copy._internal.column_labels
self_column_label = verify_temp_column_name(other, "__self_column__")
other[self_column_label] = self
self_psser = other._psser_for(self_column_label)
self_column_label = verify_temp_column_name(other_copy, "__self_column__")
other_copy[self_column_label] = self
self_psser = other_copy._psser_for(self_column_label)
product_pssers = [other._psser_for(label) * self_psser for label in column_labels]
product_pssers = [
cast(Series, other_copy._psser_for(label) * self_psser) for label in column_labels
]
dot_product_psser = DataFrame(
other._internal.with_new_columns(product_pssers, column_labels=column_labels)
other_copy._internal.with_new_columns(product_pssers, column_labels=column_labels)
).sum()
return cast(Series, dot_product_psser).rename(self.name)
@ -6158,9 +6161,13 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
# ----------------------------------------------------------------------
def _apply_series_op(
self, op: Callable[["Series"], "Series"], should_resolve: bool = False
self, op: Callable[["Series"], Union["Series", Column]], should_resolve: bool = False
) -> "Series":
psser = op(self)
psser_or_scol = op(self)
if isinstance(psser_or_scol, Series):
psser = psser_or_scol
else:
psser = self._with_new_scol(cast(Column, psser_or_scol))
if should_resolve:
internal = psser._internal.resolved_copy
return first_series(DataFrame(internal))

View file

@ -307,7 +307,9 @@ def combine_frames(
def align_diff_frames(
resolve_func: Callable[["DataFrame", List[Tuple], List[Tuple]], Tuple["Series", Tuple]],
resolve_func: Callable[
["DataFrame", List[Tuple], List[Tuple]], Iterator[Tuple["Series", Tuple]]
],
this: "DataFrame",
that: "DataFrame",
fillna: bool = True,
@ -419,7 +421,7 @@ def align_diff_frames(
if len(this_columns_to_apply) > 0 or len(that_columns_to_apply) > 0:
psser_set, column_labels_set = zip(
*resolve_func(combined, this_columns_to_apply, that_columns_to_apply)
) # type: Tuple[Iterable[Series], Iterable[Tuple]]
)
columns_applied = list(psser_set) # type: List[Union[Series, spark.Column]]
column_labels_applied = list(column_labels_set) # type: List[Tuple]
else: