diff --git a/python/pyspark/pandas/_typing.py b/python/pyspark/pandas/_typing.py index 70ce215262..cdc622d482 100644 --- a/python/pyspark/pandas/_typing.py +++ b/python/pyspark/pandas/_typing.py @@ -16,7 +16,7 @@ # import datetime import decimal -from typing import TypeVar, Union, TYPE_CHECKING +from typing import Any, Tuple, TypeVar, Union, TYPE_CHECKING import numpy as np from pandas.api.extensions import ExtensionDtype @@ -40,6 +40,10 @@ Scalar = Union[ int, float, bool, str, bytes, decimal.Decimal, datetime.date, datetime.datetime, None ] +# TODO: use the actual type parameters. +Label = Tuple[Any, ...] +Name = Union[Any, Label] + Axis = Union[int, str] Dtype = Union[np.dtype, ExtensionDtype] diff --git a/python/pyspark/pandas/accessors.py b/python/pyspark/pandas/accessors.py index 0fafeab804..6454938944 100644 --- a/python/pyspark/pandas/accessors.py +++ b/python/pyspark/pandas/accessors.py @@ -28,7 +28,7 @@ from pyspark.sql import functions as F from pyspark.sql.functions import pandas_udf from pyspark.sql.types import DataType, LongType, StructField, StructType -from pyspark.pandas._typing import DataFrameOrSeries +from pyspark.pandas._typing import DataFrameOrSeries, Name from pyspark.pandas.internal import ( InternalField, InternalFrame, @@ -56,7 +56,7 @@ class PandasOnSparkFrameMethods(object): def __init__(self, frame: "DataFrame"): self._psdf = frame - def attach_id_column(self, id_type: str, column: Union[Any, Tuple]) -> "DataFrame": + def attach_id_column(self, id_type: str, column: Name) -> "DataFrame": """ Attach a column to be used as identifier of rows similar to the default index. diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py index 2744ef91d4..6e76311d25 100644 --- a/python/pyspark/pandas/base.py +++ b/python/pyspark/pandas/base.py @@ -34,7 +34,7 @@ from pyspark.sql.types import ( ) from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. -from pyspark.pandas._typing import Axis, Dtype, IndexOpsLike, SeriesOrIndex +from pyspark.pandas._typing import Axis, Dtype, IndexOpsLike, Label, SeriesOrIndex from pyspark.pandas.config import get_option, option_context from pyspark.pandas.internal import ( InternalField, @@ -297,7 +297,7 @@ class IndexOpsMixin(object, metaclass=ABCMeta): @property @abstractmethod - def _column_label(self) -> Optional[Tuple]: + def _column_label(self) -> Optional[Label]: pass @property diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 34c459c9e9..c500ee7376 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -83,7 +83,7 @@ from pyspark.sql.types import ( # noqa: F401 (SPARK-34943) from pyspark.sql.window import Window from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. -from pyspark.pandas._typing import Axis, DataFrameOrSeries, Dtype, Scalar, T +from pyspark.pandas._typing import Axis, DataFrameOrSeries, Dtype, Label, Name, Scalar, T from pyspark.pandas.accessors import PandasOnSparkFrameMethods from pyspark.pandas.config import option_context, get_option from pyspark.pandas.spark import functions as SF @@ -520,7 +520,7 @@ class DataFrame(Frame, Generic[T]): object.__setattr__(self, "_internal_frame", internal) @property - def _pssers(self) -> Dict[Tuple, "Series"]: + def _pssers(self) -> Dict[Label, "Series"]: """Return a dict of column label -> Series which anchors `self`.""" from pyspark.pandas.series import Series @@ -746,7 +746,7 @@ class DataFrame(Frame, Generic[T]): ) return first_series(DataFrame(internal)).rename(pser.name) - def _psser_for(self, label: Tuple) -> "Series": + def _psser_for(self, label: Label) -> "Series": """ Create Series with a proper column label. @@ -806,9 +806,9 @@ class DataFrame(Frame, Generic[T]): # Different DataFrames def apply_op( psdf: DataFrame, - this_column_labels: List[Tuple], - that_column_labels: List[Tuple], - ) -> Iterator[Tuple["Series", Tuple]]: + this_column_labels: List[Label], + that_column_labels: List[Label], + ) -> Iterator[Tuple["Series", Label]]: for this_label, that_label in zip(this_column_labels, that_column_labels): yield ( getattr(psdf._psser_for(this_label), op)( @@ -1226,7 +1226,7 @@ class DataFrame(Frame, Generic[T]): return self._apply_series_op(lambda psser: psser.apply(func)) # TODO: not all arguments are implemented comparing to pandas' for now. - def aggregate(self, func: Union[List[str], Dict[Any, List[str]]]) -> "DataFrame": + def aggregate(self, func: Union[List[str], Dict[Name, List[str]]]) -> "DataFrame": """Aggregate using one or more operations over the specified axis. Parameters @@ -1388,7 +1388,7 @@ class DataFrame(Frame, Generic[T]): """ return cast(DataFrame, ps.from_pandas(corr(self, method))) - def iteritems(self) -> Iterator[Tuple[Union[Any, Tuple], "Series"]]: + def iteritems(self) -> Iterator[Tuple[Name, "Series"]]: """ Iterator over (column name, Series) pairs. @@ -1432,7 +1432,7 @@ class DataFrame(Frame, Generic[T]): for label in self._internal.column_labels ) - def iterrows(self) -> Iterator[Tuple[Union[Any, Tuple], pd.Series]]: + def iterrows(self) -> Iterator[Tuple[Name, pd.Series]]: """ Iterate over DataFrame rows as (index, Series) pairs. @@ -1478,7 +1478,7 @@ class DataFrame(Frame, Generic[T]): internal_index_columns = self._internal.index_spark_column_names internal_data_columns = self._internal.data_spark_column_names - def extract_kv_from_spark_row(row: Row) -> Tuple[Union[Any, Tuple], Any]: + def extract_kv_from_spark_row(row: Row) -> Tuple[Name, Any]: k = ( row[internal_index_columns[0]] if len(internal_index_columns) == 1 @@ -1567,7 +1567,7 @@ class DataFrame(Frame, Generic[T]): index_spark_column_names = self._internal.index_spark_column_names data_spark_column_names = self._internal.data_spark_column_names - def extract_kv_from_spark_row(row: Row) -> Tuple[Union[Any, Tuple], Any]: + def extract_kv_from_spark_row(row: Row) -> Tuple[Name, Any]: k = ( row[index_spark_column_names[0]] if len(index_spark_column_names) == 1 @@ -1592,7 +1592,7 @@ class DataFrame(Frame, Generic[T]): ): yield tuple(([k] if index else []) + list(v)) - def items(self) -> Iterator[Tuple[Union[Any, Tuple], "Series"]]: + def items(self) -> Iterator[Tuple[Name, "Series"]]: """This is an alias of ``iteritems``.""" return self.iteritems() @@ -1674,13 +1674,13 @@ class DataFrame(Frame, Generic[T]): def to_html( self, buf: Optional[IO[str]] = None, - columns: Optional[Sequence[Union[Any, Tuple]]] = None, - col_space: Optional[Union[str, int, Dict[Union[Any, Tuple], Union[str, int]]]] = None, + columns: Optional[Sequence[Name]] = None, + col_space: Optional[Union[str, int, Dict[Name, Union[str, int]]]] = None, header: bool = True, index: bool = True, na_rep: str = "NaN", formatters: Optional[ - Union[List[Callable[[Any], str]], Dict[Union[Any, Tuple], Callable[[Any], str]]] + Union[List[Callable[[Any], str]], Dict[Name, Callable[[Any], str]]] ] = None, float_format: Optional[Callable[[float], str]] = None, sparsify: Optional[bool] = None, @@ -1796,13 +1796,13 @@ class DataFrame(Frame, Generic[T]): def to_string( self, buf: Optional[IO[str]] = None, - columns: Optional[Sequence[Union[Any, Tuple]]] = None, - col_space: Optional[Union[str, int, Dict[Union[Any, Tuple], Union[str, int]]]] = None, + columns: Optional[Sequence[Name]] = None, + col_space: Optional[Union[str, int, Dict[Name, Union[str, int]]]] = None, header: bool = True, index: bool = True, na_rep: str = "NaN", formatters: Optional[ - Union[List[Callable[[Any], str]], Dict[Union[Any, Tuple], Callable[[Any], str]]] + Union[List[Callable[[Any], str]], Dict[Name, Callable[[Any], str]]] ] = None, float_format: Optional[Callable[[float], str]] = None, sparsify: Optional[bool] = None, @@ -2010,13 +2010,13 @@ defaultdict(, {'col..., 'col...})] def to_latex( self, buf: Optional[IO[str]] = None, - columns: Optional[List[Union[Any, Tuple]]] = None, + columns: Optional[List[Name]] = None, col_space: Optional[int] = None, header: bool = True, index: bool = True, na_rep: str = "NaN", formatters: Optional[ - Union[List[Callable[[Any], str]], Dict[Union[Any, Tuple], Callable[[Any], str]]] + Union[List[Callable[[Any], str]], Dict[Name, Callable[[Any], str]]] ] = None, float_format: Optional[Callable[[float], str]] = None, sparsify: Optional[bool] = None, @@ -2545,7 +2545,7 @@ defaultdict(, {'col..., 'col...})] self_applied = DataFrame(self._internal.resolved_copy) # type: "DataFrame" - column_labels = None # type: Optional[List[Tuple]] + column_labels = None # type: Optional[List[Label]] if should_infer_schema: # Here we execute with the first 1000 to get the return type. # If the records were less than 1000, it uses pandas API directly for a shortcut. @@ -2802,7 +2802,7 @@ defaultdict(, {'col..., 'col...})] lambda psser: psser.pandas_on_spark.transform_batch(func, *args, **kwargs) ) - def pop(self, item: Union[Any, Tuple]) -> "DataFrame": + def pop(self, item: Name) -> "DataFrame": """ Return item and drop from frame. Raise KeyError if not found. @@ -2881,9 +2881,7 @@ defaultdict(, {'col..., 'col...})] return result # TODO: add axis parameter can work when '1' or 'columns' - def xs( - self, key: Union[Any, Tuple], axis: Axis = 0, level: Optional[int] = None - ) -> DataFrameOrSeries: + def xs(self, key: Name, axis: Axis = 0, level: Optional[int] = None) -> DataFrameOrSeries: """ Return cross-section from the DataFrame. @@ -3527,7 +3525,7 @@ defaultdict(, {'col..., 'col...})] def set_index( self, - keys: Union[Any, Tuple, List[Union[Any, Tuple]]], + keys: Union[Name, List[Name]], drop: bool = True, append: bool = False, inplace: bool = False, @@ -3596,7 +3594,7 @@ defaultdict(, {'col..., 'col...})] """ inplace = validate_bool_kwarg(inplace, "inplace") if is_name_like_tuple(keys): - key_list = [cast(Tuple, keys)] # type: List[Tuple] + key_list = [cast(Label, keys)] # type: List[Label] elif is_name_like_value(keys): key_list = [(keys,)] else: @@ -3642,7 +3640,7 @@ defaultdict(, {'col..., 'col...})] def reset_index( self, - level: Optional[Union[int, Any, Tuple, Sequence[Union[int, Any, Tuple]]]] = None, + level: Optional[Union[int, Name, Sequence[Union[int, Name]]]] = None, drop: bool = False, inplace: bool = False, col_level: int = 0, @@ -3793,7 +3791,7 @@ defaultdict(, {'col..., 'col...})] inplace = validate_bool_kwarg(inplace, "inplace") multi_index = self._internal.index_level > 1 - def rename(index: int) -> Tuple: + def rename(index: int) -> Label: if multi_index: return ("level_{}".format(index),) else: @@ -3818,9 +3816,9 @@ defaultdict(, {'col..., 'col...})] index_fields = [] else: if is_list_like(level): - level = list(cast(Sequence[Union[int, Any, Tuple]], level)) + level = list(cast(Sequence[Union[int, Name]], level)) if isinstance(level, int) or is_name_like_tuple(level): - level_list = [cast(Union[int, Tuple], level)] + level_list = [cast(Union[int, Label], level)] elif is_name_like_value(level): level_list = [(level,)] else: @@ -3841,7 +3839,7 @@ defaultdict(, {'col..., 'col...})] idx = int_level_list elif all(is_name_like_tuple(lev) for lev in level_list): idx = [] - for l in cast(List[Tuple], level_list): + for l in cast(List[Label], level_list): try: i = self._internal.index_names.index(l) idx.append(i) @@ -3985,7 +3983,7 @@ defaultdict(, {'col..., 'col...})] def insert( self, loc: int, - column: Union[Any, Tuple], + column: Name, value: Union[Scalar, "Series", Iterable], allow_duplicates: bool = False, ) -> None: @@ -4268,9 +4266,7 @@ defaultdict(, {'col..., 'col...})] ) return first_series(DataFrame(internal).transpose()) - def round( - self, decimals: Union[int, Dict[Union[Any, Tuple], int], "Series"] = 0 - ) -> "DataFrame": + def round(self, decimals: Union[int, Dict[Name, int], "Series"] = 0) -> "DataFrame": """ Round a DataFrame to a variable number of decimal places. @@ -4352,14 +4348,14 @@ defaultdict(, {'col..., 'col...})] def _mark_duplicates( self, - subset: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, + subset: Optional[Union[Name, List[Name]]] = None, keep: str = "first", ) -> Tuple[SparkDataFrame, str]: if subset is None: subset_list = self._internal.column_labels else: if is_name_like_tuple(subset): - subset_list = [cast(Tuple, subset)] + subset_list = [cast(Label, subset)] elif is_name_like_value(subset): subset_list = [(subset,)] else: @@ -4395,7 +4391,7 @@ defaultdict(, {'col..., 'col...})] def duplicated( self, - subset: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, + subset: Optional[Union[Name, List[Name]]] = None, keep: str = "first", ) -> "Series": """ @@ -5033,12 +5029,8 @@ defaultdict(, {'col..., 'col...})] def to_records( self, index: bool = True, - column_dtypes: Optional[ - Union[str, Dtype, Dict[Union[Any, Tuple], Union[str, Dtype]]] - ] = None, - index_dtypes: Optional[ - Union[str, Dtype, Dict[Union[Any, Tuple], Union[str, Dtype]]] - ] = None, + column_dtypes: Optional[Union[str, Dtype, Dict[Name, Union[str, Dtype]]]] = None, + index_dtypes: Optional[Union[str, Dtype, Dict[Name, Union[str, Dtype]]]] = None, ) -> np.recarray: """ Convert DataFrame to a NumPy record array. @@ -5152,7 +5144,7 @@ defaultdict(, {'col..., 'col...})] axis: Axis = 0, how: str = "any", thresh: Optional[int] = None, - subset: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, + subset: Optional[Union[Name, List[Name]]] = None, inplace: bool = False, ) -> Optional["DataFrame"]: """ @@ -5256,7 +5248,7 @@ defaultdict(, {'col..., 'col...})] if subset is not None: if isinstance(subset, str): - labels = [(subset,)] # type: Optional[List[Tuple]] + labels = [(subset,)] # type: Optional[List[Label]] elif isinstance(subset, tuple): labels = [subset] else: @@ -5360,7 +5352,7 @@ defaultdict(, {'col..., 'col...})] # TODO: add 'limit' when value parameter exists def fillna( self, - value: Optional[Union[Any, Dict[Union[Any, Tuple], Any]]] = None, + value: Optional[Union[Any, Dict[Name, Any]]] = None, method: Optional[str] = None, axis: Optional[Axis] = None, inplace: bool = False, @@ -5835,10 +5827,10 @@ defaultdict(, {'col..., 'col...})] def pivot_table( self, - values: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, - index: Optional[List[Union[Any, Tuple]]] = None, - columns: Optional[Union[Any, Tuple]] = None, - aggfunc: Union[str, Dict[Union[Any, Tuple], str]] = "mean", + values: Optional[Union[Name, List[Name]]] = None, + index: Optional[List[Name]] = None, + columns: Optional[Name] = None, + aggfunc: Union[str, Dict[Name, str]] = "mean", fill_value: Optional[Any] = None, ) -> "DataFrame": """ @@ -6060,7 +6052,7 @@ defaultdict(, {'col..., 'col...})] for name in data_columns ] column_label_names = ( - [cast(Optional[Union[Any, Tuple]], None)] * column_labels_level(values) + [cast(Optional[Name], None)] * column_labels_level(values) ) + [columns] internal = InternalFrame( spark_frame=sdf, @@ -6074,9 +6066,7 @@ defaultdict(, {'col..., 'col...})] psdf = DataFrame(internal) # type: "DataFrame" else: column_labels = [tuple(list(values[0]) + [column]) for column in data_columns] - column_label_names = ( - [cast(Optional[Union[Any, Tuple]], None)] * len(values[0]) - ) + [columns] + column_label_names = ([cast(Optional[Name], None)] * len(values[0])) + [columns] internal = InternalFrame( spark_frame=sdf, index_spark_columns=[scol_for(sdf, col) for col in index_columns], @@ -6101,7 +6091,7 @@ defaultdict(, {'col..., 'col...})] index_values = values[-1] else: index_values = values - index_map = OrderedDict() # type: Dict[str, Optional[Tuple]] + index_map = OrderedDict() # type: Dict[str, Optional[Label]] for i, index_value in enumerate(index_values): colname = SPARK_INDEX_NAME_FORMAT(i) sdf = sdf.withColumn(colname, SF.lit(index_value)) @@ -6131,9 +6121,9 @@ defaultdict(, {'col..., 'col...})] def pivot( self, - index: Optional[Union[Any, Tuple]] = None, - columns: Optional[Union[Any, Tuple]] = None, - values: Optional[Union[Any, Tuple]] = None, + index: Optional[Name] = None, + columns: Optional[Name] = None, + values: Optional[Name] = None, ) -> "DataFrame": """ Return reshaped DataFrame organized by given index / column values. @@ -6280,7 +6270,7 @@ defaultdict(, {'col..., 'col...})] return columns @columns.setter - def columns(self, columns: Union[pd.Index, List[Union[Any, Tuple]]]) -> None: + def columns(self, columns: Union[pd.Index, List[Name]]) -> None: if isinstance(columns, pd.MultiIndex): column_labels = columns.tolist() else: @@ -6523,7 +6513,7 @@ defaultdict(, {'col..., 'col...})] ) def droplevel( - self, level: Union[int, Any, Tuple, List[Union[int, Any, Tuple]]], axis: Axis = 0 + self, level: Union[int, Name, List[Union[int, Name]]], axis: Axis = 0 ) -> "DataFrame": """ Return DataFrame with requested index / column level(s) removed. @@ -6579,7 +6569,7 @@ defaultdict(, {'col..., 'col...})] if not isinstance(level, (tuple, list)): # huh? level = [level] - index_names = self.index.names + names = self.index.names nlevels = self._internal.index_level int_level = set() @@ -6599,9 +6589,9 @@ defaultdict(, {'col..., 'col...})] ) ) else: - if n not in index_names: + if n not in names: raise KeyError("Level {} not found".format(n)) - n = index_names.index(n) + n = names.index(n) int_level.add(n) if len(level) >= nlevels: @@ -6637,9 +6627,9 @@ defaultdict(, {'col..., 'col...})] def drop( self, - labels: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, + labels: Optional[Union[Name, List[Name]]] = None, axis: Axis = 1, - columns: Union[Any, Tuple, List[Any], List[Tuple]] = None, + columns: Union[Name, List[Name]] = None, ) -> "DataFrame": """ Drop specified labels from columns. @@ -6773,7 +6763,7 @@ defaultdict(, {'col..., 'col...})] def sort_values( self, - by: Union[Any, List[Any], Tuple, List[Tuple]], + by: Union[Name, List[Name]], ascending: Union[bool, List[bool]] = True, inplace: bool = False, na_position: str = "last", @@ -6983,7 +6973,7 @@ defaultdict(, {'col..., 'col...})] return psdf def swaplevel( - self, i: Union[int, Any, Tuple] = -2, j: Union[int, Any, Tuple] = -1, axis: Axis = 0 + self, i: Union[int, Name] = -2, j: Union[int, Name] = -1, axis: Axis = 0 ) -> "DataFrame": """ Swap levels i and j in a MultiIndex on a particular axis. @@ -7142,9 +7132,7 @@ defaultdict(, {'col..., 'col...})] return self.copy() if i == j else self.transpose() - def _swaplevel_columns( - self, i: Union[int, Any, Tuple], j: Union[int, Any, Tuple] - ) -> InternalFrame: + def _swaplevel_columns(self, i: Union[int, Name], j: Union[int, Name]) -> InternalFrame: assert isinstance(self.columns, pd.MultiIndex) for index in (i, j): if not isinstance(index, int) and index not in self.columns.names: @@ -7174,9 +7162,7 @@ defaultdict(, {'col..., 'col...})] ) return internal - def _swaplevel_index( - self, i: Union[int, Any, Tuple], j: Union[int, Any, Tuple] - ) -> InternalFrame: + def _swaplevel_index(self, i: Union[int, Name], j: Union[int, Name]) -> InternalFrame: assert isinstance(self.index, ps.MultiIndex) for index in (i, j): if not isinstance(index, int) and index not in self.index.names: @@ -7208,7 +7194,7 @@ defaultdict(, {'col..., 'col...})] return internal # TODO: add keep = First - def nlargest(self, n: int, columns: "Any") -> "DataFrame": + def nlargest(self, n: int, columns: Union[Name, List[Name]]) -> "DataFrame": """ Return the first `n` rows ordered by `columns` in descending order. @@ -7282,7 +7268,7 @@ defaultdict(, {'col..., 'col...})] return self.sort_values(by=columns, ascending=False).head(n=n) # TODO: add keep = First - def nsmallest(self, n: int, columns: "Any") -> "DataFrame": + def nsmallest(self, n: int, columns: Union[Name, List[Name]]) -> "DataFrame": """ Return the first `n` rows ordered by `columns` in ascending order. @@ -7447,9 +7433,9 @@ defaultdict(, {'col..., 'col...})] self, right: "DataFrame", how: str = "inner", - on: Optional[Union[Any, List[Any], Tuple, List[Tuple]]] = None, - left_on: Optional[Union[Any, List[Any], Tuple, List[Tuple]]] = None, - right_on: Optional[Union[Any, List[Any], Tuple, List[Tuple]]] = None, + on: Optional[Union[Name, List[Name]]] = None, + left_on: Optional[Union[Name, List[Name]]] = None, + right_on: Optional[Union[Name, List[Name]]] = None, left_index: bool = False, right_index: bool = False, suffixes: Tuple[str, str] = ("_x", "_y"), @@ -7571,7 +7557,7 @@ defaultdict(, {'col..., 'col...})] instead of NaN. """ - def to_list(os: Optional[Union[Any, List[Any], Tuple, List[Tuple]]]) -> List[Tuple]: + def to_list(os: Optional[Union[Name, List[Name]]]) -> List[Label]: if os is None: return [] elif is_name_like_tuple(os): @@ -7779,7 +7765,7 @@ defaultdict(, {'col..., 'col...})] def join( self, right: "DataFrame", - on: Optional[Union[Any, List[Any], Tuple, List[Tuple]]] = None, + on: Optional[Union[Name, List[Name]]] = None, how: str = "left", lsuffix: str = "", rsuffix: str = "", @@ -8176,9 +8162,7 @@ defaultdict(, {'col..., 'col...})] ) return DataFrame(self._internal.with_new_sdf(sdf)) - def astype( - self, dtype: Union[str, Dtype, Dict[Union[Any, Tuple], Union[str, Dtype]]] - ) -> "DataFrame": + def astype(self, dtype: Union[str, Dtype, Dict[Name, Union[str, Dtype]]]) -> "DataFrame": """ Cast a pandas-on-Spark object to a specified dtype ``dtype``. @@ -8234,7 +8218,7 @@ defaultdict(, {'col..., 'col...})] """ applied = [] if is_dict_like(dtype): - dtype_dict = cast(Dict[Union[Any, Tuple], Union[str, Dtype]], dtype) + dtype_dict = cast(Dict[Name, Union[str, Dtype]], dtype) for col_name in dtype_dict.keys(): if col_name not in self.columns: raise KeyError( @@ -8527,7 +8511,7 @@ defaultdict(, {'col..., 'col...})] def drop_duplicates( self, - subset: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, + subset: Optional[Union[Name, List[Name]]] = None, keep: str = "first", inplace: bool = False, ) -> Optional["DataFrame"]: @@ -8997,8 +8981,8 @@ defaultdict(, {'col..., 'col...})] def melt( self, - id_vars: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, - value_vars: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, + id_vars: Optional[Union[Name, List[Name]]] = None, + value_vars: Optional[Union[Name, List[Name]]] = None, var_name: Optional[Union[str, List[str]]] = None, value_name: str = "value", ) -> "DataFrame": @@ -10168,7 +10152,7 @@ defaultdict(, {'col..., 'col...})] if level < 0 or level >= psdf._internal.column_labels_level: raise ValueError("level should be an integer between [0, column_labels_level)") - def gen_new_column_labels_entry(column_labels_entry: Tuple) -> Tuple: + def gen_new_column_labels_entry(column_labels_entry: Label) -> Label: if level is None: # rename all level columns return tuple(map(columns_mapper_fn, column_labels_entry)) @@ -10193,15 +10177,9 @@ defaultdict(, {'col..., 'col...})] def rename_axis( self, - mapper: Union[ - Any, Sequence[Any], Dict[Union[Any, Tuple], Any], Callable[[Union[Any, Tuple]], Any] - ] = None, - index: Union[ - Any, Sequence[Any], Dict[Union[Any, Tuple], Any], Callable[[Union[Any, Tuple]], Any] - ] = None, - columns: Union[ - Any, Sequence[Any], Dict[Union[Any, Tuple], Any], Callable[[Union[Any, Tuple]], Any] - ] = None, + mapper: Union[Any, Sequence[Any], Dict[Name, Any], Callable[[Name], Any]] = None, + index: Union[Any, Sequence[Any], Dict[Name, Any], Callable[[Name], Any]] = None, + columns: Union[Any, Sequence[Any], Dict[Name, Any], Callable[[Name], Any]] = None, axis: Optional[Axis] = 0, inplace: Optional[bool] = False, ) -> Optional["DataFrame"]: @@ -10311,20 +10289,18 @@ defaultdict(, {'col..., 'col...})] """ def gen_names( - v: Union[ - Any, Sequence[Any], Dict[Union[Any, Tuple], Any], Callable[[Union[Any, Tuple]], Any] - ], - curnames: List[Union[Any, Tuple]], - ) -> List[Tuple]: + v: Union[Any, Sequence[Any], Dict[Name, Any], Callable[[Name], Any]], + curnames: List[Name], + ) -> List[Label]: if is_scalar(v): - newnames = [cast(Any, v)] # type: List[Union[Any, Tuple]] + newnames = [cast(Any, v)] # type: List[Name] elif is_list_like(v) and not is_dict_like(v): newnames = list(cast(Sequence[Any], v)) elif is_dict_like(v): - v_dict = cast(Dict[Union[Any, Tuple], Any], v) + v_dict = cast(Dict[Name, Any], v) newnames = [v_dict[name] if name in v_dict else name for name in curnames] elif callable(v): - v_callable = cast(Callable[[Union[Any, Tuple]], Any], v) + v_callable = cast(Callable[[Name], Any], v) newnames = [v_callable(name) for name in curnames] else: raise ValueError( @@ -11184,7 +11160,7 @@ defaultdict(, {'col..., 'col...})] # Returns a frame return result - def explode(self, column: Union[Any, Tuple]) -> "DataFrame": + def explode(self, column: Name) -> "DataFrame": """ Transform each element of a list-like to a row, replicating index values. @@ -11280,7 +11256,7 @@ defaultdict(, {'col..., 'col...})] if axis == 0: - def get_spark_column(psdf: DataFrame, label: Tuple) -> Column: + def get_spark_column(psdf: DataFrame, label: Label) -> Column: scol = psdf._internal.spark_column_for(label) col_type = psdf._internal.spark_type_for(label) @@ -11289,7 +11265,7 @@ defaultdict(, {'col..., 'col...})] return scol - new_column_labels = [] # type: List[Tuple] + new_column_labels = [] # type: List[Label] for label in self._internal.column_labels: # Filtering out only columns of numeric and boolean type column. dtype = self._psser_for(label).spark.data_type @@ -11603,10 +11579,10 @@ defaultdict(, {'col..., 'col...})] @staticmethod def from_dict( - data: Dict[Union[Any, Tuple], Sequence[Any]], + data: Dict[Name, Sequence[Any]], orient: str = "columns", dtype: Union[str, Dtype] = None, - columns: Optional[List[Union[Any, Tuple]]] = None, + columns: Optional[List[Name]] = None, ) -> "DataFrame": """ Construct DataFrame from dict of array-like or dicts. @@ -11673,7 +11649,7 @@ defaultdict(, {'col..., 'col...})] # Override the `groupby` to specify the actual return type annotation. def groupby( self, - by: Union[Any, Tuple, "Series", List[Union[Any, Tuple, "Series"]]], + by: Union[Name, "Series", List[Union[Name, "Series"]]], axis: Axis = 0, as_index: bool = True, dropna: bool = True, @@ -11685,7 +11661,7 @@ defaultdict(, {'col..., 'col...})] groupby.__doc__ = Frame.groupby.__doc__ def _build_groupby( - self, by: List[Union["Series", Tuple]], as_index: bool, dropna: bool + self, by: List[Union["Series", Label]], as_index: bool, dropna: bool ) -> "DataFrameGroupBy": from pyspark.pandas.groupby import DataFrameGroupBy @@ -11780,8 +11756,8 @@ defaultdict(, {'col..., 'col...})] value = DataFrame._index_normalized_frame(level, value) def assign_columns( - psdf: DataFrame, this_column_labels: List[Tuple], that_column_labels: List[Tuple] - ) -> Iterator[Tuple["Series", Tuple]]: + psdf: DataFrame, this_column_labels: List[Label], that_column_labels: List[Label] + ) -> Iterator[Tuple["Series", Label]]: assert len(key) == len(that_column_labels) # Note that here intentionally uses `zip_longest` that combine # that_columns. @@ -11821,9 +11797,7 @@ defaultdict(, {'col..., 'col...})] self._update_internal_frame(psdf._internal) @staticmethod - def _index_normalized_label( - level: int, labels: Union[Any, Tuple, Sequence[Union[Any, Tuple]]] - ) -> List[Tuple]: + def _index_normalized_label(level: int, labels: Union[Name, Sequence[Name]]) -> List[Label]: """ Returns a label that is normalized against the current column index level. For example, the key "abc" can be ("abc", "", "") if the current Frame has @@ -11910,7 +11884,7 @@ defaultdict(, {'col..., 'col...})] ] return list(super().__dir__()) + fields - def __iter__(self) -> Iterator[Union[Any, Tuple]]: + def __iter__(self) -> Iterator[Name]: return iter(self.columns) # NDArray Compat @@ -11930,8 +11904,8 @@ defaultdict(, {'col..., 'col...})] # Different DataFrames def apply_op( - psdf: DataFrame, this_column_labels: List[Tuple], that_column_labels: List[Tuple] - ) -> Iterator[Tuple["Series", Tuple]]: + psdf: DataFrame, this_column_labels: List[Label], that_column_labels: List[Label] + ) -> Iterator[Tuple["Series", Label]]: for this_label, that_label in zip(this_column_labels, that_column_labels): yield ( ufunc( diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py index 790abef3c0..c60097e952 100644 --- a/python/pyspark/pandas/generic.py +++ b/python/pyspark/pandas/generic.py @@ -53,7 +53,15 @@ from pyspark.sql.types import ( ) from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. -from pyspark.pandas._typing import Axis, DataFrameOrSeries, Dtype, FrameLike, Scalar +from pyspark.pandas._typing import ( + Axis, + DataFrameOrSeries, + Dtype, + FrameLike, + Label, + Name, + Scalar, +) from pyspark.pandas.indexing import AtIndexer, iAtIndexer, iLocIndexer, LocIndexer from pyspark.pandas.internal import InternalFrame from pyspark.pandas.spark import functions as SF @@ -636,7 +644,7 @@ class Frame(object, metaclass=ABCMeta): path: Optional[str] = None, sep: str = ",", na_rep: str = "", - columns: Optional[List[Union[Any, Tuple]]] = None, + columns: Optional[List[Name]] = None, header: bool = True, quotechar: str = '"', date_format: Optional[str] = None, @@ -811,9 +819,11 @@ class Frame(object, metaclass=ABCMeta): column_labels = psdf._internal.column_labels else: column_labels = [] - for label in columns: - if not is_name_like_tuple(label): - label = (label,) + for col in columns: + if is_name_like_tuple(col): + label = cast(Label, col) + else: + label = cast(Label, (col,)) if label not in psdf._internal.column_labels: raise KeyError(name_like_string(label)) column_labels.append(label) @@ -2119,7 +2129,7 @@ class Frame(object, metaclass=ABCMeta): # should be updated when it's supported. def groupby( self: FrameLike, - by: Union[Any, Tuple, "Series", List[Union[Any, Tuple, "Series"]]], + by: Union[Name, "Series", List[Union[Name, "Series"]]], axis: Axis = 0, as_index: bool = True, dropna: bool = True, @@ -2206,15 +2216,15 @@ class Frame(object, metaclass=ABCMeta): if isinstance(by, ps.DataFrame): raise ValueError("Grouper for '{}' not 1-dimensional".format(type(by).__name__)) elif isinstance(by, ps.Series): - new_by = [by] # type: List[Union[Tuple, ps.Series]] + new_by = [by] # type: List[Union[Label, ps.Series]] elif is_name_like_tuple(by): if isinstance(self, ps.Series): raise KeyError(by) - new_by = [cast(Tuple, by)] + new_by = [cast(Label, by)] elif is_name_like_value(by): if isinstance(self, ps.Series): raise KeyError(by) - new_by = [(by,)] + new_by = [cast(Label, (by,))] elif is_list_like(by): new_by = [] for key in by: @@ -2227,11 +2237,11 @@ class Frame(object, metaclass=ABCMeta): elif is_name_like_tuple(key): if isinstance(self, ps.Series): raise KeyError(key) - new_by.append(key) + new_by.append(cast(Label, key)) elif is_name_like_value(key): if isinstance(self, ps.Series): raise KeyError(key) - new_by.append((key,)) + new_by.append(cast(Label, (key,))) else: raise ValueError( "Grouper for '{}' not 1-dimensional".format(type(key).__name__) @@ -2248,7 +2258,7 @@ class Frame(object, metaclass=ABCMeta): @abstractmethod def _build_groupby( - self: FrameLike, by: List[Union["Series", Tuple]], as_index: bool, dropna: bool + self: FrameLike, by: List[Union["Series", Label]], as_index: bool, dropna: bool ) -> "GroupBy[FrameLike]": pass diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index 1d0f85d5c2..89fd4f2048 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -58,7 +58,7 @@ from pyspark.sql.types import ( # noqa: F401 ) from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. -from pyspark.pandas._typing import Axis, FrameLike +from pyspark.pandas._typing import Axis, FrameLike, Label, Name from pyspark.pandas.typedef import infer_return_type, DataFrameType, ScalarType, SeriesType from pyspark.pandas.frame import DataFrame from pyspark.pandas.internal import ( @@ -110,7 +110,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta): groupkeys: List[Series], as_index: bool, dropna: bool, - column_labels_to_exlcude: Set[Tuple], + column_labels_to_exlcude: Set[Label], agg_columns_selected: bool, agg_columns: List[Series], ): @@ -147,9 +147,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta): # TODO: not all arguments are implemented comparing to pandas' for now. def aggregate( self, - func_or_funcs: Optional[ - Union[str, List[str], Dict[Union[Any, Tuple], Union[str, List[str]]]] - ] = None, + func_or_funcs: Optional[Union[str, List[str], Dict[Name, Union[str, List[str]]]]] = None, *args: Any, **kwargs: Any ) -> DataFrame: @@ -312,7 +310,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta): @staticmethod def _spark_groupby( psdf: DataFrame, - func: Mapping[Union[Any, Tuple], Union[str, List[str]]], + func: Mapping[Name, Union[str, List[str]]], groupkeys: Sequence[Series] = (), ) -> InternalFrame: groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(groupkeys))] @@ -1405,11 +1403,11 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta): @staticmethod def _prepare_group_map_apply( psdf: DataFrame, groupkeys: List[Series], agg_columns: List[Series] - ) -> Tuple[DataFrame, List[Tuple], List[str]]: + ) -> Tuple[DataFrame, List[Label], List[str]]: groupkey_labels = [ verify_temp_column_name(psdf, "__groupkey_{}__".format(i)) for i in range(len(groupkeys)) - ] # type: List[Tuple] + ] # type: List[Label] psdf = psdf[[s.rename(label) for s, label in zip(groupkeys, groupkey_labels)] + agg_columns] groupkey_names = [label if len(label) > 1 else label[0] for label in groupkey_labels] return DataFrame(psdf._internal.resolved_copy), groupkey_labels, groupkey_names @@ -2377,7 +2375,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta): return ExpandingGroupby(self, min_periods=min_periods) - def get_group(self, name: Union[Any, Tuple, List[Union[Any, Tuple]]]) -> FrameLike: + def get_group(self, name: Union[Name, List[Name]]) -> FrameLike: """ Construct DataFrame from group with provided name. @@ -2594,8 +2592,8 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta): @staticmethod def _resolve_grouping_from_diff_dataframes( - psdf: DataFrame, by: List[Union[Series, Tuple]] - ) -> Tuple[DataFrame, List[Series], Set[Tuple]]: + psdf: DataFrame, by: List[Union[Series, Label]] + ) -> Tuple[DataFrame, List[Series], Set[Label]]: column_labels_level = psdf._internal.column_labels_level column_labels = [] @@ -2636,8 +2634,8 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta): ) def assign_columns( - psdf: DataFrame, this_column_labels: List[Tuple], that_column_labels: List[Tuple] - ) -> Iterator[Tuple[Series, Tuple]]: + psdf: DataFrame, this_column_labels: List[Label], that_column_labels: List[Label] + ) -> Iterator[Tuple[Series, Label]]: raise NotImplementedError( "Duplicated labels with groupby() and " "'compute.ops_on_diff_frames' option are not supported currently " @@ -2669,7 +2667,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta): return psdf, new_by_series, tmp_column_labels @staticmethod - def _resolve_grouping(psdf: DataFrame, by: List[Union[Series, Tuple]]) -> List[Series]: + def _resolve_grouping(psdf: DataFrame, by: List[Union[Series, Label]]) -> List[Series]: new_by_series = [] for col_or_s in by: if isinstance(col_or_s, Series): @@ -2687,7 +2685,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta): class DataFrameGroupBy(GroupBy[DataFrame]): @staticmethod def _build( - psdf: DataFrame, by: List[Union[Series, Tuple]], as_index: bool, dropna: bool + psdf: DataFrame, by: List[Union[Series, Label]], as_index: bool, dropna: bool ) -> "DataFrameGroupBy": if any(isinstance(col_or_s, Series) and not same_anchor(psdf, col_or_s) for col_or_s in by): ( @@ -2712,8 +2710,8 @@ class DataFrameGroupBy(GroupBy[DataFrame]): by: List[Series], as_index: bool, dropna: bool, - column_labels_to_exlcude: Set[Tuple], - agg_columns: List[Tuple] = None, + column_labels_to_exlcude: Set[Label], + agg_columns: List[Label] = None, ): agg_columns_selected = agg_columns is not None if agg_columns_selected: @@ -2891,7 +2889,7 @@ class DataFrameGroupBy(GroupBy[DataFrame]): class SeriesGroupBy(GroupBy[Series]): @staticmethod def _build( - psser: Series, by: List[Union[Series, Tuple]], as_index: bool, dropna: bool + psser: Series, by: List[Union[Series, Label]], as_index: bool, dropna: bool ) -> "SeriesGroupBy": if any( isinstance(col_or_s, Series) and not same_anchor(psser, col_or_s) for col_or_s in by @@ -3255,8 +3253,8 @@ def is_multi_agg_with_relabel(**kwargs: Any) -> bool: def normalize_keyword_aggregation( - kwargs: Dict[str, Tuple[Union[Any, Tuple], str]], -) -> Tuple[Dict[Union[Any, Tuple], List[str]], List[str], List[Tuple]]: + kwargs: Dict[str, Tuple[Name, str]], +) -> Tuple[Dict[Name, List[str]], List[str], List[Tuple]]: """ Normalize user-provided kwargs. diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py index 0ca9724b9d..c06caf08c4 100644 --- a/python/pyspark/pandas/indexes/base.py +++ b/python/pyspark/pandas/indexes/base.py @@ -40,7 +40,7 @@ from pyspark.sql import functions as F, Column from pyspark.sql.types import FractionalType, IntegralType, TimestampType from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. -from pyspark.pandas._typing import Dtype, Scalar +from pyspark.pandas._typing import Dtype, Label, Name, Scalar from pyspark.pandas.config import get_option, option_context from pyspark.pandas.base import IndexOpsMixin from pyspark.pandas.frame import DataFrame @@ -125,7 +125,7 @@ class Index(IndexOpsMixin): data: Optional[Any] = None, dtype: Optional[Union[str, Dtype]] = None, copy: bool = False, - name: Optional[Union[Any, Tuple]] = None, + name: Optional[Name] = None, tupleize_cols: bool = True, **kwargs: Any ) -> "Index": @@ -215,7 +215,7 @@ class Index(IndexOpsMixin): ) @property - def _column_label(self) -> Optional[Tuple]: + def _column_label(self) -> Optional[Label]: return self._psdf._internal.index_names[0] def _with_new_scol(self, scol: Column, *, field: Optional[InternalField] = None) -> "Index": @@ -636,24 +636,24 @@ class Index(IndexOpsMixin): return not self.has_duplicates @property - def name(self) -> Union[Any, Tuple]: + def name(self) -> Name: """Return name of the Index.""" return self.names[0] @name.setter - def name(self, name: Union[Any, Tuple]) -> None: + def name(self, name: Name) -> None: self.names = [name] @property - def names(self) -> List[Union[Any, Tuple]]: + def names(self) -> List[Name]: """Return names of the Index.""" return [ name if name is None or len(name) > 1 else name[0] - for name in self._internal.index_names # type: ignore + for name in self._internal.index_names ] @names.setter - def names(self, names: List[Union[Any, Tuple]]) -> None: + def names(self, names: List[Name]) -> None: if not is_list_like(names): raise ValueError("Names must be a list-like") if self._internal.index_level != len(names): @@ -684,9 +684,7 @@ class Index(IndexOpsMixin): """ return self._internal.index_level - def rename( - self, name: Union[Any, Tuple, List[Union[Any, Tuple]]], inplace: bool = False - ) -> Optional["Index"]: + def rename(self, name: Union[Name, List[Name]], inplace: bool = False) -> Optional["Index"]: """ Alter Index or MultiIndex name. Able to set new names without level. Defaults to returning new index. @@ -749,7 +747,7 @@ class Index(IndexOpsMixin): else: return DataFrame(internal).index - def _verify_for_rename(self, name: Union[Any, Tuple]) -> List[Tuple]: + def _verify_for_rename(self, name: Name) -> List[Label]: if is_hashable(name): if is_name_like_tuple(name): return [name] @@ -830,7 +828,7 @@ class Index(IndexOpsMixin): ) return DataFrame(internal).index - def to_series(self, name: Optional[Union[Any, Tuple]] = None) -> Series: + def to_series(self, name: Optional[Name] = None) -> Series: """ Create a Series with both index and values equal to the index keys useful with map for returning an indexer based on an index. @@ -868,7 +866,7 @@ class Index(IndexOpsMixin): name = self.name column_labels = [ name if is_name_like_tuple(name) else (name,) - ] # type: List[Optional[Tuple]] + ] # type: List[Optional[Label]] internal = self._internal.copy( column_labels=column_labels, data_spark_columns=[scol], @@ -877,7 +875,7 @@ class Index(IndexOpsMixin): ) return first_series(DataFrame(internal)) - def to_frame(self, index: bool = True, name: Optional[Union[Any, Tuple]] = None) -> DataFrame: + def to_frame(self, index: bool = True, name: Optional[Name] = None) -> DataFrame: """ Create a DataFrame with a column containing the Index. @@ -939,7 +937,7 @@ class Index(IndexOpsMixin): return self._to_frame(index=index, names=[name]) - def _to_frame(self, index: bool, names: List[Tuple]) -> DataFrame: + def _to_frame(self, index: bool, names: List[Label]) -> DataFrame: if index: index_spark_columns = self._internal.index_spark_columns index_names = self._internal.index_names @@ -1115,7 +1113,7 @@ class Index(IndexOpsMixin): ) return DataFrame(internal).index - def unique(self, level: Optional[Union[int, Any, Tuple]] = None) -> "Index": + def unique(self, level: Optional[Union[int, Name]] = None) -> "Index": """ Return unique values in the index. @@ -1203,7 +1201,7 @@ class Index(IndexOpsMixin): ) return DataFrame(internal).index - def _validate_index_level(self, level: Union[int, Any, Tuple]) -> None: + def _validate_index_level(self, level: Union[int, Name]) -> None: """ Validate index level. For single-level Index getting level number is a no-op, but some @@ -1222,7 +1220,7 @@ class Index(IndexOpsMixin): "Requested level ({}) does not match index name ({})".format(level, self.name) ) - def get_level_values(self, level: Union[int, Any, Tuple]) -> "Index": + def get_level_values(self, level: Union[int, Name]) -> "Index": """ Return Index if a valid level is given. @@ -1238,9 +1236,7 @@ class Index(IndexOpsMixin): self._validate_index_level(level) return self - def copy( - self, name: Optional[Union[Any, Tuple]] = None, deep: Optional[bool] = None - ) -> "Index": + def copy(self, name: Optional[Name] = None, deep: Optional[bool] = None) -> "Index": """ Make a copy of this object. name sets those attributes on the new object. @@ -1279,7 +1275,7 @@ class Index(IndexOpsMixin): result.name = name return result - def droplevel(self, level: Union[int, Any, Tuple, List[Union[int, Any, Tuple]]]) -> "Index": + def droplevel(self, level: Union[int, Name, List[Union[int, Name]]]) -> "Index": """ Return index with requested level(s) removed. If resulting index has only 1 level left, the result will be @@ -1317,9 +1313,9 @@ class Index(IndexOpsMixin): names = self.names nlevels = self.nlevels if not is_list_like(level): - levels = [cast(Union[int, Any, Tuple], level)] + levels = [cast(Union[int, Name], level)] else: - levels = cast(List[Union[int, Any, Tuple]], level) + levels = cast(List[Union[int, Name]], level) int_level = set() for n in levels: @@ -1375,7 +1371,7 @@ class Index(IndexOpsMixin): def symmetric_difference( self, other: "Index", - result_name: Optional[Union[Any, Tuple]] = None, + result_name: Optional[Name] = None, sort: Optional[bool] = None, ) -> "Index": """ @@ -1887,8 +1883,8 @@ class Index(IndexOpsMixin): def set_names( self, - names: Union[Any, Tuple, List[Union[Any, Tuple]]], - level: Optional[Union[int, Any, Tuple, List[Union[int, Any, Tuple]]]] = None, + names: Union[Name, List[Name]], + level: Optional[Union[int, Name, List[Union[int, Name]]]] = None, inplace: bool = False, ) -> Optional["Index"]: """ diff --git a/python/pyspark/pandas/indexes/multi.py b/python/pyspark/pandas/indexes/multi.py index 5f6b430ff8..0971f20d62 100644 --- a/python/pyspark/pandas/indexes/multi.py +++ b/python/pyspark/pandas/indexes/multi.py @@ -28,7 +28,7 @@ from pyspark.sql.types import DataType # For running doctests and reference resolution in PyCharm. from pyspark import pandas as ps # noqa: F401 -from pyspark.pandas._typing import Scalar +from pyspark.pandas._typing import Label, Name, Scalar from pyspark.pandas.exceptions import PandasNotImplementedError from pyspark.pandas.frame import DataFrame from pyspark.pandas.indexes.base import Index @@ -146,7 +146,7 @@ class MultiIndex(Index): ) @property - def _column_label(self) -> Optional[Tuple]: + def _column_label(self) -> Optional[Label]: return None def __abs__(self) -> "MultiIndex": @@ -169,7 +169,7 @@ class MultiIndex(Index): def from_tuples( tuples: List[Tuple], sortorder: Optional[int] = None, - names: Optional[List[Union[Any, Tuple]]] = None, + names: Optional[List[Name]] = None, ) -> "MultiIndex": """ Convert list of tuples to MultiIndex. @@ -210,7 +210,7 @@ class MultiIndex(Index): def from_arrays( arrays: List[List], sortorder: Optional[int] = None, - names: Optional[List[Union[Any, Tuple]]] = None, + names: Optional[List[Name]] = None, ) -> "MultiIndex": """ Convert arrays to MultiIndex. @@ -251,7 +251,7 @@ class MultiIndex(Index): def from_product( iterables: List[List], sortorder: Optional[int] = None, - names: Optional[List[Union[Any, Tuple]]] = None, + names: Optional[List[Name]] = None, ) -> "MultiIndex": """ Make a MultiIndex from the cartesian product of multiple iterables. @@ -297,7 +297,7 @@ class MultiIndex(Index): ) @staticmethod - def from_frame(df: DataFrame, names: Optional[List[Union[Any, Tuple]]] = None) -> "MultiIndex": + def from_frame(df: DataFrame, names: Optional[List[Name]] = None) -> "MultiIndex": """ Make a MultiIndex from a DataFrame. @@ -369,16 +369,14 @@ class MultiIndex(Index): return cast(MultiIndex, DataFrame(internal).index) @property - def name(self) -> Union[Any, Tuple]: + def name(self) -> Name: raise PandasNotImplementedError(class_name="pd.MultiIndex", property_name="name") @name.setter - def name(self, name: Union[Any, Tuple]) -> None: + def name(self, name: Name) -> None: raise PandasNotImplementedError(class_name="pd.MultiIndex", property_name="name") - def _verify_for_rename( # type: ignore[override] - self, name: List[Union[Any, Tuple]] - ) -> List[Tuple]: + def _verify_for_rename(self, name: List[Name]) -> List[Label]: # type: ignore[override] if is_list_like(name): if self._internal.index_level != len(name): raise ValueError( @@ -575,7 +573,7 @@ class MultiIndex(Index): return first_series(DataFrame(internal)) def to_frame( # type: ignore[override] - self, index: bool = True, name: Optional[List[Union[Any, Tuple]]] = None + self, index: bool = True, name: Optional[List[Name]] = None ) -> DataFrame: """ Create a DataFrame with the levels of the MultiIndex as columns. @@ -712,7 +710,7 @@ class MultiIndex(Index): def symmetric_difference( # type: ignore[override] self, other: Index, - result_name: Optional[List[Union[Any, Tuple]]] = None, + result_name: Optional[List[Name]] = None, sort: Optional[bool] = None, ) -> "MultiIndex": """ @@ -807,9 +805,7 @@ class MultiIndex(Index): return result # TODO: ADD error parameter - def drop( - self, codes: List[Any], level: Optional[Union[int, Any, Tuple]] = None - ) -> "MultiIndex": + def drop(self, codes: List[Any], level: Optional[Union[int, Name]] = None) -> "MultiIndex": """ Make new MultiIndex with passed list of labels deleted @@ -920,7 +916,7 @@ class MultiIndex(Index): return partial(property_or_func, self) raise AttributeError("'MultiIndex' object has no attribute '{}'".format(item)) - def _get_level_number(self, level: Union[int, Any, Tuple]) -> int: + def _get_level_number(self, level: Union[int, Name]) -> int: """ Return the level number if a valid level is given. """ @@ -948,7 +944,7 @@ class MultiIndex(Index): return level - def get_level_values(self, level: Union[int, Any, Tuple]) -> Index: + def get_level_values(self, level: Union[int, Name]) -> Index: """ Return vector of label values for requested level, equal to the length of the index. @@ -1046,7 +1042,7 @@ class MultiIndex(Index): index_name = [ (name,) for name in self._internal.index_spark_column_names - ] # type: List[Tuple] + ] # type: List[Label] sdf_before = self.to_frame(name=index_name)[:loc].to_spark() sdf_middle = Index([item]).to_frame(name=index_name).to_spark() sdf_after = self.to_frame(name=index_name)[loc:].to_spark() diff --git a/python/pyspark/pandas/indexes/numeric.py b/python/pyspark/pandas/indexes/numeric.py index 26795a8e38..ceb1b9eef7 100644 --- a/python/pyspark/pandas/indexes/numeric.py +++ b/python/pyspark/pandas/indexes/numeric.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Optional, Tuple, Union, cast +from typing import Any, Optional, Union, cast import pandas as pd from pandas.api.types import is_hashable from pyspark import pandas as ps -from pyspark.pandas._typing import Dtype +from pyspark.pandas._typing import Dtype, Name from pyspark.pandas.indexes.base import Index from pyspark.pandas.series import Series @@ -89,7 +89,7 @@ class Int64Index(IntegerIndex): data: Optional[Any] = None, dtype: Optional[Union[str, Dtype]] = None, copy: bool = False, - name: Optional[Union[Any, Tuple]] = None, + name: Optional[Name] = None, ) -> "Int64Index": if not is_hashable(name): raise TypeError("Index.name must be a hashable type") @@ -151,7 +151,7 @@ class Float64Index(NumericIndex): data: Optional[Any] = None, dtype: Optional[Union[str, Dtype]] = None, copy: bool = False, - name: Optional[Union[Any, Tuple]] = None, + name: Optional[Name] = None, ) -> "Float64Index": if not is_hashable(name): raise TypeError("Index.name must be a hashable type") diff --git a/python/pyspark/pandas/indexing.py b/python/pyspark/pandas/indexing.py index 6e47b01a7b..f374a6f8dd 100644 --- a/python/pyspark/pandas/indexing.py +++ b/python/pyspark/pandas/indexing.py @@ -31,7 +31,7 @@ from pyspark.sql.utils import AnalysisException import numpy as np from pyspark import pandas as ps # noqa: F401 -from pyspark.pandas._typing import Scalar +from pyspark.pandas._typing import Label, Name, Scalar from pyspark.pandas.internal import ( InternalField, InternalFrame, @@ -274,13 +274,13 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta): return self._select_rows_else(rows_sel) def _select_cols( - self, cols_sel: Any, missing_keys: Optional[List[Tuple]] = None + self, cols_sel: Any, missing_keys: Optional[List[Name]] = None ) -> Tuple[ - List[Tuple], + List[Label], Optional[List[Column]], Optional[List[InternalField]], bool, - Optional[Tuple], + Optional[Name], ]: """ Dispatch the logic for select columns to more specific methods by `cols_sel` argument types. @@ -366,65 +366,65 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta): @abstractmethod def _select_cols_by_series( - self, cols_sel: "Series", missing_keys: Optional[List[Tuple]] + self, cols_sel: "Series", missing_keys: Optional[List[Name]] ) -> Tuple[ - List[Tuple], + List[Label], Optional[List[Column]], Optional[List[InternalField]], bool, - Optional[Tuple], + Optional[Name], ]: """Select columns by `Series` type key.""" pass @abstractmethod def _select_cols_by_spark_column( - self, cols_sel: Column, missing_keys: Optional[List[Tuple]] + self, cols_sel: Column, missing_keys: Optional[List[Name]] ) -> Tuple[ - List[Tuple], + List[Label], Optional[List[Column]], Optional[List[InternalField]], bool, - Optional[Tuple], + Optional[Name], ]: """Select columns by Spark `Column` type key.""" pass @abstractmethod def _select_cols_by_slice( - self, cols_sel: slice, missing_keys: Optional[List[Tuple]] + self, cols_sel: slice, missing_keys: Optional[List[Name]] ) -> Tuple[ - List[Tuple], + List[Label], Optional[List[Column]], Optional[List[InternalField]], bool, - Optional[Tuple], + Optional[Name], ]: """Select columns by `slice` type key.""" pass @abstractmethod def _select_cols_by_iterable( - self, cols_sel: Iterable, missing_keys: Optional[List[Tuple]] + self, cols_sel: Iterable, missing_keys: Optional[List[Name]] ) -> Tuple[ - List[Tuple], + List[Label], Optional[List[Column]], Optional[List[InternalField]], bool, - Optional[Tuple], + Optional[Name], ]: """Select columns by `Iterable` type key.""" pass @abstractmethod def _select_cols_else( - self, cols_sel: Any, missing_keys: Optional[List[Tuple]] + self, cols_sel: Any, missing_keys: Optional[List[Name]] ) -> Tuple[ - List[Tuple], + List[Label], Optional[List[Column]], Optional[List[InternalField]], bool, - Optional[Tuple], + Optional[Name], ]: """Select columns by other type key.""" pass @@ -706,7 +706,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta): return cond, limit, remaining_index = self._select_rows(rows_sel) - missing_keys = [] # type: Optional[List[Tuple]] + missing_keys = [] # type: Optional[List[Name]] _, data_spark_columns, _, _, _ = self._select_cols(cols_sel, missing_keys=missing_keys) if cond is None: @@ -746,9 +746,11 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta): new_fields.append(new_field) column_labels = self._internal.column_labels.copy() - for label in missing_keys: - if not is_name_like_tuple(label): - label = (label,) + for missing in missing_keys: + if is_name_like_tuple(missing): + label = cast(Label, missing) + else: + label = cast(Label, (missing,)) if len(label) < self._internal.column_labels_level: label = tuple( list(label) + ([""] * (self._internal.column_labels_level - len(label))) @@ -1167,11 +1169,11 @@ class LocIndexer(LocIndexerLike): def _get_from_multiindex_column( self, - key: Optional[Tuple], - missing_keys: Optional[List[Tuple]], - labels: Optional[List[Tuple]] = None, + key: Optional[Label], + missing_keys: Optional[List[Name]], + labels: Optional[List[Tuple[Label, Label]]] = None, recursed: int = 0, - ) -> Tuple[List[Tuple], Optional[List[Column]], List[InternalField], bool, Optional[Tuple]]: + ) -> Tuple[List[Label], Optional[List[Column]], List[InternalField], bool, Optional[Name]]: """Select columns from multi-index columns.""" assert isinstance(key, tuple) if labels is None: @@ -1196,14 +1198,14 @@ class LocIndexer(LocIndexerLike): else: returns_series = all(lbl is None or len(lbl) == 0 for _, lbl in labels) if returns_series: - labels = set(label for label, _ in labels) # type: ignore - assert len(labels) == 1 - label = list(labels)[0] + label_set = set(label for label, _ in labels) + assert len(label_set) == 1 + label = list(label_set)[0] column_labels = [label] data_spark_columns = [self._internal.spark_column_for(label)] data_fields = [self._internal.field_for(label)] if label is None: - series_name = None + series_name = None # type: Name else: if recursed > 0: label = label[:-recursed] @@ -1219,13 +1221,13 @@ class LocIndexer(LocIndexerLike): return column_labels, data_spark_columns, data_fields, returns_series, series_name def _select_cols_by_series( - self, cols_sel: "Series", missing_keys: Optional[List[Tuple]] + self, cols_sel: "Series", missing_keys: Optional[List[Name]] ) -> Tuple[ - List[Tuple], + List[Label], Optional[List[Column]], Optional[List[InternalField]], bool, - Optional[Tuple], + Optional[Name], ]: column_labels = cols_sel._internal.column_labels data_spark_columns = cols_sel._internal.data_spark_columns @@ -1233,28 +1235,28 @@ class LocIndexer(LocIndexerLike): return column_labels, data_spark_columns, data_fields, True, None def _select_cols_by_spark_column( - self, cols_sel: Column, missing_keys: Optional[List[Tuple]] + self, cols_sel: Column, missing_keys: Optional[List[Name]] ) -> Tuple[ - List[Tuple], + List[Label], Optional[List[Column]], Optional[List[InternalField]], bool, - Optional[Tuple], + Optional[Name], ]: column_labels = [ (self._internal.spark_frame.select(cols_sel).columns[0],) - ] # type: List[Tuple] + ] # type: List[Label] data_spark_columns = [cols_sel] return column_labels, data_spark_columns, None, True, None def _select_cols_by_slice( - self, cols_sel: slice, missing_keys: Optional[List[Tuple]] + self, cols_sel: slice, missing_keys: Optional[List[Name]] ) -> Tuple[ - List[Tuple], + List[Label], Optional[List[Column]], Optional[List[InternalField]], bool, - Optional[Tuple], + Optional[Name], ]: start, stop = self._psdf_or_psser.columns.slice_locs( start=cols_sel.start, end=cols_sel.stop @@ -1265,13 +1267,13 @@ class LocIndexer(LocIndexerLike): return column_labels, data_spark_columns, data_fields, False, None def _select_cols_by_iterable( - self, cols_sel: Iterable, missing_keys: Optional[List[Tuple]] + self, cols_sel: Iterable, missing_keys: Optional[List[Name]] ) -> Tuple[ - List[Tuple], + List[Label], Optional[List[Column]], Optional[List[InternalField]], bool, - Optional[Tuple], + Optional[Name], ]: from pyspark.pandas.series import Series @@ -1356,13 +1358,13 @@ class LocIndexer(LocIndexerLike): return column_labels, data_spark_columns, data_fields, False, None def _select_cols_else( - self, cols_sel: Any, missing_keys: Optional[List[Tuple]] + self, cols_sel: Any, missing_keys: Optional[List[Name]] ) -> Tuple[ - List[Tuple], + List[Label], Optional[List[Column]], Optional[List[InternalField]], bool, - Optional[Tuple], + Optional[Name], ]: if not is_name_like_tuple(cols_sel): cols_sel = (cols_sel,) @@ -1552,7 +1554,7 @@ class iLocIndexer(LocIndexerLike): ) @lazy_property - def _sequence_col(self) -> Union[Any, Tuple]: + def _sequence_col(self) -> str: # Use resolved_copy to fix the natural order. internal = super()._internal.resolved_copy return verify_temp_column_name(internal.spark_frame, "__distributed_sequence_column__") @@ -1692,13 +1694,13 @@ class iLocIndexer(LocIndexerLike): ) def _select_cols_by_series( - self, cols_sel: "Series", missing_keys: Optional[List[Tuple]] + self, cols_sel: "Series", missing_keys: Optional[List[Name]] ) -> Tuple[ - List[Tuple], + List[Label], Optional[List[Column]], Optional[List[InternalField]], bool, - Optional[Tuple], + Optional[Name], ]: raise ValueError( "Location based indexing can only have [integer, integer slice, " @@ -1706,13 +1708,13 @@ class iLocIndexer(LocIndexerLike): ) def _select_cols_by_spark_column( - self, cols_sel: Column, missing_keys: Optional[List[Tuple]] + self, cols_sel: Column, missing_keys: Optional[List[Name]] ) -> Tuple[ - List[Tuple], + List[Label], Optional[List[Column]], Optional[List[InternalField]], bool, - Optional[Tuple], + Optional[Name], ]: raise ValueError( "Location based indexing can only have [integer, integer slice, " @@ -1720,13 +1722,13 @@ class iLocIndexer(LocIndexerLike): ) def _select_cols_by_slice( - self, cols_sel: slice, missing_keys: Optional[List[Tuple]] + self, cols_sel: slice, missing_keys: Optional[List[Name]] ) -> Tuple[ - List[Tuple], + List[Label], Optional[List[Column]], Optional[List[InternalField]], bool, - Optional[Tuple], + Optional[Name], ]: if all( s is None or isinstance(s, int) for s in (cols_sel.start, cols_sel.stop, cols_sel.step) @@ -1750,13 +1752,13 @@ class iLocIndexer(LocIndexerLike): ) def _select_cols_by_iterable( - self, cols_sel: Iterable, missing_keys: Optional[List[Tuple]] + self, cols_sel: Iterable, missing_keys: Optional[List[Name]] ) -> Tuple[ - List[Tuple], + List[Label], Optional[List[Column]], Optional[List[InternalField]], bool, - Optional[Tuple], + Optional[Name], ]: if all(isinstance(s, bool) for s in cols_sel): cols_sel = [i for i, s in enumerate(cols_sel) if s] @@ -1769,13 +1771,13 @@ class iLocIndexer(LocIndexerLike): raise TypeError("cannot perform reduce with flexible type") def _select_cols_else( - self, cols_sel: Any, missing_keys: Optional[List[Tuple]] + self, cols_sel: Any, missing_keys: Optional[List[Name]] ) -> Tuple[ - List[Tuple], + List[Label], Optional[List[Column]], Optional[List[InternalField]], bool, - Optional[Tuple], + Optional[Name], ]: if isinstance(cols_sel, int): if cols_sel > len(self._internal.column_labels): diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index a4ffae36d1..2c2fa97809 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -41,6 +41,7 @@ from pyspark.sql.types import ( # noqa: F401 # For running doctests and reference resolution in PyCharm. from pyspark import pandas as ps +from pyspark.pandas._typing import Label if TYPE_CHECKING: # This is required in old Python 3.5 to prevent circular reference. @@ -529,12 +530,12 @@ class InternalFrame(object): self, spark_frame: SparkDataFrame, index_spark_columns: Optional[List[Column]], - index_names: Optional[List[Optional[Tuple]]] = None, + index_names: Optional[List[Optional[Label]]] = None, index_fields: Optional[List[InternalField]] = None, - column_labels: Optional[List[Tuple]] = None, + column_labels: Optional[List[Label]] = None, data_spark_columns: Optional[List[Column]] = None, data_fields: Optional[List[InternalField]] = None, - column_label_names: Optional[List[Optional[Tuple]]] = None, + column_label_names: Optional[List[Optional[Label]]] = None, ): """ Create a new internal immutable DataFrame to manage Spark DataFrame, column fields and @@ -783,13 +784,13 @@ class InternalFrame(object): is_name_like_tuple(index_name, check_type=True) for index_name in index_names ), index_names - self._index_names = index_names # type: List[Optional[Tuple]] + self._index_names = index_names # type: List[Optional[Label]] # column_labels if column_labels is None: self._column_labels = [ (col,) for col in spark_frame.select(self._data_spark_columns).columns - ] # type: List[Tuple] + ] # type: List[Label] else: assert len(column_labels) == len(self._data_spark_columns), ( len(column_labels), @@ -810,7 +811,7 @@ class InternalFrame(object): if column_label_names is None: self._column_label_names = [None] * column_labels_level( self._column_labels - ) # type: List[Optional[Tuple]] + ) # type: List[Optional[Label]] else: if len(self._column_labels) > 0: assert len(column_label_names) == column_labels_level(self._column_labels), ( @@ -1027,7 +1028,7 @@ class InternalFrame(object): False, ) - def spark_column_for(self, label: Tuple) -> Column: + def spark_column_for(self, label: Label) -> Column: """Return Spark Column for the given column label.""" column_labels_to_scol = dict(zip(self.column_labels, self.data_spark_columns)) if label in column_labels_to_scol: @@ -1035,28 +1036,28 @@ class InternalFrame(object): else: raise KeyError(name_like_string(label)) - def spark_column_name_for(self, label_or_scol: Union[Tuple, Column]) -> str: + def spark_column_name_for(self, label_or_scol: Union[Label, Column]) -> str: """Return the actual Spark column name for the given column label.""" if isinstance(label_or_scol, Column): return self.spark_frame.select(label_or_scol).columns[0] else: return self.field_for(label_or_scol).name - def spark_type_for(self, label_or_scol: Union[Tuple, Column]) -> DataType: + def spark_type_for(self, label_or_scol: Union[Label, Column]) -> DataType: """Return DataType for the given column label.""" if isinstance(label_or_scol, Column): return self.spark_frame.select(label_or_scol).schema[0].dataType else: return self.field_for(label_or_scol).spark_type - def spark_column_nullable_for(self, label_or_scol: Union[Tuple, Column]) -> bool: + def spark_column_nullable_for(self, label_or_scol: Union[Label, Column]) -> bool: """Return nullability for the given column label.""" if isinstance(label_or_scol, Column): return self.spark_frame.select(label_or_scol).schema[0].nullable else: return self.field_for(label_or_scol).nullable - def field_for(self, label: Tuple) -> InternalField: + def field_for(self, label: Label) -> InternalField: """Return InternalField for the given column label.""" column_labels_to_fields = dict(zip(self.column_labels, self.data_fields)) if label in column_labels_to_fields: @@ -1105,7 +1106,7 @@ class InternalFrame(object): ] @property - def index_names(self) -> List[Optional[Tuple]]: + def index_names(self) -> List[Optional[Label]]: """Return the managed index names.""" return self._index_names @@ -1115,7 +1116,7 @@ class InternalFrame(object): return len(self._index_names) @property - def column_labels(self) -> List[Tuple]: + def column_labels(self) -> List[Label]: """Return the managed column index.""" return self._column_labels @@ -1125,7 +1126,7 @@ class InternalFrame(object): return len(self._column_label_names) @property - def column_label_names(self) -> List[Optional[Tuple]]: + def column_label_names(self) -> List[Optional[Label]]: """Return names of the index levels.""" return self._column_label_names @@ -1197,10 +1198,10 @@ class InternalFrame(object): pdf: pd.DataFrame, *, index_columns: List[str], - index_names: List[Tuple], + index_names: List[Label], data_columns: List[str], - column_labels: List[Tuple], - column_label_names: List[Tuple], + column_labels: List[Label], + column_label_names: List[Label], fields: List[InternalField] = None, ) -> pd.DataFrame: """ @@ -1335,9 +1336,9 @@ class InternalFrame(object): self, scols_or_pssers: Sequence[Union[Column, "Series"]], *, - column_labels: Optional[List[Tuple]] = None, + column_labels: Optional[List[Label]] = None, data_fields: Optional[List[InternalField]] = None, - column_label_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue, + column_label_names: Union[Optional[List[Optional[Label]]], _NoValueType] = _NoValue, keep_order: bool = True, ) -> "InternalFrame": """ @@ -1439,7 +1440,7 @@ class InternalFrame(object): def with_new_spark_column( self, - column_label: Tuple, + column_label: Label, scol: Column, *, field: Optional[InternalField] = None, @@ -1465,7 +1466,7 @@ class InternalFrame(object): data_spark_columns, data_fields=data_fields, keep_order=keep_order ) - def select_column(self, column_label: Tuple) -> "InternalFrame": + def select_column(self, column_label: Label) -> "InternalFrame": """ Copy the immutable InternalFrame with the specified column. @@ -1486,12 +1487,12 @@ class InternalFrame(object): *, spark_frame: Union[SparkDataFrame, _NoValueType] = _NoValue, index_spark_columns: Union[List[Column], _NoValueType] = _NoValue, - index_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue, + index_names: Union[Optional[List[Optional[Label]]], _NoValueType] = _NoValue, index_fields: Union[Optional[List[InternalField]], _NoValueType] = _NoValue, - column_labels: Union[Optional[List[Tuple]], _NoValueType] = _NoValue, + column_labels: Union[Optional[List[Label]], _NoValueType] = _NoValue, data_spark_columns: Union[Optional[List[Column]], _NoValueType] = _NoValue, data_fields: Union[Optional[List[InternalField]], _NoValueType] = _NoValue, - column_label_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue, + column_label_names: Union[Optional[List[Optional[Label]]], _NoValueType] = _NoValue, ) -> "InternalFrame": """ Copy the immutable InternalFrame. @@ -1530,12 +1531,12 @@ class InternalFrame(object): return InternalFrame( spark_frame=cast(SparkDataFrame, spark_frame), index_spark_columns=cast(List[Column], index_spark_columns), - index_names=cast(Optional[List[Optional[Tuple]]], index_names), + index_names=cast(Optional[List[Optional[Label]]], index_names), index_fields=cast(Optional[List[InternalField]], index_fields), - column_labels=cast(Optional[List[Tuple]], column_labels), + column_labels=cast(Optional[List[Label]], column_labels), data_spark_columns=cast(Optional[List[Column]], data_spark_columns), data_fields=cast(Optional[List[InternalField]], data_fields), - column_label_names=cast(Optional[List[Optional[Tuple]]], column_label_names), + column_label_names=cast(Optional[List[Optional[Label]]], column_label_names), ) @staticmethod @@ -1548,16 +1549,16 @@ class InternalFrame(object): index_names = [ name if name is None or isinstance(name, tuple) else (name,) for name in pdf.index.names - ] + ] # type: List[Optional[Label]] columns = pdf.columns if isinstance(columns, pd.MultiIndex): - column_labels = columns.tolist() + column_labels = columns.tolist() # type: List[Label] else: column_labels = [(col,) for col in columns] column_label_names = [ name if name is None or isinstance(name, tuple) else (name,) for name in columns.names - ] + ] # type: List[Optional[Label]] ( pdf, diff --git a/python/pyspark/pandas/ml.py b/python/pyspark/pandas/ml.py index 921d84f1a2..f0554ddefc 100644 --- a/python/pyspark/pandas/ml.py +++ b/python/pyspark/pandas/ml.py @@ -24,6 +24,7 @@ import pyspark from pyspark.ml.feature import VectorAssembler from pyspark.ml.stat import Correlation +from pyspark.pandas._typing import Label from pyspark.pandas.utils import column_labels_level if TYPE_CHECKING: @@ -62,7 +63,7 @@ def corr(psdf: "ps.DataFrame", method: str = "pearson") -> pd.DataFrame: return pd.DataFrame(arr, columns=idx, index=idx) -def to_numeric_df(psdf: "ps.DataFrame") -> Tuple[pyspark.sql.DataFrame, List[Tuple]]: +def to_numeric_df(psdf: "ps.DataFrame") -> Tuple[pyspark.sql.DataFrame, List[Label]]: """ Takes a dataframe and turns it into a dataframe containing a single numerical vector of doubles. This dataframe has a single field called '_1'. diff --git a/python/pyspark/pandas/mlflow.py b/python/pyspark/pandas/mlflow.py index 123171cec2..6178bac1f9 100644 --- a/python/pyspark/pandas/mlflow.py +++ b/python/pyspark/pandas/mlflow.py @@ -18,13 +18,14 @@ """ MLflow-related functions to load models and apply them to pandas-on-Spark dataframes. """ -from typing import List, Tuple, Union # noqa: F401 (SPARK-34943) +from typing import List, Union # noqa: F401 (SPARK-34943) from pyspark.sql.types import DataType import pandas as pd import numpy as np from typing import Any +from pyspark.pandas._typing import Label # noqa: F401 (SPARK-34943) from pyspark.pandas.utils import lazy_property, default_session from pyspark.pandas.frame import DataFrame from pyspark.pandas.series import Series, first_series @@ -99,7 +100,7 @@ class PythonModelWrapper(object): # return_col = self._model_udf(s) column_labels = [ (col,) for col in data._internal.spark_frame.select(return_col).columns - ] # type: List[Tuple] + ] # type: List[Label] internal = data._internal.copy( column_labels=column_labels, data_spark_columns=[return_col], data_fields=None ) diff --git a/python/pyspark/pandas/namespace.py b/python/pyspark/pandas/namespace.py index a545f6d598..70894861b1 100644 --- a/python/pyspark/pandas/namespace.py +++ b/python/pyspark/pandas/namespace.py @@ -65,7 +65,7 @@ from pyspark.sql.types import ( ) from pyspark import pandas as ps # noqa: F401 -from pyspark.pandas._typing import Axis, Dtype +from pyspark.pandas._typing import Axis, Dtype, Label, Name from pyspark.pandas.base import IndexOpsMixin from pyspark.pandas.utils import ( align_diff_frames, @@ -398,7 +398,7 @@ def read_csv( if col not in column_labels: raise KeyError(col) index_spark_column_names = [column_labels[col] for col in index_col] - index_names = [(col,) for col in index_col] # type: List[Tuple] + index_names = [(col,) for col in index_col] # type: List[Label] column_labels = OrderedDict( (label, col) for label, col in column_labels.items() if label not in index_col ) @@ -1818,7 +1818,7 @@ def get_dummies( prefix: Optional[Union[str, List[str], Dict[str, str]]] = None, prefix_sep: str = "_", dummy_na: bool = False, - columns: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, + columns: Optional[Union[Name, List[Name]]] = None, sparse: bool = False, drop_first: bool = False, dtype: Optional[Union[str, Dtype]] = None, @@ -2443,8 +2443,8 @@ def concat( def melt( frame: DataFrame, - id_vars: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, - value_vars: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, + id_vars: Optional[Union[Name, List[Name]]] = None, + value_vars: Optional[Union[Name, List[Name]]] = None, var_name: Optional[Union[str, List[str]]] = None, value_name: str = "value", ) -> DataFrame: @@ -2616,9 +2616,9 @@ def merge( obj: DataFrame, right: DataFrame, how: str = "inner", - on: Union[Any, List[Any], Tuple, List[Tuple]] = None, - left_on: Union[Any, List[Any], Tuple, List[Tuple]] = None, - right_on: Union[Any, List[Any], Tuple, List[Tuple]] = None, + on: Optional[Union[Name, List[Name]]] = None, + left_on: Optional[Union[Name, List[Name]]] = None, + right_on: Optional[Union[Name, List[Name]]] = None, left_index: bool = False, right_index: bool = False, suffixes: Tuple[str, str] = ("_x", "_y"), @@ -2919,7 +2919,7 @@ def read_orc( def _get_index_map( sdf: SparkDataFrame, index_col: Optional[Union[str, List[str]]] = None -) -> Tuple[Optional[List[Column]], Optional[List[Tuple]]]: +) -> Tuple[Optional[List[Column]], Optional[List[Label]]]: if index_col is not None: if isinstance(index_col, str): index_col = [index_col] @@ -2930,7 +2930,7 @@ def _get_index_map( index_spark_columns = [ scol_for(sdf, col) for col in index_col ] # type: Optional[List[Column]] - index_names = [(col,) for col in index_col] # type: Optional[List[Tuple]] + index_names = [(col,) for col in index_col] # type: Optional[List[Label]] else: index_spark_columns = None index_names = None diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index 9ffa29fd13..f3a119882e 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -67,7 +67,7 @@ from pyspark.sql.types import ( from pyspark.sql.window import Window from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. -from pyspark.pandas._typing import Axis, Dtype, Scalar, T +from pyspark.pandas._typing import Axis, Dtype, Label, Name, Scalar, T from pyspark.pandas.accessors import PandasOnSparkSeriesMethods from pyspark.pandas.categorical import CategoricalAccessor from pyspark.pandas.config import get_option @@ -410,7 +410,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]): assert not fastpath self._anchor = data # type: DataFrame - self._col_label = index # type: Tuple + self._col_label = index # type: Label else: if isinstance(data, pd.Series): assert index is None @@ -441,7 +441,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]): return self._psdf._internal.select_column(self._column_label) @property - def _column_label(self) -> Optional[Tuple]: + def _column_label(self) -> Optional[Label]: return self._col_label def _update_anchor(self, psdf: DataFrame) -> None: @@ -1042,7 +1042,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]): return (len(self),) @property - def name(self) -> Union[Any, Tuple]: + def name(self) -> Name: """Return name of the Series.""" name = self._column_label if name is not None and len(name) == 1: @@ -1051,12 +1051,12 @@ class Series(Frame, IndexOpsMixin, Generic[T]): return name @name.setter - def name(self, name: Union[Any, Tuple]) -> None: + def name(self, name: Name) -> None: self.rename(name, inplace=True) # TODO: Functionality and documentation should be matched. Currently, changing index labels # taking dictionary and function to change index are not supported. - def rename(self, index: Optional[Union[Any, Tuple]] = None, **kwargs: Any) -> "Series": + def rename(self, index: Optional[Name] = None, **kwargs: Any) -> "Series": """ Alter Series name. @@ -1227,9 +1227,9 @@ class Series(Frame, IndexOpsMixin, Generic[T]): def reset_index( self, - level: Optional[Union[int, Any, Tuple, Sequence[Union[int, Any, Tuple]]]] = None, + level: Optional[Union[int, Name, Sequence[Union[int, Name]]]] = None, drop: bool = False, - name: Optional[Union[Any, Tuple]] = None, + name: Optional[Name] = None, inplace: bool = False, ) -> Optional[Union["Series", DataFrame]]: """ @@ -1324,7 +1324,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]): else: return psdf - def to_frame(self, name: Optional[Union[Any, Tuple]] = None) -> DataFrame: + def to_frame(self, name: Optional[Name] = None) -> DataFrame: """ Convert Series to DataFrame. @@ -1491,13 +1491,13 @@ class Series(Frame, IndexOpsMixin, Generic[T]): def to_latex( self, buf: Optional[IO[str]] = None, - columns: Optional[List[Union[Any, Tuple]]] = None, + columns: Optional[List[Name]] = None, col_space: Optional[int] = None, header: bool = True, index: bool = True, na_rep: str = "NaN", formatters: Optional[ - Union[List[Callable[[Any], str]], Dict[Union[Any, Tuple], Callable[[Any], str]]] + Union[List[Callable[[Any], str]], Dict[Name, Callable[[Any], str]]] ] = None, float_format: Optional[Callable[[float], str]] = None, sparsify: Optional[bool] = None, @@ -2062,8 +2062,8 @@ class Series(Frame, IndexOpsMixin, Generic[T]): def drop( self, - labels: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, - index: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, + labels: Optional[Union[Name, List[Name]]] = None, + index: Optional[Union[Name, List[Name]]] = None, level: Optional[int] = None, ) -> "Series": """ @@ -2177,8 +2177,8 @@ class Series(Frame, IndexOpsMixin, Generic[T]): def _drop( self, - labels: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, - index: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, + labels: Optional[Union[Name, List[Name]]] = None, + index: Optional[Union[Name, List[Name]]] = None, level: Optional[int] = None, ) -> DataFrame: if labels is not None: @@ -2193,7 +2193,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]): raise ValueError("'level' should be less than the number of indexes") if is_name_like_tuple(index): # type: ignore - index_list = [cast(Tuple, index)] + index_list = [cast(Label, index)] elif is_name_like_value(index): index_list = [(index,)] elif all(is_name_like_value(idxes, allow_tuple=False) for idxes in index): @@ -2205,7 +2205,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]): "that contain index names" ) else: - index_list = cast(List[Tuple], index) + index_list = cast(List[Label], index) drop_index_scols = [] for idxes in index_list: @@ -2602,7 +2602,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]): return first_series(psdf) def swaplevel( - self, i: Union[int, Any, Tuple] = -2, j: Union[int, Any, Tuple] = -1, copy: bool = True + self, i: Union[int, Name] = -2, j: Union[int, Name] = -1, copy: bool = True ) -> "Series": """ Swap levels i and j in a MultiIndex. @@ -3903,7 +3903,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]): else: return tuple(values) - def pop(self, item: Union[Any, Tuple]) -> Union["Series", Scalar]: + def pop(self, item: Name) -> Union["Series", Scalar]: """ Return item and drop from series. @@ -4671,7 +4671,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]): """ return self.where(cast(Series, ~cond), other) - def xs(self, key: Union[Any, Tuple], level: Optional[int] = None) -> "Series": + def xs(self, key: Name, level: Optional[int] = None) -> "Series": """ Return cross-section from the Series. @@ -5278,7 +5278,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]): """ return self.head(2)._to_internal_pandas().item() - def iteritems(self) -> Iterable[Tuple[Union[Any, Tuple], Any]]: + def iteritems(self) -> Iterable[Tuple[Name, Any]]: """ Lazily iterate over (index, value) tuples. @@ -5311,7 +5311,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]): internal_index_columns = self._internal.index_spark_column_names internal_data_column = self._internal.data_spark_column_names[0] - def extract_kv_from_spark_row(row: Row) -> Tuple[Union[Any, Tuple], Any]: + def extract_kv_from_spark_row(row: Row) -> Tuple[Name, Any]: k = ( row[internal_index_columns[0]] if len(internal_index_columns) == 1 @@ -5325,11 +5325,11 @@ class Series(Frame, IndexOpsMixin, Generic[T]): ): yield k, v - def items(self) -> Iterable[Tuple[Union[Any, Tuple], Any]]: + def items(self) -> Iterable[Tuple[Name, Any]]: """This is an alias of ``iteritems``.""" return self.iteritems() - def droplevel(self, level: Union[int, Any, Tuple, List[Union[int, Any, Tuple]]]) -> "Series": + def droplevel(self, level: Union[int, Name, List[Union[int, Name]]]) -> "Series": """ Return Series with requested index level(s) removed. @@ -6213,7 +6213,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]): # Override the `groupby` to specify the actual return type annotation. def groupby( self, - by: Union[Any, Tuple, "Series", List[Union[Any, Tuple, "Series"]]], + by: Union[Name, "Series", List[Union[Name, "Series"]]], axis: Axis = 0, as_index: bool = True, dropna: bool = True, @@ -6225,7 +6225,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]): groupby.__doc__ = Frame.groupby.__doc__ def _build_groupby( - self, by: List[Union["Series", Tuple]], as_index: bool, dropna: bool + self, by: List[Union["Series", Label]], as_index: bool, dropna: bool ) -> "SeriesGroupBy": from pyspark.pandas.groupby import SeriesGroupBy diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py index 18791c9ee6..dbacb1d494 100644 --- a/python/pyspark/pandas/utils.py +++ b/python/pyspark/pandas/utils.py @@ -46,7 +46,7 @@ from pandas.api.types import is_list_like # For running doctests and reference resolution in PyCharm. from pyspark import pandas as ps # noqa: F401 -from pyspark.pandas._typing import Axis, DataFrameOrSeries +from pyspark.pandas._typing import Axis, Label, Name, DataFrameOrSeries from pyspark.pandas.spark import functions as SF from pyspark.pandas.typedef.typehints import as_spark_type @@ -279,7 +279,7 @@ def combine_frames( level = max(this_internal.column_labels_level, that_internal.column_labels_level) - def fill_label(label: Optional[Tuple]) -> List: + def fill_label(label: Optional[Label]) -> List: if label is None: return ([""] * (level - 1)) + [None] else: @@ -289,7 +289,7 @@ def combine_frames( tuple(["this"] + fill_label(label)) for label in this_internal.column_labels ] + [tuple(["that"] + fill_label(label)) for label in that_internal.column_labels] column_label_names = ( - cast(List[Optional[Tuple]], [None]) * (1 + level - this_internal.column_labels_level) + cast(List[Optional[Label]], [None]) * (1 + level - this_internal.column_labels_level) ) + this_internal.column_label_names return DataFrame( InternalFrame( @@ -309,7 +309,7 @@ def combine_frames( def align_diff_frames( resolve_func: Callable[ - ["DataFrame", List[Tuple], List[Tuple]], Iterator[Tuple["Series", Tuple]] + ["DataFrame", List[Label], List[Label]], Iterator[Tuple["Series", Label]] ], this: "DataFrame", that: "DataFrame", @@ -385,11 +385,11 @@ def align_diff_frames( # 2. Apply the given function to transform the columns in a batch and keep the new columns. combined_column_labels = combined._internal.column_labels - that_columns_to_apply = [] # type: List[Tuple] - this_columns_to_apply = [] # type: List[Tuple] - additional_that_columns = [] # type: List[Tuple] + that_columns_to_apply = [] # type: List[Label] + this_columns_to_apply = [] # type: List[Label] + additional_that_columns = [] # type: List[Label] columns_to_keep = [] # type: List[Union[Series, Column]] - column_labels_to_keep = [] # type: List[Tuple] + column_labels_to_keep = [] # type: List[Label] for combined_label in combined_column_labels: for common_label in common_column_labels: @@ -424,7 +424,7 @@ def align_diff_frames( *resolve_func(combined, this_columns_to_apply, that_columns_to_apply) ) columns_applied = list(psser_set) # type: List[Union[Series, Column]] - column_labels_applied = list(column_labels_set) # type: List[Tuple] + column_labels_applied = list(column_labels_set) # type: List[Label] else: columns_applied = [] column_labels_applied = [] @@ -592,7 +592,7 @@ def scol_for(sdf: SparkDataFrame, column_name: str) -> Column: return sdf["`{}`".format(column_name)] -def column_labels_level(column_labels: List[Tuple]) -> int: +def column_labels_level(column_labels: List[Label]) -> int: """Return the level of the column index.""" if len(column_labels) == 0: return 1 @@ -602,7 +602,7 @@ def column_labels_level(column_labels: List[Tuple]) -> int: return list(levels)[0] -def name_like_string(name: Optional[Union[Any, Tuple]]) -> str: +def name_like_string(name: Optional[Name]) -> str: """ Return the name-like strings from str or tuple of str @@ -621,12 +621,12 @@ def name_like_string(name: Optional[Union[Any, Tuple]]) -> str: '(a, b, c)' """ if name is None: - name = ("__none__",) + label = ("__none__",) # type: Label elif is_list_like(name): - name = tuple([str(n) for n in name]) + label = tuple([str(n) for n in name]) else: - name = (str(name),) - return ("(%s)" % ", ".join(name)) if len(name) > 1 else name[0] + label = (str(name),) + return ("(%s)" % ", ".join(label)) if len(label) > 1 else label[0] def is_name_like_tuple(value: Any, allow_none: bool = True, check_type: bool = False) -> bool: @@ -760,15 +760,13 @@ def verify_temp_column_name(df: SparkDataFrame, column_name_or_label: str) -> st @overload -def verify_temp_column_name( - df: "DataFrame", column_name_or_label: Union[Any, Tuple] -) -> Union[Any, Tuple]: +def verify_temp_column_name(df: "DataFrame", column_name_or_label: Name) -> Label: ... def verify_temp_column_name( - df: Union["DataFrame", SparkDataFrame], column_name_or_label: Union[Any, Tuple] -) -> Union[Any, Tuple]: + df: Union["DataFrame", SparkDataFrame], column_name_or_label: Union[str, Name] +) -> Union[str, Label]: """ Verify that the given column name does not exist in the given pandas-on-Spark or Spark DataFrame.