diff --git a/python/mypy.ini b/python/mypy.ini index 7dbd0f190d..c56f0161c3 100644 --- a/python/mypy.ini +++ b/python/mypy.ini @@ -158,6 +158,3 @@ ignore_missing_imports = True [mypy-pyspark.pandas.data_type_ops.*] disallow_untyped_defs = False - -[mypy-pyspark.pandas.frame] -disallow_untyped_defs = False diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 82e776d224..caee343f13 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -29,6 +29,7 @@ import types from functools import partial, reduce import sys from itertools import zip_longest +from types import TracebackType from typing import ( Any, Callable, @@ -41,9 +42,11 @@ from typing import ( Optional, Sequence, Tuple, + Type, TypeVar, Union, cast, + no_type_check, TYPE_CHECKING, ) import datetime @@ -64,19 +67,19 @@ else: from pandas.core.accessor import CachedAccessor from pandas.core.dtypes.inference import is_sequence from pyspark import StorageLevel -from pyspark import sql as spark from pyspark.sql import Column, DataFrame as SparkDataFrame, functions as F from pyspark.sql.functions import pandas_udf from pyspark.sql.types import ( # noqa: F401 (SPARK-34943) + ArrayType, BooleanType, DataType, DoubleType, FloatType, NumericType, + Row, StringType, StructField, StructType, - ArrayType, ) from pyspark.sql.window import Window @@ -118,6 +121,7 @@ from pyspark.pandas.typedef import ( infer_return_type, spark_type_to_pandas_dtype, DataFrameType, + Dtype, SeriesType, Scalar, ScalarType, @@ -125,6 +129,8 @@ from pyspark.pandas.typedef import ( from pyspark.pandas.plot import PandasOnSparkPlotAccessor if TYPE_CHECKING: + from pyspark.sql._typing import OptionalPrimitiveType # noqa: F401 (SPARK-34943) + from pyspark.pandas.groupby import DataFrameGroupBy # noqa: F401 (SPARK-34943) from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943) @@ -338,7 +344,7 @@ rectangle 16.0 2.348543e+108 T = TypeVar("T") -def _create_tuple_for_frame_type(params): +def _create_tuple_for_frame_type(params: Any) -> object: """ This is a workaround to support variadic generic in DataFrame. @@ -347,7 +353,7 @@ def _create_tuple_for_frame_type(params): """ from pyspark.pandas.typedef import NameTypeHolder - if isinstance(params, zip): + if isinstance(params, zip): # type: ignore params = [slice(name, tpe) for name, tpe in params] if isinstance(params, slice): @@ -367,7 +373,7 @@ def _create_tuple_for_frame_type(params): name_classes = [] for param in params: - new_class = type("NameType", (NameTypeHolder,), {}) + new_class = type("NameType", (NameTypeHolder,), {}) # type: Type[NameTypeHolder] new_class.name = param.start # When the given argument is a numpy's dtype instance. new_class.tpe = param.stop.type if isinstance(param.stop, np.dtype) else param.stop @@ -397,6 +403,7 @@ if (3, 5) <= sys.version_info < (3, 7) and __name__ != "__main__": # We wrap the input params by a tuple to mimic variadic generic. old_getitem = GenericMeta.__getitem__ # type: ignore + @no_type_check def new_getitem(self, params): if hasattr(self, "is_dataframe"): return old_getitem(self, _create_tuple_for_frame_type(params)) @@ -481,6 +488,7 @@ class DataFrame(Frame, Generic[T]): 4 2 5 4 3 9 """ + @no_type_check def __init__(self, data=None, index=None, columns=None, dtype=None, copy=False): if isinstance(data, InternalFrame): assert index is None @@ -488,7 +496,7 @@ class DataFrame(Frame, Generic[T]): assert dtype is None assert not copy internal = data - elif isinstance(data, spark.DataFrame): + elif isinstance(data, SparkDataFrame): assert index is None assert columns is None assert dtype is None @@ -515,7 +523,7 @@ class DataFrame(Frame, Generic[T]): object.__setattr__(self, "_internal_frame", internal) @property - def _pssers(self): + def _pssers(self) -> Dict[Tuple, "Series"]: """Return a dict of column label -> Series which anchors `self`.""" from pyspark.pandas.series import Series @@ -526,7 +534,7 @@ class DataFrame(Frame, Generic[T]): {label: Series(data=self, index=label) for label in self._internal.column_labels}, ) else: - psseries = self._psseries + psseries = self._psseries # type: ignore assert len(self._internal.column_labels) == len(psseries), ( len(self._internal.column_labels), len(psseries), @@ -543,9 +551,11 @@ class DataFrame(Frame, Generic[T]): @property def _internal(self) -> InternalFrame: - return self._internal_frame + return self._internal_frame # type: ignore - def _update_internal_frame(self, internal: InternalFrame, requires_same_anchor: bool = True): + def _update_internal_frame( + self, internal: InternalFrame, requires_same_anchor: bool = True + ) -> None: """ Update InternalFrame with the given one. @@ -739,7 +749,7 @@ class DataFrame(Frame, Generic[T]): ) return first_series(DataFrame(internal)).rename(pser.name) - def _psser_for(self, label): + def _psser_for(self, label: Tuple) -> "Series": """ Create Series with a proper column label. @@ -769,7 +779,7 @@ class DataFrame(Frame, Generic[T]): return self._pssers[label] def _apply_series_op( - self, op: Callable[["Series"], "Series"], should_resolve: bool = False + self, op: Callable[["Series"], Union["Series", Column]], should_resolve: bool = False ) -> "DataFrame": applied = [] for label in self._internal.column_labels: @@ -780,7 +790,7 @@ class DataFrame(Frame, Generic[T]): return DataFrame(internal) # Arithmetic Operators - def _map_series_op(self, op, other): + def _map_series_op(self, op: str, other: Any) -> "DataFrame": from pyspark.pandas.base import IndexOpsMixin if not isinstance(other, DataFrame) and ( @@ -797,7 +807,11 @@ class DataFrame(Frame, Generic[T]): if not same_anchor(self, other): # Different DataFrames - def apply_op(psdf, this_column_labels, that_column_labels): + def apply_op( + psdf: DataFrame, + this_column_labels: List[Tuple], + that_column_labels: List[Tuple], + ) -> Iterator[Tuple["Series", Tuple]]: for this_label, that_label in zip(this_column_labels, that_column_labels): yield ( getattr(psdf._psser_for(this_label), op)( @@ -833,52 +847,52 @@ class DataFrame(Frame, Generic[T]): else: return self._apply_series_op(lambda psser: getattr(psser, op)(other)) - def __add__(self, other) -> "DataFrame": + def __add__(self, other: Any) -> "DataFrame": return self._map_series_op("add", other) - def __radd__(self, other) -> "DataFrame": + def __radd__(self, other: Any) -> "DataFrame": return self._map_series_op("radd", other) - def __div__(self, other) -> "DataFrame": + def __div__(self, other: Any) -> "DataFrame": return self._map_series_op("div", other) - def __rdiv__(self, other) -> "DataFrame": + def __rdiv__(self, other: Any) -> "DataFrame": return self._map_series_op("rdiv", other) - def __truediv__(self, other) -> "DataFrame": + def __truediv__(self, other: Any) -> "DataFrame": return self._map_series_op("truediv", other) - def __rtruediv__(self, other) -> "DataFrame": + def __rtruediv__(self, other: Any) -> "DataFrame": return self._map_series_op("rtruediv", other) - def __mul__(self, other) -> "DataFrame": + def __mul__(self, other: Any) -> "DataFrame": return self._map_series_op("mul", other) - def __rmul__(self, other) -> "DataFrame": + def __rmul__(self, other: Any) -> "DataFrame": return self._map_series_op("rmul", other) - def __sub__(self, other) -> "DataFrame": + def __sub__(self, other: Any) -> "DataFrame": return self._map_series_op("sub", other) - def __rsub__(self, other) -> "DataFrame": + def __rsub__(self, other: Any) -> "DataFrame": return self._map_series_op("rsub", other) - def __pow__(self, other) -> "DataFrame": + def __pow__(self, other: Any) -> "DataFrame": return self._map_series_op("pow", other) - def __rpow__(self, other) -> "DataFrame": + def __rpow__(self, other: Any) -> "DataFrame": return self._map_series_op("rpow", other) - def __mod__(self, other) -> "DataFrame": + def __mod__(self, other: Any) -> "DataFrame": return self._map_series_op("mod", other) - def __rmod__(self, other) -> "DataFrame": + def __rmod__(self, other: Any) -> "DataFrame": return self._map_series_op("rmod", other) - def __floordiv__(self, other) -> "DataFrame": + def __floordiv__(self, other: Any) -> "DataFrame": return self._map_series_op("floordiv", other) - def __rfloordiv__(self, other) -> "DataFrame": + def __rfloordiv__(self, other: Any) -> "DataFrame": return self._map_series_op("rfloordiv", other) def __abs__(self) -> "DataFrame": @@ -887,7 +901,7 @@ class DataFrame(Frame, Generic[T]): def __neg__(self) -> "DataFrame": return self._apply_series_op(lambda psser: -psser) - def add(self, other) -> "DataFrame": + def add(self, other: Any) -> "DataFrame": return self + other # create accessor for plot @@ -902,11 +916,13 @@ class DataFrame(Frame, Generic[T]): # keep the name "koalas" for backward compatibility. koalas = CachedAccessor("koalas", PandasOnSparkFrameMethods) + @no_type_check def hist(self, bins=10, **kwds): return self.plot.hist(bins, **kwds) hist.__doc__ = PandasOnSparkPlotAccessor.hist.__doc__ + @no_type_check def kde(self, bw_method=None, ind=None, **kwds): return self.plot.kde(bw_method, ind, **kwds) @@ -916,14 +932,14 @@ class DataFrame(Frame, Generic[T]): desc="Addition", op_name="+", equiv="dataframe + other", reverse="radd" ) - def radd(self, other) -> "DataFrame": + def radd(self, other: Any) -> "DataFrame": return other + self radd.__doc__ = _flex_doc_FRAME.format( desc="Addition", op_name="+", equiv="other + dataframe", reverse="add" ) - def div(self, other) -> "DataFrame": + def div(self, other: Any) -> "DataFrame": return self / other div.__doc__ = _flex_doc_FRAME.format( @@ -932,28 +948,28 @@ class DataFrame(Frame, Generic[T]): divide = div - def rdiv(self, other) -> "DataFrame": + def rdiv(self, other: Any) -> "DataFrame": return other / self rdiv.__doc__ = _flex_doc_FRAME.format( desc="Floating division", op_name="/", equiv="other / dataframe", reverse="div" ) - def truediv(self, other) -> "DataFrame": + def truediv(self, other: Any) -> "DataFrame": return self / other truediv.__doc__ = _flex_doc_FRAME.format( desc="Floating division", op_name="/", equiv="dataframe / other", reverse="rtruediv" ) - def rtruediv(self, other) -> "DataFrame": + def rtruediv(self, other: Any) -> "DataFrame": return other / self rtruediv.__doc__ = _flex_doc_FRAME.format( desc="Floating division", op_name="/", equiv="other / dataframe", reverse="truediv" ) - def mul(self, other) -> "DataFrame": + def mul(self, other: Any) -> "DataFrame": return self * other mul.__doc__ = _flex_doc_FRAME.format( @@ -962,14 +978,14 @@ class DataFrame(Frame, Generic[T]): multiply = mul - def rmul(self, other) -> "DataFrame": + def rmul(self, other: Any) -> "DataFrame": return other * self rmul.__doc__ = _flex_doc_FRAME.format( desc="Multiplication", op_name="*", equiv="other * dataframe", reverse="mul" ) - def sub(self, other) -> "DataFrame": + def sub(self, other: Any) -> "DataFrame": return self - other sub.__doc__ = _flex_doc_FRAME.format( @@ -978,49 +994,49 @@ class DataFrame(Frame, Generic[T]): subtract = sub - def rsub(self, other) -> "DataFrame": + def rsub(self, other: Any) -> "DataFrame": return other - self rsub.__doc__ = _flex_doc_FRAME.format( desc="Subtraction", op_name="-", equiv="other - dataframe", reverse="sub" ) - def mod(self, other) -> "DataFrame": + def mod(self, other: Any) -> "DataFrame": return self % other mod.__doc__ = _flex_doc_FRAME.format( desc="Modulo", op_name="%", equiv="dataframe % other", reverse="rmod" ) - def rmod(self, other) -> "DataFrame": + def rmod(self, other: Any) -> "DataFrame": return other % self rmod.__doc__ = _flex_doc_FRAME.format( desc="Modulo", op_name="%", equiv="other % dataframe", reverse="mod" ) - def pow(self, other) -> "DataFrame": + def pow(self, other: Any) -> "DataFrame": return self ** other pow.__doc__ = _flex_doc_FRAME.format( desc="Exponential power of series", op_name="**", equiv="dataframe ** other", reverse="rpow" ) - def rpow(self, other) -> "DataFrame": + def rpow(self, other: Any) -> "DataFrame": return other ** self rpow.__doc__ = _flex_doc_FRAME.format( desc="Exponential power", op_name="**", equiv="other ** dataframe", reverse="pow" ) - def floordiv(self, other) -> "DataFrame": + def floordiv(self, other: Any) -> "DataFrame": return self // other floordiv.__doc__ = _flex_doc_FRAME.format( desc="Integer division", op_name="//", equiv="dataframe // other", reverse="rfloordiv" ) - def rfloordiv(self, other) -> "DataFrame": + def rfloordiv(self, other: Any) -> "DataFrame": return other // self rfloordiv.__doc__ = _flex_doc_FRAME.format( @@ -1028,25 +1044,25 @@ class DataFrame(Frame, Generic[T]): ) # Comparison Operators - def __eq__(self, other) -> "DataFrame": # type: ignore + def __eq__(self, other: Any) -> "DataFrame": # type: ignore[override] return self._map_series_op("eq", other) - def __ne__(self, other) -> "DataFrame": # type: ignore + def __ne__(self, other: Any) -> "DataFrame": # type: ignore[override] return self._map_series_op("ne", other) - def __lt__(self, other) -> "DataFrame": + def __lt__(self, other: Any) -> "DataFrame": return self._map_series_op("lt", other) - def __le__(self, other) -> "DataFrame": + def __le__(self, other: Any) -> "DataFrame": return self._map_series_op("le", other) - def __ge__(self, other) -> "DataFrame": + def __ge__(self, other: Any) -> "DataFrame": return self._map_series_op("ge", other) - def __gt__(self, other) -> "DataFrame": + def __gt__(self, other: Any) -> "DataFrame": return self._map_series_op("gt", other) - def eq(self, other) -> "DataFrame": + def eq(self, other: Any) -> "DataFrame": """ Compare if the current value is equal to the other. @@ -1065,7 +1081,7 @@ class DataFrame(Frame, Generic[T]): equals = eq - def gt(self, other) -> "DataFrame": + def gt(self, other: Any) -> "DataFrame": """ Compare if the current value is greater than the other. @@ -1082,7 +1098,7 @@ class DataFrame(Frame, Generic[T]): """ return self > other - def ge(self, other) -> "DataFrame": + def ge(self, other: Any) -> "DataFrame": """ Compare if the current value is greater than or equal to the other. @@ -1099,7 +1115,7 @@ class DataFrame(Frame, Generic[T]): """ return self >= other - def lt(self, other) -> "DataFrame": + def lt(self, other: Any) -> "DataFrame": """ Compare if the current value is less than the other. @@ -1116,7 +1132,7 @@ class DataFrame(Frame, Generic[T]): """ return self < other - def le(self, other) -> "DataFrame": + def le(self, other: Any) -> "DataFrame": """ Compare if the current value is less than or equal to the other. @@ -1133,7 +1149,7 @@ class DataFrame(Frame, Generic[T]): """ return self <= other - def ne(self, other) -> "DataFrame": + def ne(self, other: Any) -> "DataFrame": """ Compare if the current value is not equal to the other. @@ -1150,7 +1166,7 @@ class DataFrame(Frame, Generic[T]): """ return self != other - def applymap(self, func) -> "DataFrame": + def applymap(self, func: Callable[[Any], Any]) -> "DataFrame": """ Apply a function to a Dataframe elementwise. @@ -1213,9 +1229,7 @@ class DataFrame(Frame, Generic[T]): return self._apply_series_op(lambda psser: psser.apply(func)) # TODO: not all arguments are implemented comparing to pandas' for now. - def aggregate( - self, func: Union[List[str], Dict[Any, List[str]]] - ) -> Union["Series", "DataFrame", "Index"]: + def aggregate(self, func: Union[List[str], Dict[Any, List[str]]]) -> "DataFrame": """Aggregate using one or more operations over the specified axis. Parameters @@ -1333,7 +1347,7 @@ class DataFrame(Frame, Generic[T]): agg = aggregate - def corr(self, method="pearson") -> Union["Series", "DataFrame", "Index"]: + def corr(self, method: str = "pearson") -> "DataFrame": """ Compute pairwise correlation of columns, excluding NA/null values. @@ -1375,9 +1389,9 @@ class DataFrame(Frame, Generic[T]): * `min_periods` argument is not supported """ - return ps.from_pandas(corr(self, method)) + return cast(DataFrame, ps.from_pandas(corr(self, method))) - def iteritems(self) -> Iterator: + def iteritems(self) -> Iterator[Tuple[Union[Any, Tuple], "Series"]]: """ Iterator over (column name, Series) pairs. @@ -1421,7 +1435,7 @@ class DataFrame(Frame, Generic[T]): for label in self._internal.column_labels ) - def iterrows(self) -> Iterator: + def iterrows(self) -> Iterator[Tuple[Union[Any, Tuple], pd.Series]]: """ Iterate over DataFrame rows as (index, Series) pairs. @@ -1467,7 +1481,7 @@ class DataFrame(Frame, Generic[T]): internal_index_columns = self._internal.index_spark_column_names internal_data_columns = self._internal.data_spark_column_names - def extract_kv_from_spark_row(row): + def extract_kv_from_spark_row(row: Row) -> Tuple[Union[Any, Tuple], Any]: k = ( row[internal_index_columns[0]] if len(internal_index_columns) == 1 @@ -1482,7 +1496,9 @@ class DataFrame(Frame, Generic[T]): s = pd.Series(v, index=columns, name=k) yield k, s - def itertuples(self, index: bool = True, name: Optional[str] = "PandasOnSpark") -> Iterator: + def itertuples( + self, index: bool = True, name: Optional[str] = "PandasOnSpark" + ) -> Iterator[Tuple]: """ Iterate over DataFrame rows as namedtuples. @@ -1554,7 +1570,7 @@ class DataFrame(Frame, Generic[T]): index_spark_column_names = self._internal.index_spark_column_names data_spark_column_names = self._internal.data_spark_column_names - def extract_kv_from_spark_row(row): + def extract_kv_from_spark_row(row: Row) -> Tuple[Union[Any, Tuple], Any]: k = ( row[index_spark_column_names[0]] if len(index_spark_column_names) == 1 @@ -1579,7 +1595,7 @@ class DataFrame(Frame, Generic[T]): ): yield tuple(([k] if index else []) + list(v)) - def items(self) -> Iterator: + def items(self) -> Iterator[Tuple[Union[Any, Tuple], "Series"]]: """This is an alias of ``iteritems``.""" return self.iteritems() @@ -1660,28 +1676,30 @@ class DataFrame(Frame, Generic[T]): def to_html( self, - buf=None, - columns=None, - col_space=None, - header=True, - index=True, - na_rep="NaN", - formatters=None, - float_format=None, - sparsify=None, - index_names=True, - justify=None, - max_rows=None, - max_cols=None, - show_dimensions=False, - decimal=".", - bold_rows=True, - classes=None, - escape=True, - notebook=False, - border=None, - table_id=None, - render_links=False, + buf: Optional[IO[str]] = None, + columns: Optional[Sequence[Union[Any, Tuple]]] = None, + col_space: Optional[Union[str, int, Dict[Union[Any, Tuple], Union[str, int]]]] = None, + header: bool = True, + index: bool = True, + na_rep: str = "NaN", + formatters: Optional[ + Union[List[Callable[[Any], str]], Dict[Union[Any, Tuple], Callable[[Any], str]]] + ] = None, + float_format: Optional[Callable[[float], str]] = None, + sparsify: Optional[bool] = None, + index_names: bool = True, + justify: Optional[str] = None, + max_rows: Optional[int] = None, + max_cols: Optional[int] = None, + show_dimensions: bool = False, + decimal: str = ".", + bold_rows: bool = True, + classes: Optional[Union[str, list, tuple]] = None, + escape: bool = True, + notebook: bool = False, + border: Optional[int] = None, + table_id: Optional[str] = None, + render_links: bool = False, ) -> Optional[str]: """ Render a DataFrame as an HTML table. @@ -1780,22 +1798,24 @@ class DataFrame(Frame, Generic[T]): def to_string( self, - buf=None, - columns=None, - col_space=None, - header=True, - index=True, - na_rep="NaN", - formatters=None, - float_format=None, - sparsify=None, - index_names=True, - justify=None, - max_rows=None, - max_cols=None, - show_dimensions=False, - decimal=".", - line_width=None, + buf: Optional[IO[str]] = None, + columns: Optional[Sequence[Union[Any, Tuple]]] = None, + col_space: Optional[Union[str, int, Dict[Union[Any, Tuple], Union[str, int]]]] = None, + header: bool = True, + index: bool = True, + na_rep: str = "NaN", + formatters: Optional[ + Union[List[Callable[[Any], str]], Dict[Union[Any, Tuple], Callable[[Any], str]]] + ] = None, + float_format: Optional[Callable[[float], str]] = None, + sparsify: Optional[bool] = None, + index_names: bool = True, + justify: Optional[str] = None, + max_rows: Optional[int] = None, + max_cols: Optional[int] = None, + show_dimensions: bool = False, + decimal: str = ".", + line_width: Optional[int] = None, ) -> Optional[str]: """ Render a DataFrame to a console-friendly tabular output. @@ -1893,7 +1913,7 @@ class DataFrame(Frame, Generic[T]): psdf._to_internal_pandas(), self.to_string, pd.DataFrame.to_string, args ) - def to_dict(self, orient="dict", into=dict) -> Union[List, Mapping]: + def to_dict(self, orient: str = "dict", into: Type = dict) -> Union[List, Mapping]: """ Convert the DataFrame to a dictionary. @@ -2318,7 +2338,9 @@ defaultdict(, {'col..., 'col...})] T = property(transpose) - def apply(self, func, axis=0, args=(), **kwds) -> Union["Series", "DataFrame", "Index"]: + def apply( + self, func: Callable, axis: Union[int, str] = 0, args: Sequence[Any] = (), **kwds: Any + ) -> Union["Series", "DataFrame", "Index"]: """ Apply a function along an axis of the DataFrame. @@ -2517,7 +2539,7 @@ defaultdict(, {'col..., 'col...})] return_sig = spec.annotations.get("return", None) should_infer_schema = return_sig is None - def apply_func(pdf): + def apply_func(pdf: pd.DataFrame) -> pd.DataFrame: pdf_or_pser = pdf.apply(func, axis=axis, args=args, **kwds) if isinstance(pdf_or_pser, pd.Series): return pdf_or_pser.to_frame() @@ -2625,7 +2647,9 @@ defaultdict(, {'col..., 'col...})] else: return result - def transform(self, func, axis=0, *args, **kwargs) -> "DataFrame": + def transform( + self, func: Callable[..., "Series"], axis: Union[int, str] = 0, *args: Any, **kwargs: Any + ) -> "DataFrame": """ Call ``func`` on self producing a Series with transformed values and that has the same length as its input. @@ -2781,7 +2805,7 @@ defaultdict(, {'col..., 'col...})] lambda psser: psser.pandas_on_spark.transform_batch(func, *args, **kwargs) ) - def pop(self, item) -> "DataFrame": + def pop(self, item: Union[Any, Tuple]) -> "DataFrame": """ Return item and drop from frame. Raise KeyError if not found. @@ -2860,7 +2884,9 @@ defaultdict(, {'col..., 'col...})] return result # TODO: add axis parameter can work when '1' or 'columns' - def xs(self, key, axis=0, level=None) -> Union["DataFrame", "Series"]: + def xs( + self, key: Union[Any, Tuple], axis: Union[int, str] = 0, level: Optional[int] = None + ) -> Union["DataFrame", "Series"]: """ Return cross-section from the DataFrame. @@ -3172,7 +3198,9 @@ defaultdict(, {'col..., 'col...})] ) ) - def where(self, cond, other=np.nan) -> "DataFrame": + def where( + self, cond: Union["DataFrame", "Series"], other: Union["DataFrame", "Series", Any] = np.nan + ) -> "DataFrame": """ Replace values where the condition is False. @@ -3363,7 +3391,9 @@ defaultdict(, {'col..., 'col...})] ) ) - def mask(self, cond, other=np.nan) -> "DataFrame": + def mask( + self, cond: Union["DataFrame", "Series"], other: Union["DataFrame", "Series", Any] = np.nan + ) -> "DataFrame": """ Replace values where the condition is True. @@ -3498,7 +3528,13 @@ defaultdict(, {'col..., 'col...})] warnings.warn("'style' property will only use top %s rows." % max_results, UserWarning) return pdf.head(max_results).style - def set_index(self, keys, drop=True, append=False, inplace=False) -> Optional["DataFrame"]: + def set_index( + self, + keys: Union[Any, Tuple, List[Union[Any, Tuple]]], + drop: bool = True, + append: bool = False, + inplace: bool = False, + ) -> Optional["DataFrame"]: """Set the DataFrame index (row labels) using one or more existing columns. Set the DataFrame index (row labels) using one or more existing @@ -3563,32 +3599,34 @@ defaultdict(, {'col..., 'col...})] """ inplace = validate_bool_kwarg(inplace, "inplace") if is_name_like_tuple(keys): - keys = [keys] + key_list = [cast(Tuple, keys)] # type: List[Tuple] elif is_name_like_value(keys): - keys = [(keys,)] + key_list = [(keys,)] else: - keys = [key if is_name_like_tuple(key) else (key,) for key in keys] + key_list = [key if is_name_like_tuple(key) else (key,) for key in keys] columns = set(self._internal.column_labels) - for key in keys: + for key in key_list: if key not in columns: raise KeyError(name_like_string(key)) if drop: - column_labels = [label for label in self._internal.column_labels if label not in keys] + column_labels = [ + label for label in self._internal.column_labels if label not in key_list + ] else: column_labels = self._internal.column_labels if append: index_spark_columns = self._internal.index_spark_columns + [ - self._internal.spark_column_for(label) for label in keys + self._internal.spark_column_for(label) for label in key_list ] - index_names = self._internal.index_names + keys + index_names = self._internal.index_names + key_list index_fields = self._internal.index_fields + [ - self._internal.field_for(label) for label in keys + self._internal.field_for(label) for label in key_list ] else: - index_spark_columns = [self._internal.spark_column_for(label) for label in keys] - index_names = keys - index_fields = [self._internal.field_for(label) for label in keys] + index_spark_columns = [self._internal.spark_column_for(label) for label in key_list] + index_names = key_list + index_fields = [self._internal.field_for(label) for label in key_list] internal = self._internal.copy( index_spark_columns=index_spark_columns, @@ -3758,7 +3796,7 @@ defaultdict(, {'col..., 'col...})] inplace = validate_bool_kwarg(inplace, "inplace") multi_index = self._internal.index_level > 1 - def rename(index): + def rename(index: int) -> Tuple: if multi_index: return ("level_{}".format(index),) else: @@ -3950,7 +3988,7 @@ defaultdict(, {'col..., 'col...})] def insert( self, loc: int, - column, + column: Union[Any, Tuple], value: Union[Scalar, "Series", Iterable], allow_duplicates: bool = False, ) -> None: @@ -4029,7 +4067,7 @@ defaultdict(, {'col..., 'col...})] self._update_internal_frame(psdf._internal) # TODO: add frep and axis parameter - def shift(self, periods=1, fill_value=None) -> "DataFrame": + def shift(self, periods: int = 1, fill_value: Optional[Any] = None) -> "DataFrame": """ Shift DataFrame by desired number of periods. @@ -4233,7 +4271,9 @@ defaultdict(, {'col..., 'col...})] ) return first_series(DataFrame(internal).transpose()) - def round(self, decimals=0) -> "DataFrame": + def round( + self, decimals: Union[int, Dict[Union[Any, Tuple], int], "Series"] = 0 + ) -> "DataFrame": """ Round a DataFrame to a variable number of decimal places. @@ -4293,42 +4333,44 @@ defaultdict(, {'col..., 'col...})] third 0.9 0.0 0.49 """ if isinstance(decimals, ps.Series): - decimals = { + decimals_dict = { k if isinstance(k, tuple) else (k,): v for k, v in decimals._to_internal_pandas().items() } elif isinstance(decimals, dict): - decimals = {k if is_name_like_tuple(k) else (k,): v for k, v in decimals.items()} + decimals_dict = {k if is_name_like_tuple(k) else (k,): v for k, v in decimals.items()} elif isinstance(decimals, int): - decimals = {k: decimals for k in self._internal.column_labels} + decimals_dict = {k: decimals for k in self._internal.column_labels} else: raise TypeError("decimals must be an integer, a dict-like or a Series") - def op(psser): + def op(psser: ps.Series) -> Union[ps.Series, Column]: label = psser._column_label - if label in decimals: - return F.round(psser.spark.column, decimals[label]).alias( - psser._internal.data_spark_column_names[0] - ) + if label in decimals_dict: + return F.round(psser.spark.column, decimals_dict[label]) else: return psser return self._apply_series_op(op) - def _mark_duplicates(self, subset=None, keep="first"): + def _mark_duplicates( + self, + subset: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, + keep: str = "first", + ) -> Tuple[SparkDataFrame, str]: if subset is None: - subset = self._internal.column_labels + subset_list = self._internal.column_labels else: if is_name_like_tuple(subset): - subset = [subset] + subset_list = [cast(Tuple, subset)] elif is_name_like_value(subset): - subset = [(subset,)] + subset_list = [(subset,)] else: - subset = [sub if is_name_like_tuple(sub) else (sub,) for sub in subset] - diff = set(subset).difference(set(self._internal.column_labels)) + subset_list = [sub if is_name_like_tuple(sub) else (sub,) for sub in subset] + diff = set(subset_list).difference(set(self._internal.column_labels)) if len(diff) > 0: raise KeyError(", ".join([name_like_string(d) for d in diff])) - group_cols = [self._internal.spark_column_name_for(label) for label in subset] + group_cols = [self._internal.spark_column_name_for(label) for label in subset_list] sdf = self._internal.resolved_copy.spark_frame @@ -4336,17 +4378,17 @@ defaultdict(, {'col..., 'col...})] if keep == "first" or keep == "last": if keep == "first": - ord_func = spark.functions.asc + ord_func = F.asc else: - ord_func = spark.functions.desc + ord_func = F.desc window = ( - Window.partitionBy(group_cols) + Window.partitionBy(*group_cols) .orderBy(ord_func(NATURAL_ORDER_COLUMN_NAME)) .rowsBetween(Window.unboundedPreceding, Window.currentRow) ) sdf = sdf.withColumn(column, F.row_number().over(window) > 1) elif not keep: - window = Window.partitionBy(group_cols).rowsBetween( + window = Window.partitionBy(*group_cols).rowsBetween( Window.unboundedPreceding, Window.unboundedFollowing ) sdf = sdf.withColumn(column, F.count("*").over(window) > 1) @@ -4354,7 +4396,11 @@ defaultdict(, {'col..., 'col...})] raise ValueError("'keep' only supports 'first', 'last' and False") return sdf, column - def duplicated(self, subset=None, keep="first") -> "Series": + def duplicated( + self, + subset: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, + keep: str = "first", + ) -> "Series": """ Return boolean Series denoting duplicate rows, optionally only considering certain columns. @@ -4513,7 +4559,7 @@ defaultdict(, {'col..., 'col...})] else: return cast(ps.Series, other.dot(self.transpose())).rename(None) - def __matmul__(self, other): + def __matmul__(self, other: "Series") -> "Series": """ Matrix multiplication using binary `@` operator in Python>=3.5. """ @@ -4577,7 +4623,7 @@ defaultdict(, {'col..., 'col...})] if isinstance(self, DataFrame): return self else: - assert isinstance(self, spark.DataFrame), type(self) + assert isinstance(self, SparkDataFrame), type(self) from pyspark.pandas.namespace import _get_index_map index_spark_columns, index_names = _get_index_map(self, index_col) @@ -4601,7 +4647,7 @@ defaultdict(, {'col..., 'col...})] mode: str = "overwrite", partition_cols: Optional[Union[str, List[str]]] = None, index_col: Optional[Union[str, List[str]]] = None, - **options + **options: Any ) -> None: return self.spark.to_table(name, format, mode, partition_cols, index_col, **options) @@ -4613,7 +4659,7 @@ defaultdict(, {'col..., 'col...})] mode: str = "overwrite", partition_cols: Optional[Union[str, List[str]]] = None, index_col: Optional[Union[str, List[str]]] = None, - **options + **options: "OptionalPrimitiveType" ) -> None: """ Write the DataFrame out as a Delta Lake table. @@ -4692,7 +4738,7 @@ defaultdict(, {'col..., 'col...})] partition_cols: Optional[Union[str, List[str]]] = None, compression: Optional[str] = None, index_col: Optional[Union[str, List[str]]] = None, - **options + **options: Any ) -> None: """ Write the DataFrame out as a Parquet file or directory. @@ -4763,7 +4809,7 @@ defaultdict(, {'col..., 'col...})] mode: str = "overwrite", partition_cols: Optional[Union[str, List[str]]] = None, index_col: Optional[Union[str, List[str]]] = None, - **options + **options: "OptionalPrimitiveType" ) -> None: """ Write the DataFrame out as a ORC file or directory. @@ -4835,7 +4881,7 @@ defaultdict(, {'col..., 'col...})] mode: str = "overwrite", partition_cols: Optional[Union[str, List[str]]] = None, index_col: Optional[Union[str, List[str]]] = None, - **options + **options: "OptionalPrimitiveType" ) -> None: """An alias for :func:`DataFrame.spark.to_spark_io`. See :meth:`pyspark.pandas.spark.accessors.SparkFrameMethods.to_spark_io`. @@ -4873,7 +4919,7 @@ defaultdict(, {'col..., 'col...})] """ return self._internal.to_pandas_frame.copy() - def assign(self, **kwargs) -> "DataFrame": + def assign(self, **kwargs: Any) -> "DataFrame": """ Assign new columns to a DataFrame. @@ -4934,14 +4980,14 @@ defaultdict(, {'col..., 'col...})] """ return self._assign(kwargs) - def _assign(self, kwargs): + def _assign(self, kwargs: Any) -> "DataFrame": assert isinstance(kwargs, dict) from pyspark.pandas.indexes import MultiIndex from pyspark.pandas.series import IndexOpsMixin for k, v in kwargs.items(): is_invalid_assignee = ( - not (isinstance(v, (IndexOpsMixin, spark.Column)) or callable(v) or is_scalar(v)) + not (isinstance(v, (IndexOpsMixin, Column)) or callable(v) or is_scalar(v)) ) or isinstance(v, MultiIndex) if is_invalid_assignee: raise TypeError( @@ -4955,7 +5001,7 @@ defaultdict(, {'col..., 'col...})] (v.spark.column, v._internal.data_fields[0]) if isinstance(v, IndexOpsMixin) and not isinstance(v, MultiIndex) else (v, None) - if isinstance(v, spark.Column) + if isinstance(v, Column) else (F.lit(v), None) ) for k, v in kwargs.items() @@ -5062,7 +5108,16 @@ defaultdict(, {'col..., 'col...})] pd.DataFrame.from_records(data, index, exclude, columns, coerce_float, nrows) ) - def to_records(self, index=True, column_dtypes=None, index_dtypes=None) -> np.recarray: + def to_records( + self, + index: bool = True, + column_dtypes: Optional[ + Union[str, Dtype, Dict[Union[Any, Tuple], Union[str, Dtype]]] + ] = None, + index_dtypes: Optional[ + Union[str, Dtype, Dict[Union[Any, Tuple], Union[str, Dtype]]] + ] = None, + ) -> np.recarray: """ Convert DataFrame to a NumPy record array. @@ -5171,7 +5226,12 @@ defaultdict(, {'col..., 'col...})] return DataFrame(self._internal) def dropna( - self, axis=0, how="any", thresh=None, subset=None, inplace=False + self, + axis: Union[int, str] = 0, + how: str = "any", + thresh: Optional[int] = None, + subset: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, + inplace: bool = False, ) -> Optional["DataFrame"]: """ Remove missing values. @@ -5377,7 +5437,12 @@ defaultdict(, {'col..., 'col...})] # TODO: add 'limit' when value parameter exists def fillna( - self, value=None, method=None, axis=None, inplace=False, limit=None + self, + value: Optional[Union[Any, Dict[Union[Any, Tuple], Any]]] = None, + method: Optional[str] = None, + axis: Optional[Union[int, str]] = None, + inplace: bool = False, + limit: Optional[int] = None, ) -> Optional["DataFrame"]: """Fill NA/NaN values. @@ -5474,7 +5539,7 @@ defaultdict(, {'col..., 'col...})] raise TypeError("Unsupported type %s" % type(v).__name__) value = {k if is_name_like_tuple(k) else (k,): v for k, v in value.items()} - def op(psser): + def op(psser: ps.Series) -> ps.Series: label = psser._column_label for k, v in value.items(): if k == label[: len(k)]: @@ -5502,12 +5567,12 @@ defaultdict(, {'col..., 'col...})] def replace( self, - to_replace=None, - value=None, - inplace=False, - limit=None, - regex=False, - method="pad", + to_replace: Optional[Union[Any, List, Tuple, Dict]] = None, + value: Optional[Any] = None, + inplace: bool = False, + limit: Optional[int] = None, + regex: bool = False, + method: str = "pad", ) -> Optional["DataFrame"]: """ Returns a new DataFrame replacing a value with another value. @@ -5611,11 +5676,12 @@ defaultdict(, {'col..., 'col...})] if isinstance(to_replace, dict) and ( value is not None or all(isinstance(i, dict) for i in to_replace.values()) ): + to_replace_dict = cast(dict, to_replace) - def op(psser): - if psser.name in to_replace: + def op(psser: ps.Series) -> ps.Series: + if psser.name in to_replace_dict: return psser.replace( - to_replace=to_replace[psser.name], value=value, regex=regex + to_replace=to_replace_dict[psser.name], value=value, regex=regex ) else: return psser @@ -5846,7 +5912,12 @@ defaultdict(, {'col..., 'col...})] return cast(DataFrame, self.loc[:to_date]) def pivot_table( - self, values=None, index=None, columns=None, aggfunc="mean", fill_value=None + self, + values: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, + index: Optional[List[Union[Any, Tuple]]] = None, + columns: Optional[Union[Any, Tuple]] = None, + aggfunc: Union[str, Dict[Union[Any, Tuple], str]] = "mean", + fill_value: Optional[Any] = None, ) -> "DataFrame": """ Create a spreadsheet-style pivot table as a DataFrame. The levels in @@ -6066,7 +6137,9 @@ defaultdict(, {'col..., 'col...})] tuple(list(column_name_to_index[name.split("_")[1]]) + [name.split("_")[0]]) for name in data_columns ] - column_label_names = ([None] * column_labels_level(values)) + [columns] + column_label_names = ( + [cast(Optional[Union[Any, Tuple]], None)] * column_labels_level(values) + ) + [columns] internal = InternalFrame( spark_frame=sdf, index_spark_columns=[scol_for(sdf, col) for col in index_columns], @@ -6074,12 +6147,14 @@ defaultdict(, {'col..., 'col...})] index_fields=index_fields, column_labels=column_labels, data_spark_columns=[scol_for(sdf, col) for col in data_columns], - column_label_names=column_label_names, # type: ignore + column_label_names=column_label_names, ) psdf = DataFrame(internal) # type: "DataFrame" else: column_labels = [tuple(list(values[0]) + [column]) for column in data_columns] - column_label_names = ([None] * len(values[0])) + [columns] + column_label_names = ( + [cast(Optional[Union[Any, Tuple]], None)] * len(values[0]) + ) + [columns] internal = InternalFrame( spark_frame=sdf, index_spark_columns=[scol_for(sdf, col) for col in index_columns], @@ -6087,7 +6162,7 @@ defaultdict(, {'col..., 'col...})] index_fields=index_fields, column_labels=column_labels, data_spark_columns=[scol_for(sdf, col) for col in data_columns], - column_label_names=column_label_names, # type: ignore + column_label_names=column_label_names, ) psdf = DataFrame(internal) else: @@ -6132,7 +6207,12 @@ defaultdict(, {'col..., 'col...})] return psdf - def pivot(self, index=None, columns=None, values=None) -> "DataFrame": + def pivot( + self, + index: Optional[Union[Any, Tuple]] = None, + columns: Optional[Union[Any, Tuple]] = None, + values: Optional[Union[Any, Tuple]] = None, + ) -> "DataFrame": """ Return reshaped DataFrame organized by given index / column values. @@ -6248,15 +6328,15 @@ defaultdict(, {'col..., 'col...})] should_use_existing_index = index is not None if should_use_existing_index: df = self - index = [index] + index_labels = [index] else: # The index after `reset_index()` will never be used, so use "distributed" index # as a dummy to avoid overhead. with option_context("compute.default_index_type", "distributed"): df = self.reset_index() - index = df._internal.column_labels[: self._internal.index_level] + index_labels = df._internal.column_labels[: self._internal.index_level] - df = df.pivot_table(index=index, columns=columns, values=values, aggfunc="first") + df = df.pivot_table(index=index_labels, columns=columns, values=values, aggfunc="first") if should_use_existing_index: return df @@ -6278,7 +6358,7 @@ defaultdict(, {'col..., 'col...})] return columns @columns.setter - def columns(self, columns) -> None: + def columns(self, columns: Union[pd.Index, List[Union[Any, Tuple]]]) -> None: if isinstance(columns, pd.MultiIndex): column_labels = columns.tolist() else: @@ -6346,7 +6426,11 @@ defaultdict(, {'col..., 'col...})] ), ) - def select_dtypes(self, include=None, exclude=None) -> "DataFrame": + def select_dtypes( + self, + include: Optional[Union[str, List[str]]] = None, + exclude: Optional[Union[str, List[str]]] = None, + ) -> "DataFrame": """ Return a subset of the DataFrame's columns based on the column dtypes. @@ -6447,31 +6531,35 @@ defaultdict(, {'col..., 'col...})] from pyspark.sql.types import _parse_datatype_string # type: ignore if not is_list_like(include): - include = (include,) if include is not None else () + include_list = [include] if include is not None else [] + else: + include_list = list(include) if not is_list_like(exclude): - exclude = (exclude,) if exclude is not None else () + exclude_list = [exclude] if exclude is not None else [] + else: + exclude_list = list(exclude) - if not any((include, exclude)): + if not any((include_list, exclude_list)): raise ValueError("at least one of include or exclude must be " "nonempty") # can't both include AND exclude! - if set(include).intersection(set(exclude)): + if set(include_list).intersection(set(exclude_list)): raise ValueError( "include and exclude overlap on {inc_ex}".format( - inc_ex=set(include).intersection(set(exclude)) + inc_ex=set(include_list).intersection(set(exclude_list)) ) ) # Handle Spark types include_spark_type = [] - for inc in include: + for inc in include_list: try: include_spark_type.append(_parse_datatype_string(inc)) except: pass exclude_spark_type = [] - for exc in exclude: + for exc in exclude_list: try: exclude_spark_type.append(_parse_datatype_string(exc)) except: @@ -6479,14 +6567,14 @@ defaultdict(, {'col..., 'col...})] # Handle pandas types include_numpy_type = [] - for inc in include: + for inc in include_list: try: include_numpy_type.append(infer_dtype_from_object(inc)) except: pass exclude_numpy_type = [] - for exc in exclude: + for exc in exclude_list: try: exclude_numpy_type.append(infer_dtype_from_object(exc)) except: @@ -6494,7 +6582,7 @@ defaultdict(, {'col..., 'col...})] column_labels = [] for label in self._internal.column_labels: - if len(include) > 0: + if len(include_list) > 0: should_include = ( infer_dtype_from_object(self._psser_for(label).dtype.name) in include_numpy_type or self._internal.spark_type_for(label) in include_spark_type @@ -6626,7 +6714,10 @@ defaultdict(, {'col..., 'col...})] return psdf def drop( - self, labels=None, axis=1, columns: Union[Any, Tuple, List[Any], List[Tuple]] = None + self, + labels: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, + axis: Union[int, str] = 1, + columns: Union[Any, Tuple, List[Any], List[Tuple]] = None, ) -> "DataFrame": """ Drop specified labels from columns. @@ -6735,8 +6826,8 @@ defaultdict(, {'col..., 'col...})] raise ValueError("Need to specify at least one of 'labels' or 'columns'") def _sort( - self, by: List[Column], ascending: Union[bool, List[bool]], inplace: bool, na_position: str - ): + self, by: List[Column], ascending: Union[bool, List[bool]], na_position: str + ) -> "DataFrame": if isinstance(ascending, bool): ascending = [ascending] * len(by) if len(ascending) != len(by): @@ -6756,12 +6847,7 @@ defaultdict(, {'col..., 'col...})] } by = [mapper[(asc, na_position)](scol) for scol, asc in zip(by, ascending)] sdf = self._internal.resolved_copy.spark_frame.sort(*by, NATURAL_ORDER_COLUMN_NAME) - psdf = DataFrame(self._internal.with_new_sdf(sdf)) # type: DataFrame - if inplace: - self._update_internal_frame(psdf._internal) - return None - else: - return psdf + return DataFrame(self._internal.with_new_sdf(sdf)) def sort_values( self, @@ -6858,7 +6944,12 @@ defaultdict(, {'col..., 'col...})] ) new_by.append(ser.spark.column) - return self._sort(by=new_by, ascending=ascending, inplace=inplace, na_position=na_position) + psdf = self._sort(by=new_by, ascending=ascending, na_position=na_position) + if inplace: + self._update_internal_frame(psdf._internal) + return None + else: + return psdf def sort_index( self, @@ -6962,7 +7053,12 @@ defaultdict(, {'col..., 'col...})] else: by = [self._internal.index_spark_columns[level]] # type: ignore - return self._sort(by=by, ascending=ascending, inplace=inplace, na_position=na_position) + psdf = self._sort(by=by, ascending=ascending, na_position=na_position) + if inplace: + self._update_internal_frame(psdf._internal) + return None + else: + return psdf def swaplevel( self, @@ -7331,7 +7427,7 @@ defaultdict(, {'col..., 'col...})] """ return self.sort_values(by=columns, ascending=True).head(n=n) - def isin(self, values) -> "DataFrame": + def isin(self, values: Union[List, Dict]) -> "DataFrame": """ Whether each element in the DataFrame is contained in values. @@ -7612,17 +7708,17 @@ defaultdict(, {'col..., 'col...})] how = validate_how(how) - def resolve(internal, side): + def resolve(internal: InternalFrame, side: str) -> InternalFrame: rename = lambda col: "__{}_{}".format(side, col) internal = internal.resolved_copy sdf = internal.spark_frame sdf = sdf.select( - [ + *[ scol_for(sdf, col).alias(rename(col)) for col in sdf.columns if col not in HIDDEN_COLUMNS - ] - + list(HIDDEN_COLUMNS) + ], + *HIDDEN_COLUMNS, ) return internal.copy( spark_frame=sdf, @@ -8161,7 +8257,9 @@ defaultdict(, {'col..., 'col...})] ) return DataFrame(self._internal.with_new_sdf(sdf)) - def astype(self, dtype) -> "DataFrame": + def astype( + self, dtype: Union[str, Dtype, Dict[Union[Any, Tuple], Union[str, Dtype]]] + ) -> "DataFrame": """ Cast a pandas-on-Spark object to a specified dtype ``dtype``. @@ -8217,15 +8315,16 @@ defaultdict(, {'col..., 'col...})] """ applied = [] if is_dict_like(dtype): - for col_name in dtype.keys(): + dtype_dict = cast(Dict[Union[Any, Tuple], Union[str, Dtype]], dtype) + for col_name in dtype_dict.keys(): if col_name not in self.columns: raise KeyError( "Only a column name can be used for the " "key in a dtype mappings argument." ) for col_name, col in self.items(): - if col_name in dtype: - applied.append(col.astype(dtype=dtype[col_name])) + if col_name in dtype_dict: + applied.append(col.astype(dtype=dtype_dict[col_name])) else: applied.append(col) else: @@ -8233,7 +8332,7 @@ defaultdict(, {'col..., 'col...})] applied.append(col.astype(dtype=dtype)) return DataFrame(self._internal.with_new_columns(applied)) - def add_prefix(self, prefix) -> "DataFrame": + def add_prefix(self, prefix: str) -> "DataFrame": """ Prefix labels with string `prefix`. @@ -8278,7 +8377,7 @@ defaultdict(, {'col..., 'col...})] lambda psser: psser.rename(tuple([prefix + i for i in psser._column_label])) ) - def add_suffix(self, suffix) -> "DataFrame": + def add_suffix(self, suffix: str) -> "DataFrame": """ Suffix labels with string `suffix`. @@ -8507,7 +8606,12 @@ defaultdict(, {'col..., 'col...})] ) return DataFrame(internal).astype("float64") - def drop_duplicates(self, subset=None, keep="first", inplace=False) -> Optional["DataFrame"]: + def drop_duplicates( + self, + subset: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, + keep: str = "first", + inplace: bool = False, + ) -> Optional["DataFrame"]: """ Return DataFrame with duplicate rows removed, optionally only considering certain columns. @@ -8587,9 +8691,9 @@ defaultdict(, {'col..., 'col...})] def reindex( self, - labels: Optional[Any] = None, - index: Optional[Any] = None, - columns: Optional[Any] = None, + labels: Optional[Sequence[Any]] = None, + index: Optional[Union["Index", Sequence[Any]]] = None, + columns: Optional[Union[pd.Index, Sequence[Any]]] = None, axis: Optional[Union[int, str]] = None, copy: Optional[bool] = True, fill_value: Optional[Any] = None, @@ -8771,7 +8875,9 @@ defaultdict(, {'col..., 'col...})] else: return df - def _reindex_index(self, index, fill_value): + def _reindex_index( + self, index: Optional[Union["Index", Sequence[Any]]], fill_value: Optional[Any] + ) -> "DataFrame": # When axis is index, we can mimic pandas' by a right outer join. nlevels = self._internal.index_level assert nlevels <= 1 or ( @@ -8857,7 +8963,9 @@ defaultdict(, {'col..., 'col...})] ) return DataFrame(internal) - def _reindex_columns(self, columns, fill_value): + def _reindex_columns( + self, columns: Optional[Union[pd.Index, Sequence[Any]]], fill_value: Optional[Any] + ) -> "DataFrame": level = self._internal.column_labels_level if level > 1: label_columns = list(columns) @@ -8872,7 +8980,8 @@ defaultdict(, {'col..., 'col...})] "shape (1,{}) doesn't match the shape (1,{})".format(len(col), level) ) fill_value = np.nan if fill_value is None else fill_value - scols_or_pssers, labels = [], [] + scols_or_pssers = [] # type: List[Union[Series, Column]] + labels = [] for label in label_columns: if label in self._internal.column_labels: scols_or_pssers.append(self._psser_for(label)) @@ -9700,7 +9809,7 @@ defaultdict(, {'col..., 'col...})] return first_series(DataFrame(internal)) # TODO: add axis, numeric_only, pct, na_option parameter - def rank(self, method="average", ascending=True) -> "DataFrame": + def rank(self, method: str = "average", ascending: bool = True) -> "DataFrame": """ Compute numerical data ranks (1 through n) along axis. Equal values are assigned a rank that is the average of the ranks of those values. @@ -9936,13 +10045,13 @@ defaultdict(, {'col..., 'col...})] def rename( self, - mapper=None, - index=None, - columns=None, - axis="index", - inplace=False, - level=None, - errors="ignore", + mapper: Optional[Union[Dict, Callable[[Any], Any]]] = None, + index: Optional[Union[Dict, Callable[[Any], Any]]] = None, + columns: Optional[Union[Dict, Callable[[Any], Any]]] = None, + axis: Union[int, str] = "index", + inplace: bool = False, + level: Optional[int] = None, + errors: str = "ignore", ) -> Optional["DataFrame"]: """ @@ -10029,32 +10138,36 @@ defaultdict(, {'col..., 'col...})] d 7 8 """ - def gen_mapper_fn(mapper): + def gen_mapper_fn( + mapper: Union[Dict, Callable[[Any], Any]] + ) -> Tuple[Callable[[Any], Any], DataType]: if isinstance(mapper, dict): - if len(mapper) == 0: + mapper_dict = cast(dict, mapper) + if len(mapper_dict) == 0: if errors == "raise": raise KeyError("Index include label which is not in the `mapper`.") else: return DataFrame(self._internal) - type_set = set(map(lambda x: type(x), mapper.values())) + type_set = set(map(lambda x: type(x), mapper_dict.values())) if len(type_set) > 1: raise ValueError("Mapper dict should have the same value type.") spark_return_type = as_spark_type(list(type_set)[0]) - def mapper_fn(x): - if x in mapper: - return mapper[x] + def mapper_fn(x: Any) -> Any: + if x in mapper_dict: + return mapper_dict[x] else: if errors == "raise": raise KeyError("Index include value which is not in the `mapper`") return x elif callable(mapper): + mapper_callable = cast(Callable, mapper) spark_return_type = cast(ScalarType, infer_return_type(mapper)).spark_type - def mapper_fn(x): - return mapper(x) + def mapper_fn(x: Any) -> Any: + return mapper_callable(x) else: raise ValueError( @@ -10109,10 +10222,10 @@ defaultdict(, {'col..., 'col...})] if level < 0 or level >= num_indices: raise ValueError("level should be an integer between [0, num_indices)") - def gen_new_index_column(level): + def gen_new_index_column(level: int) -> Column: index_col_name = index_columns[level] - @pandas_udf(returnType=index_mapper_ret_stype) + @pandas_udf(returnType=index_mapper_ret_stype) # type: ignore def index_mapper_udf(s: pd.Series) -> pd.Series: return s.map(index_mapper_fn) @@ -10136,18 +10249,15 @@ defaultdict(, {'col..., 'col...})] if level < 0 or level >= psdf._internal.column_labels_level: raise ValueError("level should be an integer between [0, column_labels_level)") - def gen_new_column_labels_entry(column_labels_entry): - if isinstance(column_labels_entry, tuple): - if level is None: - # rename all level columns - return tuple(map(columns_mapper_fn, column_labels_entry)) - else: - # only rename specified level column - entry_list = list(column_labels_entry) - entry_list[level] = columns_mapper_fn(entry_list[level]) - return tuple(entry_list) + def gen_new_column_labels_entry(column_labels_entry: Tuple) -> Tuple: + if level is None: + # rename all level columns + return tuple(map(columns_mapper_fn, column_labels_entry)) else: - return columns_mapper_fn(column_labels_entry) + # only rename specified level column + entry_list = list(column_labels_entry) + entry_list[level] = columns_mapper_fn(entry_list[level]) + return tuple(entry_list) new_column_labels = list(map(gen_new_column_labels_entry, psdf._internal.column_labels)) @@ -10164,9 +10274,15 @@ defaultdict(, {'col..., 'col...})] def rename_axis( self, - mapper: Optional[Any] = None, - index: Optional[Any] = None, - columns: Optional[Any] = None, + mapper: Union[ + Any, Sequence[Any], Dict[Union[Any, Tuple], Any], Callable[[Union[Any, Tuple]], Any] + ] = None, + index: Union[ + Any, Sequence[Any], Dict[Union[Any, Tuple], Any], Callable[[Union[Any, Tuple]], Any] + ] = None, + columns: Union[ + Any, Sequence[Any], Dict[Union[Any, Tuple], Any], Callable[[Union[Any, Tuple]], Any] + ] = None, axis: Optional[Union[int, str]] = 0, inplace: Optional[bool] = False, ) -> Optional["DataFrame"]: @@ -10275,15 +10391,22 @@ defaultdict(, {'col..., 'col...})] monkey 2 2 """ - def gen_names(v, curnames): + def gen_names( + v: Union[ + Any, Sequence[Any], Dict[Union[Any, Tuple], Any], Callable[[Union[Any, Tuple]], Any] + ], + curnames: List[Union[Any, Tuple]], + ) -> List[Tuple]: if is_scalar(v): - newnames = [v] + newnames = [cast(Any, v)] # type: List[Union[Any, Tuple]] elif is_list_like(v) and not is_dict_like(v): - newnames = list(v) + newnames = list(cast(Sequence[Any], v)) elif is_dict_like(v): - newnames = [v[name] if name in v else name for name in curnames] + v_dict = cast(Dict[Union[Any, Tuple], Any], v) + newnames = [v_dict[name] if name in v_dict else name for name in curnames] elif callable(v): - newnames = [v(name) for name in curnames] + v_callable = cast(Callable[[Union[Any, Tuple]], Any], v) + newnames = [v_callable(name) for name in curnames] else: raise ValueError( "`mapper` or `index` or `columns` should be " @@ -10350,7 +10473,7 @@ defaultdict(, {'col..., 'col...})] """ return self.columns - def pct_change(self, periods=1) -> "DataFrame": + def pct_change(self, periods: int = 1) -> "DataFrame": """ Percentage change between the current and a prior element. @@ -10400,7 +10523,7 @@ defaultdict(, {'col..., 'col...})] """ window = Window.orderBy(NATURAL_ORDER_COLUMN_NAME).rowsBetween(-periods, -periods) - def op(psser): + def op(psser: ps.Series) -> Column: prev_row = F.lag(psser.spark.column, periods).over(window) return ((psser.spark.column - prev_row) / prev_row).alias( psser._internal.data_spark_column_names[0] @@ -10409,7 +10532,7 @@ defaultdict(, {'col..., 'col...})] return self._apply_series_op(op, should_resolve=True) # TODO: axis = 1 - def idxmax(self, axis=0) -> "Series": + def idxmax(self, axis: Union[int, str] = 0) -> "Series": """ Return index of first occurrence of maximum over requested axis. NA/null values are excluded. @@ -10487,7 +10610,7 @@ defaultdict(, {'col..., 'col...})] return cast(ps.Series, ps.from_pandas(psdf._to_internal_pandas().idxmax())) # TODO: axis = 1 - def idxmin(self, axis=0) -> "Series": + def idxmin(self, axis: Union[int, str] = 0) -> "Series": """ Return index of first occurrence of minimum over requested axis. NA/null values are excluded. @@ -10558,7 +10681,13 @@ defaultdict(, {'col..., 'col...})] return cast(ps.Series, ps.from_pandas(psdf._to_internal_pandas().idxmin())) - def info(self, verbose=None, buf=None, max_cols=None, null_counts=None) -> None: + def info( + self, + verbose: Optional[bool] = None, + buf: Optional[IO[str]] = None, + max_cols: Optional[int] = None, + null_counts: Optional[bool] = None, + ) -> None: """ Print a concise summary of a DataFrame. @@ -10831,7 +10960,7 @@ defaultdict(, {'col..., 'col...})] quantile, name="quantile", numeric_only=numeric_only ).rename(qq) - def query(self, expr, inplace=False) -> Optional["DataFrame"]: + def query(self, expr: str, inplace: bool = False) -> Optional["DataFrame"]: """ Query the columns of a DataFrame with a boolean expression. @@ -10935,7 +11064,7 @@ defaultdict(, {'col..., 'col...})] else: return DataFrame(internal) - def take(self, indices, axis=0, **kwargs) -> "DataFrame": + def take(self, indices: List[int], axis: Union[int, str] = 0, **kwargs: Any) -> "DataFrame": """ Return the elements in the given *positional* indices along an axis. @@ -11016,7 +11145,7 @@ defaultdict(, {'col..., 'col...})] else: return cast(DataFrame, self.iloc[:, indices]) - def eval(self, expr, inplace=False) -> Optional[Union["DataFrame", "Series"]]: + def eval(self, expr: str, inplace: bool = False) -> Optional[Union["DataFrame", "Series"]]: """ Evaluate a string describing operations on DataFrame columns. @@ -11105,6 +11234,7 @@ defaultdict(, {'col..., 'col...})] # Since `eval_func` doesn't have a type hint, inferring the schema is always preformed # in the `apply_batch`. Hence, the variables `should_return_series`, `series_name`, # and `should_return_scalar` can be updated. + @no_type_check def eval_func(pdf): nonlocal should_return_series nonlocal series_name @@ -11135,7 +11265,7 @@ defaultdict(, {'col..., 'col...})] # Returns a frame return result - def explode(self, column) -> "DataFrame": + def explode(self, column: Union[Any, Tuple]) -> "DataFrame": """ Transform each element of a list-like to a row, replicating index values. @@ -11199,7 +11329,7 @@ defaultdict(, {'col..., 'col...})] internal = psdf._internal.with_new_sdf(sdf, data_fields=data_fields) return DataFrame(internal) - def mad(self, axis=0) -> "Series": + def mad(self, axis: int = 0) -> "Series": """ Return the mean absolute deviation of values. @@ -11231,7 +11361,7 @@ defaultdict(, {'col..., 'col...})] if axis == 0: - def get_spark_column(psdf, label): + def get_spark_column(psdf: DataFrame, label: Tuple) -> Column: scol = psdf._internal.spark_column_for(label) col_type = psdf._internal.spark_type_for(label) @@ -11240,7 +11370,7 @@ defaultdict(, {'col..., 'col...})] return scol - new_column_labels = [] + new_column_labels = [] # type: List[Tuple] for label in self._internal.column_labels: # Filtering out only columns of numeric and boolean type column. dtype = self._psser_for(label).spark.data_type @@ -11252,7 +11382,7 @@ defaultdict(, {'col..., 'col...})] for label in new_column_labels ] - mean_data = self._internal.spark_frame.select(new_columns).first() + mean_data = self._internal.spark_frame.select(*new_columns).first() new_columns = [ F.avg( @@ -11262,7 +11392,7 @@ defaultdict(, {'col..., 'col...})] ] sdf = self._internal.spark_frame.select( - [F.lit(None).cast(StringType()).alias(SPARK_DEFAULT_INDEX_NAME)] + new_columns + *[F.lit(None).cast(StringType()).alias(SPARK_DEFAULT_INDEX_NAME)], *new_columns ) # The data is expected to be small so it's fine to transpose/use default index. @@ -11293,7 +11423,7 @@ defaultdict(, {'col..., 'col...})] ) return first_series(DataFrame(internal)) - def tail(self, n=5) -> "DataFrame": + def tail(self, n: int = 5) -> "DataFrame": """ Return the last `n` rows. @@ -11553,7 +11683,12 @@ defaultdict(, {'col..., 'col...})] return (left.copy(), right.copy()) if copy else (left, right) @staticmethod - def from_dict(data, orient="columns", dtype=None, columns=None) -> "DataFrame": + def from_dict( + data: Dict[Union[Any, Tuple], Sequence[Any]], + orient: str = "columns", + dtype: Union[str, Dtype] = None, + columns: Optional[List[Union[Any, Tuple]]] = None, + ) -> "DataFrame": """ Construct DataFrame from dict of array-like or dicts. @@ -11623,7 +11758,7 @@ defaultdict(, {'col..., 'col...})] return DataFrameGroupBy._build(self, by, as_index=as_index, dropna=dropna) - def _to_internal_pandas(self): + def _to_internal_pandas(self) -> pd.DataFrame: """ Return a pandas DataFrame directly from _internal to avoid overhead of copy. @@ -11631,14 +11766,14 @@ defaultdict(, {'col..., 'col...})] """ return self._internal.to_pandas_frame - def _get_or_create_repr_pandas_cache(self, n): + def _get_or_create_repr_pandas_cache(self, n: int) -> pd.DataFrame: if not hasattr(self, "_repr_pandas_cache") or n not in self._repr_pandas_cache: object.__setattr__( self, "_repr_pandas_cache", {n: self.head(n + 1)._to_internal_pandas()} ) return self._repr_pandas_cache[n] - def __repr__(self): + def __repr__(self) -> str: max_display_count = get_option("display.max_rows") if max_display_count is None: return self._to_internal_pandas().to_string() @@ -11658,7 +11793,7 @@ defaultdict(, {'col..., 'col...})] return REPR_PATTERN.sub(footer, repr_string) return pdf.to_string() - def _repr_html_(self): + def _repr_html_(self) -> str: max_display_count = get_option("display.max_rows") # pandas 0.25.1 has a regression about HTML representation so 'bold_rows' # has to be set as False explicitly. See https://github.com/pandas-dev/pandas/issues/28204 @@ -11683,7 +11818,7 @@ defaultdict(, {'col..., 'col...})] return REPR_HTML_PATTERN.sub(footer, repr_html) return pdf.to_html(notebook=True, bold_rows=bold_rows) - def __getitem__(self, key): + def __getitem__(self, key: Any) -> Any: from pyspark.pandas.series import Series if key is None: @@ -11702,7 +11837,7 @@ defaultdict(, {'col..., 'col...})] return self.loc[:, list(key)] raise NotImplementedError(key) - def __setitem__(self, key, value): + def __setitem__(self, key: Any, value: Any) -> None: from pyspark.pandas.series import Series if isinstance(value, (DataFrame, Series)) and not same_anchor(value, self): @@ -11711,7 +11846,9 @@ defaultdict(, {'col..., 'col...})] key = DataFrame._index_normalized_label(level, key) value = DataFrame._index_normalized_frame(level, value) - def assign_columns(psdf, this_column_labels, that_column_labels): + def assign_columns( + psdf: DataFrame, this_column_labels: List[Tuple], that_column_labels: List[Tuple] + ) -> Iterator[Tuple["Series", Tuple]]: assert len(key) == len(that_column_labels) # Note that here intentionally uses `zip_longest` that combine # that_columns. @@ -11751,7 +11888,9 @@ defaultdict(, {'col..., 'col...})] self._update_internal_frame(psdf._internal) @staticmethod - def _index_normalized_label(level, labels): + def _index_normalized_label( + level: int, labels: Union[Any, Tuple, Sequence[Union[Any, Tuple]]] + ) -> List[Tuple]: """ Returns a label that is normalized against the current column index level. For example, the key "abc" can be ("abc", "", "") if the current Frame has @@ -11773,7 +11912,9 @@ defaultdict(, {'col..., 'col...})] return [tuple(list(label) + ([""] * (level - len(label)))) for label in labels] @staticmethod - def _index_normalized_frame(level, psser_or_psdf): + def _index_normalized_frame( + level: int, psser_or_psdf: Union["DataFrame", "Series"] + ) -> "DataFrame": """ Returns a frame that is normalized against the current column index level. For example, the name in `pd.Series([...], name="abc")` can be can be @@ -11813,7 +11954,7 @@ defaultdict(, {'col..., 'col...})] "'%s' object has no attribute '%s'" % (self.__class__.__name__, key) ) - def __setattr__(self, key: str, value) -> None: + def __setattr__(self, key: str, value: Any) -> None: try: object.__getattribute__(self, key) return object.__setattr__(self, key, value) @@ -11829,20 +11970,22 @@ defaultdict(, {'col..., 'col...})] else: warnings.warn(msg, UserWarning) - def __len__(self): + def __len__(self) -> int: return self._internal.resolved_copy.spark_frame.count() - def __dir__(self): + def __dir__(self) -> Iterable[str]: fields = [ f for f in self._internal.resolved_copy.spark_frame.schema.fieldNames() if " " not in f ] - return super().__dir__() + fields + return list(super().__dir__()) + fields - def __iter__(self): + def __iter__(self) -> Iterator[Union[Any, Tuple]]: return iter(self.columns) # NDArray Compat - def __array_ufunc__(self, ufunc: Callable, method: str, *inputs: Any, **kwargs: Any): + def __array_ufunc__( + self, ufunc: Callable, method: str, *inputs: Any, **kwargs: Any + ) -> "DataFrame": # TODO: is it possible to deduplicate it with '_map_series_op'? if all(isinstance(inp, DataFrame) for inp in inputs) and any( not same_anchor(inp, inputs[0]) for inp in inputs @@ -11855,7 +11998,9 @@ defaultdict(, {'col..., 'col...})] raise ValueError("cannot join with no overlapping index names") # Different DataFrames - def apply_op(psdf, this_column_labels, that_column_labels): + def apply_op( + psdf: DataFrame, this_column_labels: List[Tuple], that_column_labels: List[Tuple] + ) -> Iterator[Tuple["Series", Tuple]]: for this_label, that_label in zip(this_column_labels, that_column_labels): yield ( ufunc( @@ -11883,7 +12028,7 @@ defaultdict(, {'col..., 'col...})] if sys.version_info >= (3, 7): - def __class_getitem__(cls, params): + def __class_getitem__(cls, params: Any) -> object: # This is a workaround to support variadic generic in DataFrame in Python 3.7. # See https://github.com/python/typing/issues/193 # we always wraps the given type hints by a tuple to mimic the variadic generic. @@ -11896,13 +12041,13 @@ defaultdict(, {'col..., 'col...})] is_dataframe = None -def _reduce_spark_multi(sdf, aggs): +def _reduce_spark_multi(sdf: SparkDataFrame, aggs: List[Column]) -> Any: """ Performs a reduction on a spark DataFrame, the functions being known sql aggregate functions. """ - assert isinstance(sdf, spark.DataFrame) + assert isinstance(sdf, SparkDataFrame) sdf0 = sdf.agg(*aggs) - l = sdf0.limit(2).toPandas() + l = cast(pd.DataFrame, sdf0.limit(2).toPandas()) assert len(l) == 1, (sdf, l) row = l.iloc[0] l2 = list(row) @@ -11916,7 +12061,7 @@ class CachedDataFrame(DataFrame): internally it caches the corresponding Spark DataFrame. """ - def __init__(self, internal, storage_level=None): + def __init__(self, internal: InternalFrame, storage_level: Optional[StorageLevel] = None): if storage_level is None: object.__setattr__(self, "_cached", internal.spark_frame.cache()) elif isinstance(storage_level, StorageLevel): @@ -11927,17 +12072,23 @@ class CachedDataFrame(DataFrame): ) super().__init__(internal) - def __enter__(self): + def __enter__(self) -> "CachedDataFrame": return self - def __exit__(self, exception_type, exception_value, traceback): + def __exit__( + self, + exception_type: Optional[Type[BaseException]], + exception_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Optional[bool]: self.spark.unpersist() + return None # create accessor for Spark related methods. spark = CachedAccessor("spark", CachedSparkFrameMethods) -def _test(): +def _test() -> None: import os import doctest import shutil diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py index 03864ef182..3a33295eb8 100644 --- a/python/pyspark/pandas/generic.py +++ b/python/pyspark/pandas/generic.py @@ -96,7 +96,9 @@ class Frame(object, metaclass=ABCMeta): @abstractmethod def _apply_series_op( - self: T_Frame, op: Callable[["Series"], "Series"], should_resolve: bool = False + self: T_Frame, + op: Callable[["Series"], Union["Series", Column]], + should_resolve: bool = False, ) -> T_Frame: pass @@ -2096,11 +2098,13 @@ class Frame(object, metaclass=ABCMeta): 3 7 40 50 """ - def abs(psser: "Series") -> "Series": + def abs(psser: "Series") -> Union["Series", Column]: if isinstance(psser.spark.data_type, BooleanType): return psser elif isinstance(psser.spark.data_type, NumericType): - return psser.spark.transform(F.abs) + return psser._with_new_scol( + F.abs(psser.spark.column), field=psser._internal.data_fields[0] + ) else: raise TypeError( "bad operand type for abs(): {} ({})".format( diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index bc8bac2ae8..860540e8ed 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -31,6 +31,7 @@ from typing import ( Callable, Dict, Generic, + Iterator, Mapping, List, Optional, @@ -2632,7 +2633,7 @@ class GroupBy(Generic[T_Frame], metaclass=ABCMeta): def assign_columns( psdf: DataFrame, this_column_labels: List[Tuple], that_column_labels: List[Tuple] - ) -> Tuple[Series, Tuple]: + ) -> Iterator[Tuple[Series, Tuple]]: raise NotImplementedError( "Duplicated labels with groupby() and " "'compute.ops_on_diff_frames' option are not supported currently " diff --git a/python/pyspark/pandas/namespace.py b/python/pyspark/pandas/namespace.py index f602d24790..d1ecdb0a39 100644 --- a/python/pyspark/pandas/namespace.py +++ b/python/pyspark/pandas/namespace.py @@ -30,6 +30,7 @@ from typing import ( # noqa: F401 (SPARK-34943) Union, cast, no_type_check, + overload, ) from collections import OrderedDict from collections.abc import Iterable @@ -1130,7 +1131,7 @@ def read_excel( else: pdf = pdf_or_pser - psdf = from_pandas(pdf) + psdf = cast(DataFrame, from_pandas(pdf)) return_schema = force_decimal_precision_scale( as_nullable_spark_type(psdf._internal.spark_frame.drop(*HIDDEN_COLUMNS).schema) ) diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index 01296d9c90..3de72436e6 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -2500,7 +2500,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]): """ inplace = validate_bool_kwarg(inplace, "inplace") psdf = self._psdf[[self.name]]._sort( - by=[self.spark.column], ascending=ascending, inplace=False, na_position=na_position + by=[self.spark.column], ascending=ascending, na_position=na_position ) if inplace: @@ -3041,7 +3041,8 @@ class Series(Frame, IndexOpsMixin, Generic[T]): sample.__doc__ = DataFrame.sample.__doc__ - def hist(self, bins: int = 10, **kwds: Any) -> Any: + @no_type_check + def hist(self, bins=10, **kwds): return self.plot.hist(bins, **kwds) hist.__doc__ = PandasOnSparkPlotAccessor.hist.__doc__ @@ -4963,17 +4964,19 @@ class Series(Frame, IndexOpsMixin, Generic[T]): if not self.index.sort_values().equals(other.index.sort_values()): raise ValueError("matrices are not aligned") - other = other.copy() - column_labels = other._internal.column_labels + other_copy = other.copy() # type: DataFrame + column_labels = other_copy._internal.column_labels - self_column_label = verify_temp_column_name(other, "__self_column__") - other[self_column_label] = self - self_psser = other._psser_for(self_column_label) + self_column_label = verify_temp_column_name(other_copy, "__self_column__") + other_copy[self_column_label] = self + self_psser = other_copy._psser_for(self_column_label) - product_pssers = [other._psser_for(label) * self_psser for label in column_labels] + product_pssers = [ + cast(Series, other_copy._psser_for(label) * self_psser) for label in column_labels + ] dot_product_psser = DataFrame( - other._internal.with_new_columns(product_pssers, column_labels=column_labels) + other_copy._internal.with_new_columns(product_pssers, column_labels=column_labels) ).sum() return cast(Series, dot_product_psser).rename(self.name) @@ -6158,9 +6161,13 @@ class Series(Frame, IndexOpsMixin, Generic[T]): # ---------------------------------------------------------------------- def _apply_series_op( - self, op: Callable[["Series"], "Series"], should_resolve: bool = False + self, op: Callable[["Series"], Union["Series", Column]], should_resolve: bool = False ) -> "Series": - psser = op(self) + psser_or_scol = op(self) + if isinstance(psser_or_scol, Series): + psser = psser_or_scol + else: + psser = self._with_new_scol(cast(Column, psser_or_scol)) if should_resolve: internal = psser._internal.resolved_copy return first_series(DataFrame(internal)) diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py index de0e0a6e8d..6f2ed7bf37 100644 --- a/python/pyspark/pandas/utils.py +++ b/python/pyspark/pandas/utils.py @@ -307,7 +307,9 @@ def combine_frames( def align_diff_frames( - resolve_func: Callable[["DataFrame", List[Tuple], List[Tuple]], Tuple["Series", Tuple]], + resolve_func: Callable[ + ["DataFrame", List[Tuple], List[Tuple]], Iterator[Tuple["Series", Tuple]] + ], this: "DataFrame", that: "DataFrame", fillna: bool = True, @@ -419,7 +421,7 @@ def align_diff_frames( if len(this_columns_to_apply) > 0 or len(that_columns_to_apply) > 0: psser_set, column_labels_set = zip( *resolve_func(combined, this_columns_to_apply, that_columns_to_apply) - ) # type: Tuple[Iterable[Series], Iterable[Tuple]] + ) columns_applied = list(psser_set) # type: List[Union[Series, spark.Column]] column_labels_applied = list(column_labels_set) # type: List[Tuple] else: