[SPARK-35471][PYTHON] Fix disallow_untyped_defs mypy checks for pyspark.pandas.frame
### What changes were proposed in this pull request? Adds more type annotations in the file `python/pyspark/pandas/frame.py` and fixes the mypy check failures. ### Why are the changes needed? We should enable more disallow_untyped_defs mypy checks. ### Does this PR introduce _any_ user-facing change? Yes. This PR adds more type annotations in pandas APIs on Spark module, which can impact interaction with development tools for users. ### How was this patch tested? The mypy check with a new configuration and existing tests should pass. Closes #33073 from ueshin/issues/SPARK-35471/disallow_untyped_defs_frame. Authored-by: Takuya UESHIN <ueshin@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
parent
c6555f1845
commit
6497ac3585
|
@ -158,6 +158,3 @@ ignore_missing_imports = True
|
|||
|
||||
[mypy-pyspark.pandas.data_type_ops.*]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-pyspark.pandas.frame]
|
||||
disallow_untyped_defs = False
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -96,7 +96,9 @@ class Frame(object, metaclass=ABCMeta):
|
|||
|
||||
@abstractmethod
|
||||
def _apply_series_op(
|
||||
self: T_Frame, op: Callable[["Series"], "Series"], should_resolve: bool = False
|
||||
self: T_Frame,
|
||||
op: Callable[["Series"], Union["Series", Column]],
|
||||
should_resolve: bool = False,
|
||||
) -> T_Frame:
|
||||
pass
|
||||
|
||||
|
@ -2096,11 +2098,13 @@ class Frame(object, metaclass=ABCMeta):
|
|||
3 7 40 50
|
||||
"""
|
||||
|
||||
def abs(psser: "Series") -> "Series":
|
||||
def abs(psser: "Series") -> Union["Series", Column]:
|
||||
if isinstance(psser.spark.data_type, BooleanType):
|
||||
return psser
|
||||
elif isinstance(psser.spark.data_type, NumericType):
|
||||
return psser.spark.transform(F.abs)
|
||||
return psser._with_new_scol(
|
||||
F.abs(psser.spark.column), field=psser._internal.data_fields[0]
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
"bad operand type for abs(): {} ({})".format(
|
||||
|
|
|
@ -31,6 +31,7 @@ from typing import (
|
|||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
Mapping,
|
||||
List,
|
||||
Optional,
|
||||
|
@ -2632,7 +2633,7 @@ class GroupBy(Generic[T_Frame], metaclass=ABCMeta):
|
|||
|
||||
def assign_columns(
|
||||
psdf: DataFrame, this_column_labels: List[Tuple], that_column_labels: List[Tuple]
|
||||
) -> Tuple[Series, Tuple]:
|
||||
) -> Iterator[Tuple[Series, Tuple]]:
|
||||
raise NotImplementedError(
|
||||
"Duplicated labels with groupby() and "
|
||||
"'compute.ops_on_diff_frames' option are not supported currently "
|
||||
|
|
|
@ -30,6 +30,7 @@ from typing import ( # noqa: F401 (SPARK-34943)
|
|||
Union,
|
||||
cast,
|
||||
no_type_check,
|
||||
overload,
|
||||
)
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterable
|
||||
|
@ -1130,7 +1131,7 @@ def read_excel(
|
|||
else:
|
||||
pdf = pdf_or_pser
|
||||
|
||||
psdf = from_pandas(pdf)
|
||||
psdf = cast(DataFrame, from_pandas(pdf))
|
||||
return_schema = force_decimal_precision_scale(
|
||||
as_nullable_spark_type(psdf._internal.spark_frame.drop(*HIDDEN_COLUMNS).schema)
|
||||
)
|
||||
|
|
|
@ -2500,7 +2500,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
|
|||
"""
|
||||
inplace = validate_bool_kwarg(inplace, "inplace")
|
||||
psdf = self._psdf[[self.name]]._sort(
|
||||
by=[self.spark.column], ascending=ascending, inplace=False, na_position=na_position
|
||||
by=[self.spark.column], ascending=ascending, na_position=na_position
|
||||
)
|
||||
|
||||
if inplace:
|
||||
|
@ -3041,7 +3041,8 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
|
|||
|
||||
sample.__doc__ = DataFrame.sample.__doc__
|
||||
|
||||
def hist(self, bins: int = 10, **kwds: Any) -> Any:
|
||||
@no_type_check
|
||||
def hist(self, bins=10, **kwds):
|
||||
return self.plot.hist(bins, **kwds)
|
||||
|
||||
hist.__doc__ = PandasOnSparkPlotAccessor.hist.__doc__
|
||||
|
@ -4963,17 +4964,19 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
|
|||
if not self.index.sort_values().equals(other.index.sort_values()):
|
||||
raise ValueError("matrices are not aligned")
|
||||
|
||||
other = other.copy()
|
||||
column_labels = other._internal.column_labels
|
||||
other_copy = other.copy() # type: DataFrame
|
||||
column_labels = other_copy._internal.column_labels
|
||||
|
||||
self_column_label = verify_temp_column_name(other, "__self_column__")
|
||||
other[self_column_label] = self
|
||||
self_psser = other._psser_for(self_column_label)
|
||||
self_column_label = verify_temp_column_name(other_copy, "__self_column__")
|
||||
other_copy[self_column_label] = self
|
||||
self_psser = other_copy._psser_for(self_column_label)
|
||||
|
||||
product_pssers = [other._psser_for(label) * self_psser for label in column_labels]
|
||||
product_pssers = [
|
||||
cast(Series, other_copy._psser_for(label) * self_psser) for label in column_labels
|
||||
]
|
||||
|
||||
dot_product_psser = DataFrame(
|
||||
other._internal.with_new_columns(product_pssers, column_labels=column_labels)
|
||||
other_copy._internal.with_new_columns(product_pssers, column_labels=column_labels)
|
||||
).sum()
|
||||
|
||||
return cast(Series, dot_product_psser).rename(self.name)
|
||||
|
@ -6158,9 +6161,13 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
|
|||
# ----------------------------------------------------------------------
|
||||
|
||||
def _apply_series_op(
|
||||
self, op: Callable[["Series"], "Series"], should_resolve: bool = False
|
||||
self, op: Callable[["Series"], Union["Series", Column]], should_resolve: bool = False
|
||||
) -> "Series":
|
||||
psser = op(self)
|
||||
psser_or_scol = op(self)
|
||||
if isinstance(psser_or_scol, Series):
|
||||
psser = psser_or_scol
|
||||
else:
|
||||
psser = self._with_new_scol(cast(Column, psser_or_scol))
|
||||
if should_resolve:
|
||||
internal = psser._internal.resolved_copy
|
||||
return first_series(DataFrame(internal))
|
||||
|
|
|
@ -307,7 +307,9 @@ def combine_frames(
|
|||
|
||||
|
||||
def align_diff_frames(
|
||||
resolve_func: Callable[["DataFrame", List[Tuple], List[Tuple]], Tuple["Series", Tuple]],
|
||||
resolve_func: Callable[
|
||||
["DataFrame", List[Tuple], List[Tuple]], Iterator[Tuple["Series", Tuple]]
|
||||
],
|
||||
this: "DataFrame",
|
||||
that: "DataFrame",
|
||||
fillna: bool = True,
|
||||
|
@ -419,7 +421,7 @@ def align_diff_frames(
|
|||
if len(this_columns_to_apply) > 0 or len(that_columns_to_apply) > 0:
|
||||
psser_set, column_labels_set = zip(
|
||||
*resolve_func(combined, this_columns_to_apply, that_columns_to_apply)
|
||||
) # type: Tuple[Iterable[Series], Iterable[Tuple]]
|
||||
)
|
||||
columns_applied = list(psser_set) # type: List[Union[Series, spark.Column]]
|
||||
column_labels_applied = list(column_labels_set) # type: List[Tuple]
|
||||
else:
|
||||
|
|
Loading…
Reference in a new issue