[SPARK-35944][PYTHON] Introduce Name and Label type aliases

### What changes were proposed in this pull request?

Introduce `Name` and `Label` type aliases to distinguish what is expected instead of `Any` or `Union[Any, Tuple]`.

- `Label`: `Tuple[Any, ...]`
  Internal expression for name-like metadata, like `index_names`, `column_labels`, and `column_label_names` in `InternalFrame`, and similar internal structures.
- `Name`: `Union[Any, Label]`
  External expression for user-facing names, which can be scalar values or tuples.

### Why are the changes needed?

Currently `Any` or `Union[Any, Tuple]` is used for name-like types, but type aliases should be used to distinguish what is expected clearly.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests.

Closes #33159 from ueshin/issues/SPARK-35944/name_and_label.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
Takuya UESHIN 2021-07-01 09:40:07 +09:00 committed by Hyukjin Kwon
parent 5ad12611ec
commit a98c8ae57d
16 changed files with 343 additions and 362 deletions

View file

@ -16,7 +16,7 @@
# #
import datetime import datetime
import decimal import decimal
from typing import TypeVar, Union, TYPE_CHECKING from typing import Any, Tuple, TypeVar, Union, TYPE_CHECKING
import numpy as np import numpy as np
from pandas.api.extensions import ExtensionDtype from pandas.api.extensions import ExtensionDtype
@ -40,6 +40,10 @@ Scalar = Union[
int, float, bool, str, bytes, decimal.Decimal, datetime.date, datetime.datetime, None int, float, bool, str, bytes, decimal.Decimal, datetime.date, datetime.datetime, None
] ]
# TODO: use the actual type parameters.
Label = Tuple[Any, ...]
Name = Union[Any, Label]
Axis = Union[int, str] Axis = Union[int, str]
Dtype = Union[np.dtype, ExtensionDtype] Dtype = Union[np.dtype, ExtensionDtype]

View file

@ -28,7 +28,7 @@ from pyspark.sql import functions as F
from pyspark.sql.functions import pandas_udf from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DataType, LongType, StructField, StructType from pyspark.sql.types import DataType, LongType, StructField, StructType
from pyspark.pandas._typing import DataFrameOrSeries from pyspark.pandas._typing import DataFrameOrSeries, Name
from pyspark.pandas.internal import ( from pyspark.pandas.internal import (
InternalField, InternalField,
InternalFrame, InternalFrame,
@ -56,7 +56,7 @@ class PandasOnSparkFrameMethods(object):
def __init__(self, frame: "DataFrame"): def __init__(self, frame: "DataFrame"):
self._psdf = frame self._psdf = frame
def attach_id_column(self, id_type: str, column: Union[Any, Tuple]) -> "DataFrame": def attach_id_column(self, id_type: str, column: Name) -> "DataFrame":
""" """
Attach a column to be used as identifier of rows similar to the default index. Attach a column to be used as identifier of rows similar to the default index.

View file

@ -34,7 +34,7 @@ from pyspark.sql.types import (
) )
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
from pyspark.pandas._typing import Axis, Dtype, IndexOpsLike, SeriesOrIndex from pyspark.pandas._typing import Axis, Dtype, IndexOpsLike, Label, SeriesOrIndex
from pyspark.pandas.config import get_option, option_context from pyspark.pandas.config import get_option, option_context
from pyspark.pandas.internal import ( from pyspark.pandas.internal import (
InternalField, InternalField,
@ -297,7 +297,7 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
@property @property
@abstractmethod @abstractmethod
def _column_label(self) -> Optional[Tuple]: def _column_label(self) -> Optional[Label]:
pass pass
@property @property

View file

@ -83,7 +83,7 @@ from pyspark.sql.types import ( # noqa: F401 (SPARK-34943)
from pyspark.sql.window import Window from pyspark.sql.window import Window
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
from pyspark.pandas._typing import Axis, DataFrameOrSeries, Dtype, Scalar, T from pyspark.pandas._typing import Axis, DataFrameOrSeries, Dtype, Label, Name, Scalar, T
from pyspark.pandas.accessors import PandasOnSparkFrameMethods from pyspark.pandas.accessors import PandasOnSparkFrameMethods
from pyspark.pandas.config import option_context, get_option from pyspark.pandas.config import option_context, get_option
from pyspark.pandas.spark import functions as SF from pyspark.pandas.spark import functions as SF
@ -520,7 +520,7 @@ class DataFrame(Frame, Generic[T]):
object.__setattr__(self, "_internal_frame", internal) object.__setattr__(self, "_internal_frame", internal)
@property @property
def _pssers(self) -> Dict[Tuple, "Series"]: def _pssers(self) -> Dict[Label, "Series"]:
"""Return a dict of column label -> Series which anchors `self`.""" """Return a dict of column label -> Series which anchors `self`."""
from pyspark.pandas.series import Series from pyspark.pandas.series import Series
@ -746,7 +746,7 @@ class DataFrame(Frame, Generic[T]):
) )
return first_series(DataFrame(internal)).rename(pser.name) return first_series(DataFrame(internal)).rename(pser.name)
def _psser_for(self, label: Tuple) -> "Series": def _psser_for(self, label: Label) -> "Series":
""" """
Create Series with a proper column label. Create Series with a proper column label.
@ -806,9 +806,9 @@ class DataFrame(Frame, Generic[T]):
# Different DataFrames # Different DataFrames
def apply_op( def apply_op(
psdf: DataFrame, psdf: DataFrame,
this_column_labels: List[Tuple], this_column_labels: List[Label],
that_column_labels: List[Tuple], that_column_labels: List[Label],
) -> Iterator[Tuple["Series", Tuple]]: ) -> Iterator[Tuple["Series", Label]]:
for this_label, that_label in zip(this_column_labels, that_column_labels): for this_label, that_label in zip(this_column_labels, that_column_labels):
yield ( yield (
getattr(psdf._psser_for(this_label), op)( getattr(psdf._psser_for(this_label), op)(
@ -1226,7 +1226,7 @@ class DataFrame(Frame, Generic[T]):
return self._apply_series_op(lambda psser: psser.apply(func)) return self._apply_series_op(lambda psser: psser.apply(func))
# TODO: not all arguments are implemented comparing to pandas' for now. # TODO: not all arguments are implemented comparing to pandas' for now.
def aggregate(self, func: Union[List[str], Dict[Any, List[str]]]) -> "DataFrame": def aggregate(self, func: Union[List[str], Dict[Name, List[str]]]) -> "DataFrame":
"""Aggregate using one or more operations over the specified axis. """Aggregate using one or more operations over the specified axis.
Parameters Parameters
@ -1388,7 +1388,7 @@ class DataFrame(Frame, Generic[T]):
""" """
return cast(DataFrame, ps.from_pandas(corr(self, method))) return cast(DataFrame, ps.from_pandas(corr(self, method)))
def iteritems(self) -> Iterator[Tuple[Union[Any, Tuple], "Series"]]: def iteritems(self) -> Iterator[Tuple[Name, "Series"]]:
""" """
Iterator over (column name, Series) pairs. Iterator over (column name, Series) pairs.
@ -1432,7 +1432,7 @@ class DataFrame(Frame, Generic[T]):
for label in self._internal.column_labels for label in self._internal.column_labels
) )
def iterrows(self) -> Iterator[Tuple[Union[Any, Tuple], pd.Series]]: def iterrows(self) -> Iterator[Tuple[Name, pd.Series]]:
""" """
Iterate over DataFrame rows as (index, Series) pairs. Iterate over DataFrame rows as (index, Series) pairs.
@ -1478,7 +1478,7 @@ class DataFrame(Frame, Generic[T]):
internal_index_columns = self._internal.index_spark_column_names internal_index_columns = self._internal.index_spark_column_names
internal_data_columns = self._internal.data_spark_column_names internal_data_columns = self._internal.data_spark_column_names
def extract_kv_from_spark_row(row: Row) -> Tuple[Union[Any, Tuple], Any]: def extract_kv_from_spark_row(row: Row) -> Tuple[Name, Any]:
k = ( k = (
row[internal_index_columns[0]] row[internal_index_columns[0]]
if len(internal_index_columns) == 1 if len(internal_index_columns) == 1
@ -1567,7 +1567,7 @@ class DataFrame(Frame, Generic[T]):
index_spark_column_names = self._internal.index_spark_column_names index_spark_column_names = self._internal.index_spark_column_names
data_spark_column_names = self._internal.data_spark_column_names data_spark_column_names = self._internal.data_spark_column_names
def extract_kv_from_spark_row(row: Row) -> Tuple[Union[Any, Tuple], Any]: def extract_kv_from_spark_row(row: Row) -> Tuple[Name, Any]:
k = ( k = (
row[index_spark_column_names[0]] row[index_spark_column_names[0]]
if len(index_spark_column_names) == 1 if len(index_spark_column_names) == 1
@ -1592,7 +1592,7 @@ class DataFrame(Frame, Generic[T]):
): ):
yield tuple(([k] if index else []) + list(v)) yield tuple(([k] if index else []) + list(v))
def items(self) -> Iterator[Tuple[Union[Any, Tuple], "Series"]]: def items(self) -> Iterator[Tuple[Name, "Series"]]:
"""This is an alias of ``iteritems``.""" """This is an alias of ``iteritems``."""
return self.iteritems() return self.iteritems()
@ -1674,13 +1674,13 @@ class DataFrame(Frame, Generic[T]):
def to_html( def to_html(
self, self,
buf: Optional[IO[str]] = None, buf: Optional[IO[str]] = None,
columns: Optional[Sequence[Union[Any, Tuple]]] = None, columns: Optional[Sequence[Name]] = None,
col_space: Optional[Union[str, int, Dict[Union[Any, Tuple], Union[str, int]]]] = None, col_space: Optional[Union[str, int, Dict[Name, Union[str, int]]]] = None,
header: bool = True, header: bool = True,
index: bool = True, index: bool = True,
na_rep: str = "NaN", na_rep: str = "NaN",
formatters: Optional[ formatters: Optional[
Union[List[Callable[[Any], str]], Dict[Union[Any, Tuple], Callable[[Any], str]]] Union[List[Callable[[Any], str]], Dict[Name, Callable[[Any], str]]]
] = None, ] = None,
float_format: Optional[Callable[[float], str]] = None, float_format: Optional[Callable[[float], str]] = None,
sparsify: Optional[bool] = None, sparsify: Optional[bool] = None,
@ -1796,13 +1796,13 @@ class DataFrame(Frame, Generic[T]):
def to_string( def to_string(
self, self,
buf: Optional[IO[str]] = None, buf: Optional[IO[str]] = None,
columns: Optional[Sequence[Union[Any, Tuple]]] = None, columns: Optional[Sequence[Name]] = None,
col_space: Optional[Union[str, int, Dict[Union[Any, Tuple], Union[str, int]]]] = None, col_space: Optional[Union[str, int, Dict[Name, Union[str, int]]]] = None,
header: bool = True, header: bool = True,
index: bool = True, index: bool = True,
na_rep: str = "NaN", na_rep: str = "NaN",
formatters: Optional[ formatters: Optional[
Union[List[Callable[[Any], str]], Dict[Union[Any, Tuple], Callable[[Any], str]]] Union[List[Callable[[Any], str]], Dict[Name, Callable[[Any], str]]]
] = None, ] = None,
float_format: Optional[Callable[[float], str]] = None, float_format: Optional[Callable[[float], str]] = None,
sparsify: Optional[bool] = None, sparsify: Optional[bool] = None,
@ -2010,13 +2010,13 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
def to_latex( def to_latex(
self, self,
buf: Optional[IO[str]] = None, buf: Optional[IO[str]] = None,
columns: Optional[List[Union[Any, Tuple]]] = None, columns: Optional[List[Name]] = None,
col_space: Optional[int] = None, col_space: Optional[int] = None,
header: bool = True, header: bool = True,
index: bool = True, index: bool = True,
na_rep: str = "NaN", na_rep: str = "NaN",
formatters: Optional[ formatters: Optional[
Union[List[Callable[[Any], str]], Dict[Union[Any, Tuple], Callable[[Any], str]]] Union[List[Callable[[Any], str]], Dict[Name, Callable[[Any], str]]]
] = None, ] = None,
float_format: Optional[Callable[[float], str]] = None, float_format: Optional[Callable[[float], str]] = None,
sparsify: Optional[bool] = None, sparsify: Optional[bool] = None,
@ -2545,7 +2545,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
self_applied = DataFrame(self._internal.resolved_copy) # type: "DataFrame" self_applied = DataFrame(self._internal.resolved_copy) # type: "DataFrame"
column_labels = None # type: Optional[List[Tuple]] column_labels = None # type: Optional[List[Label]]
if should_infer_schema: if should_infer_schema:
# Here we execute with the first 1000 to get the return type. # Here we execute with the first 1000 to get the return type.
# If the records were less than 1000, it uses pandas API directly for a shortcut. # If the records were less than 1000, it uses pandas API directly for a shortcut.
@ -2802,7 +2802,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
lambda psser: psser.pandas_on_spark.transform_batch(func, *args, **kwargs) lambda psser: psser.pandas_on_spark.transform_batch(func, *args, **kwargs)
) )
def pop(self, item: Union[Any, Tuple]) -> "DataFrame": def pop(self, item: Name) -> "DataFrame":
""" """
Return item and drop from frame. Raise KeyError if not found. Return item and drop from frame. Raise KeyError if not found.
@ -2881,9 +2881,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
return result return result
# TODO: add axis parameter can work when '1' or 'columns' # TODO: add axis parameter can work when '1' or 'columns'
def xs( def xs(self, key: Name, axis: Axis = 0, level: Optional[int] = None) -> DataFrameOrSeries:
self, key: Union[Any, Tuple], axis: Axis = 0, level: Optional[int] = None
) -> DataFrameOrSeries:
""" """
Return cross-section from the DataFrame. Return cross-section from the DataFrame.
@ -3527,7 +3525,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
def set_index( def set_index(
self, self,
keys: Union[Any, Tuple, List[Union[Any, Tuple]]], keys: Union[Name, List[Name]],
drop: bool = True, drop: bool = True,
append: bool = False, append: bool = False,
inplace: bool = False, inplace: bool = False,
@ -3596,7 +3594,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
""" """
inplace = validate_bool_kwarg(inplace, "inplace") inplace = validate_bool_kwarg(inplace, "inplace")
if is_name_like_tuple(keys): if is_name_like_tuple(keys):
key_list = [cast(Tuple, keys)] # type: List[Tuple] key_list = [cast(Label, keys)] # type: List[Label]
elif is_name_like_value(keys): elif is_name_like_value(keys):
key_list = [(keys,)] key_list = [(keys,)]
else: else:
@ -3642,7 +3640,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
def reset_index( def reset_index(
self, self,
level: Optional[Union[int, Any, Tuple, Sequence[Union[int, Any, Tuple]]]] = None, level: Optional[Union[int, Name, Sequence[Union[int, Name]]]] = None,
drop: bool = False, drop: bool = False,
inplace: bool = False, inplace: bool = False,
col_level: int = 0, col_level: int = 0,
@ -3793,7 +3791,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
inplace = validate_bool_kwarg(inplace, "inplace") inplace = validate_bool_kwarg(inplace, "inplace")
multi_index = self._internal.index_level > 1 multi_index = self._internal.index_level > 1
def rename(index: int) -> Tuple: def rename(index: int) -> Label:
if multi_index: if multi_index:
return ("level_{}".format(index),) return ("level_{}".format(index),)
else: else:
@ -3818,9 +3816,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
index_fields = [] index_fields = []
else: else:
if is_list_like(level): if is_list_like(level):
level = list(cast(Sequence[Union[int, Any, Tuple]], level)) level = list(cast(Sequence[Union[int, Name]], level))
if isinstance(level, int) or is_name_like_tuple(level): if isinstance(level, int) or is_name_like_tuple(level):
level_list = [cast(Union[int, Tuple], level)] level_list = [cast(Union[int, Label], level)]
elif is_name_like_value(level): elif is_name_like_value(level):
level_list = [(level,)] level_list = [(level,)]
else: else:
@ -3841,7 +3839,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
idx = int_level_list idx = int_level_list
elif all(is_name_like_tuple(lev) for lev in level_list): elif all(is_name_like_tuple(lev) for lev in level_list):
idx = [] idx = []
for l in cast(List[Tuple], level_list): for l in cast(List[Label], level_list):
try: try:
i = self._internal.index_names.index(l) i = self._internal.index_names.index(l)
idx.append(i) idx.append(i)
@ -3985,7 +3983,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
def insert( def insert(
self, self,
loc: int, loc: int,
column: Union[Any, Tuple], column: Name,
value: Union[Scalar, "Series", Iterable], value: Union[Scalar, "Series", Iterable],
allow_duplicates: bool = False, allow_duplicates: bool = False,
) -> None: ) -> None:
@ -4268,9 +4266,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
) )
return first_series(DataFrame(internal).transpose()) return first_series(DataFrame(internal).transpose())
def round( def round(self, decimals: Union[int, Dict[Name, int], "Series"] = 0) -> "DataFrame":
self, decimals: Union[int, Dict[Union[Any, Tuple], int], "Series"] = 0
) -> "DataFrame":
""" """
Round a DataFrame to a variable number of decimal places. Round a DataFrame to a variable number of decimal places.
@ -4352,14 +4348,14 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
def _mark_duplicates( def _mark_duplicates(
self, self,
subset: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, subset: Optional[Union[Name, List[Name]]] = None,
keep: str = "first", keep: str = "first",
) -> Tuple[SparkDataFrame, str]: ) -> Tuple[SparkDataFrame, str]:
if subset is None: if subset is None:
subset_list = self._internal.column_labels subset_list = self._internal.column_labels
else: else:
if is_name_like_tuple(subset): if is_name_like_tuple(subset):
subset_list = [cast(Tuple, subset)] subset_list = [cast(Label, subset)]
elif is_name_like_value(subset): elif is_name_like_value(subset):
subset_list = [(subset,)] subset_list = [(subset,)]
else: else:
@ -4395,7 +4391,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
def duplicated( def duplicated(
self, self,
subset: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, subset: Optional[Union[Name, List[Name]]] = None,
keep: str = "first", keep: str = "first",
) -> "Series": ) -> "Series":
""" """
@ -5033,12 +5029,8 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
def to_records( def to_records(
self, self,
index: bool = True, index: bool = True,
column_dtypes: Optional[ column_dtypes: Optional[Union[str, Dtype, Dict[Name, Union[str, Dtype]]]] = None,
Union[str, Dtype, Dict[Union[Any, Tuple], Union[str, Dtype]]] index_dtypes: Optional[Union[str, Dtype, Dict[Name, Union[str, Dtype]]]] = None,
] = None,
index_dtypes: Optional[
Union[str, Dtype, Dict[Union[Any, Tuple], Union[str, Dtype]]]
] = None,
) -> np.recarray: ) -> np.recarray:
""" """
Convert DataFrame to a NumPy record array. Convert DataFrame to a NumPy record array.
@ -5152,7 +5144,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
axis: Axis = 0, axis: Axis = 0,
how: str = "any", how: str = "any",
thresh: Optional[int] = None, thresh: Optional[int] = None,
subset: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, subset: Optional[Union[Name, List[Name]]] = None,
inplace: bool = False, inplace: bool = False,
) -> Optional["DataFrame"]: ) -> Optional["DataFrame"]:
""" """
@ -5256,7 +5248,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
if subset is not None: if subset is not None:
if isinstance(subset, str): if isinstance(subset, str):
labels = [(subset,)] # type: Optional[List[Tuple]] labels = [(subset,)] # type: Optional[List[Label]]
elif isinstance(subset, tuple): elif isinstance(subset, tuple):
labels = [subset] labels = [subset]
else: else:
@ -5360,7 +5352,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
# TODO: add 'limit' when value parameter exists # TODO: add 'limit' when value parameter exists
def fillna( def fillna(
self, self,
value: Optional[Union[Any, Dict[Union[Any, Tuple], Any]]] = None, value: Optional[Union[Any, Dict[Name, Any]]] = None,
method: Optional[str] = None, method: Optional[str] = None,
axis: Optional[Axis] = None, axis: Optional[Axis] = None,
inplace: bool = False, inplace: bool = False,
@ -5835,10 +5827,10 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
def pivot_table( def pivot_table(
self, self,
values: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, values: Optional[Union[Name, List[Name]]] = None,
index: Optional[List[Union[Any, Tuple]]] = None, index: Optional[List[Name]] = None,
columns: Optional[Union[Any, Tuple]] = None, columns: Optional[Name] = None,
aggfunc: Union[str, Dict[Union[Any, Tuple], str]] = "mean", aggfunc: Union[str, Dict[Name, str]] = "mean",
fill_value: Optional[Any] = None, fill_value: Optional[Any] = None,
) -> "DataFrame": ) -> "DataFrame":
""" """
@ -6060,7 +6052,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
for name in data_columns for name in data_columns
] ]
column_label_names = ( column_label_names = (
[cast(Optional[Union[Any, Tuple]], None)] * column_labels_level(values) [cast(Optional[Name], None)] * column_labels_level(values)
) + [columns] ) + [columns]
internal = InternalFrame( internal = InternalFrame(
spark_frame=sdf, spark_frame=sdf,
@ -6074,9 +6066,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
psdf = DataFrame(internal) # type: "DataFrame" psdf = DataFrame(internal) # type: "DataFrame"
else: else:
column_labels = [tuple(list(values[0]) + [column]) for column in data_columns] column_labels = [tuple(list(values[0]) + [column]) for column in data_columns]
column_label_names = ( column_label_names = ([cast(Optional[Name], None)] * len(values[0])) + [columns]
[cast(Optional[Union[Any, Tuple]], None)] * len(values[0])
) + [columns]
internal = InternalFrame( internal = InternalFrame(
spark_frame=sdf, spark_frame=sdf,
index_spark_columns=[scol_for(sdf, col) for col in index_columns], index_spark_columns=[scol_for(sdf, col) for col in index_columns],
@ -6101,7 +6091,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
index_values = values[-1] index_values = values[-1]
else: else:
index_values = values index_values = values
index_map = OrderedDict() # type: Dict[str, Optional[Tuple]] index_map = OrderedDict() # type: Dict[str, Optional[Label]]
for i, index_value in enumerate(index_values): for i, index_value in enumerate(index_values):
colname = SPARK_INDEX_NAME_FORMAT(i) colname = SPARK_INDEX_NAME_FORMAT(i)
sdf = sdf.withColumn(colname, SF.lit(index_value)) sdf = sdf.withColumn(colname, SF.lit(index_value))
@ -6131,9 +6121,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
def pivot( def pivot(
self, self,
index: Optional[Union[Any, Tuple]] = None, index: Optional[Name] = None,
columns: Optional[Union[Any, Tuple]] = None, columns: Optional[Name] = None,
values: Optional[Union[Any, Tuple]] = None, values: Optional[Name] = None,
) -> "DataFrame": ) -> "DataFrame":
""" """
Return reshaped DataFrame organized by given index / column values. Return reshaped DataFrame organized by given index / column values.
@ -6280,7 +6270,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
return columns return columns
@columns.setter @columns.setter
def columns(self, columns: Union[pd.Index, List[Union[Any, Tuple]]]) -> None: def columns(self, columns: Union[pd.Index, List[Name]]) -> None:
if isinstance(columns, pd.MultiIndex): if isinstance(columns, pd.MultiIndex):
column_labels = columns.tolist() column_labels = columns.tolist()
else: else:
@ -6523,7 +6513,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
) )
def droplevel( def droplevel(
self, level: Union[int, Any, Tuple, List[Union[int, Any, Tuple]]], axis: Axis = 0 self, level: Union[int, Name, List[Union[int, Name]]], axis: Axis = 0
) -> "DataFrame": ) -> "DataFrame":
""" """
Return DataFrame with requested index / column level(s) removed. Return DataFrame with requested index / column level(s) removed.
@ -6579,7 +6569,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
if not isinstance(level, (tuple, list)): # huh? if not isinstance(level, (tuple, list)): # huh?
level = [level] level = [level]
index_names = self.index.names names = self.index.names
nlevels = self._internal.index_level nlevels = self._internal.index_level
int_level = set() int_level = set()
@ -6599,9 +6589,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
) )
) )
else: else:
if n not in index_names: if n not in names:
raise KeyError("Level {} not found".format(n)) raise KeyError("Level {} not found".format(n))
n = index_names.index(n) n = names.index(n)
int_level.add(n) int_level.add(n)
if len(level) >= nlevels: if len(level) >= nlevels:
@ -6637,9 +6627,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
def drop( def drop(
self, self,
labels: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, labels: Optional[Union[Name, List[Name]]] = None,
axis: Axis = 1, axis: Axis = 1,
columns: Union[Any, Tuple, List[Any], List[Tuple]] = None, columns: Union[Name, List[Name]] = None,
) -> "DataFrame": ) -> "DataFrame":
""" """
Drop specified labels from columns. Drop specified labels from columns.
@ -6773,7 +6763,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
def sort_values( def sort_values(
self, self,
by: Union[Any, List[Any], Tuple, List[Tuple]], by: Union[Name, List[Name]],
ascending: Union[bool, List[bool]] = True, ascending: Union[bool, List[bool]] = True,
inplace: bool = False, inplace: bool = False,
na_position: str = "last", na_position: str = "last",
@ -6983,7 +6973,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
return psdf return psdf
def swaplevel( def swaplevel(
self, i: Union[int, Any, Tuple] = -2, j: Union[int, Any, Tuple] = -1, axis: Axis = 0 self, i: Union[int, Name] = -2, j: Union[int, Name] = -1, axis: Axis = 0
) -> "DataFrame": ) -> "DataFrame":
""" """
Swap levels i and j in a MultiIndex on a particular axis. Swap levels i and j in a MultiIndex on a particular axis.
@ -7142,9 +7132,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
return self.copy() if i == j else self.transpose() return self.copy() if i == j else self.transpose()
def _swaplevel_columns( def _swaplevel_columns(self, i: Union[int, Name], j: Union[int, Name]) -> InternalFrame:
self, i: Union[int, Any, Tuple], j: Union[int, Any, Tuple]
) -> InternalFrame:
assert isinstance(self.columns, pd.MultiIndex) assert isinstance(self.columns, pd.MultiIndex)
for index in (i, j): for index in (i, j):
if not isinstance(index, int) and index not in self.columns.names: if not isinstance(index, int) and index not in self.columns.names:
@ -7174,9 +7162,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
) )
return internal return internal
def _swaplevel_index( def _swaplevel_index(self, i: Union[int, Name], j: Union[int, Name]) -> InternalFrame:
self, i: Union[int, Any, Tuple], j: Union[int, Any, Tuple]
) -> InternalFrame:
assert isinstance(self.index, ps.MultiIndex) assert isinstance(self.index, ps.MultiIndex)
for index in (i, j): for index in (i, j):
if not isinstance(index, int) and index not in self.index.names: if not isinstance(index, int) and index not in self.index.names:
@ -7208,7 +7194,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
return internal return internal
# TODO: add keep = First # TODO: add keep = First
def nlargest(self, n: int, columns: "Any") -> "DataFrame": def nlargest(self, n: int, columns: Union[Name, List[Name]]) -> "DataFrame":
""" """
Return the first `n` rows ordered by `columns` in descending order. Return the first `n` rows ordered by `columns` in descending order.
@ -7282,7 +7268,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
return self.sort_values(by=columns, ascending=False).head(n=n) return self.sort_values(by=columns, ascending=False).head(n=n)
# TODO: add keep = First # TODO: add keep = First
def nsmallest(self, n: int, columns: "Any") -> "DataFrame": def nsmallest(self, n: int, columns: Union[Name, List[Name]]) -> "DataFrame":
""" """
Return the first `n` rows ordered by `columns` in ascending order. Return the first `n` rows ordered by `columns` in ascending order.
@ -7447,9 +7433,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
self, self,
right: "DataFrame", right: "DataFrame",
how: str = "inner", how: str = "inner",
on: Optional[Union[Any, List[Any], Tuple, List[Tuple]]] = None, on: Optional[Union[Name, List[Name]]] = None,
left_on: Optional[Union[Any, List[Any], Tuple, List[Tuple]]] = None, left_on: Optional[Union[Name, List[Name]]] = None,
right_on: Optional[Union[Any, List[Any], Tuple, List[Tuple]]] = None, right_on: Optional[Union[Name, List[Name]]] = None,
left_index: bool = False, left_index: bool = False,
right_index: bool = False, right_index: bool = False,
suffixes: Tuple[str, str] = ("_x", "_y"), suffixes: Tuple[str, str] = ("_x", "_y"),
@ -7571,7 +7557,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
instead of NaN. instead of NaN.
""" """
def to_list(os: Optional[Union[Any, List[Any], Tuple, List[Tuple]]]) -> List[Tuple]: def to_list(os: Optional[Union[Name, List[Name]]]) -> List[Label]:
if os is None: if os is None:
return [] return []
elif is_name_like_tuple(os): elif is_name_like_tuple(os):
@ -7779,7 +7765,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
def join( def join(
self, self,
right: "DataFrame", right: "DataFrame",
on: Optional[Union[Any, List[Any], Tuple, List[Tuple]]] = None, on: Optional[Union[Name, List[Name]]] = None,
how: str = "left", how: str = "left",
lsuffix: str = "", lsuffix: str = "",
rsuffix: str = "", rsuffix: str = "",
@ -8176,9 +8162,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
) )
return DataFrame(self._internal.with_new_sdf(sdf)) return DataFrame(self._internal.with_new_sdf(sdf))
def astype( def astype(self, dtype: Union[str, Dtype, Dict[Name, Union[str, Dtype]]]) -> "DataFrame":
self, dtype: Union[str, Dtype, Dict[Union[Any, Tuple], Union[str, Dtype]]]
) -> "DataFrame":
""" """
Cast a pandas-on-Spark object to a specified dtype ``dtype``. Cast a pandas-on-Spark object to a specified dtype ``dtype``.
@ -8234,7 +8218,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
""" """
applied = [] applied = []
if is_dict_like(dtype): if is_dict_like(dtype):
dtype_dict = cast(Dict[Union[Any, Tuple], Union[str, Dtype]], dtype) dtype_dict = cast(Dict[Name, Union[str, Dtype]], dtype)
for col_name in dtype_dict.keys(): for col_name in dtype_dict.keys():
if col_name not in self.columns: if col_name not in self.columns:
raise KeyError( raise KeyError(
@ -8527,7 +8511,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
def drop_duplicates( def drop_duplicates(
self, self,
subset: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, subset: Optional[Union[Name, List[Name]]] = None,
keep: str = "first", keep: str = "first",
inplace: bool = False, inplace: bool = False,
) -> Optional["DataFrame"]: ) -> Optional["DataFrame"]:
@ -8997,8 +8981,8 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
def melt( def melt(
self, self,
id_vars: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, id_vars: Optional[Union[Name, List[Name]]] = None,
value_vars: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, value_vars: Optional[Union[Name, List[Name]]] = None,
var_name: Optional[Union[str, List[str]]] = None, var_name: Optional[Union[str, List[str]]] = None,
value_name: str = "value", value_name: str = "value",
) -> "DataFrame": ) -> "DataFrame":
@ -10168,7 +10152,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
if level < 0 or level >= psdf._internal.column_labels_level: if level < 0 or level >= psdf._internal.column_labels_level:
raise ValueError("level should be an integer between [0, column_labels_level)") raise ValueError("level should be an integer between [0, column_labels_level)")
def gen_new_column_labels_entry(column_labels_entry: Tuple) -> Tuple: def gen_new_column_labels_entry(column_labels_entry: Label) -> Label:
if level is None: if level is None:
# rename all level columns # rename all level columns
return tuple(map(columns_mapper_fn, column_labels_entry)) return tuple(map(columns_mapper_fn, column_labels_entry))
@ -10193,15 +10177,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
def rename_axis( def rename_axis(
self, self,
mapper: Union[ mapper: Union[Any, Sequence[Any], Dict[Name, Any], Callable[[Name], Any]] = None,
Any, Sequence[Any], Dict[Union[Any, Tuple], Any], Callable[[Union[Any, Tuple]], Any] index: Union[Any, Sequence[Any], Dict[Name, Any], Callable[[Name], Any]] = None,
] = None, columns: Union[Any, Sequence[Any], Dict[Name, Any], Callable[[Name], Any]] = None,
index: Union[
Any, Sequence[Any], Dict[Union[Any, Tuple], Any], Callable[[Union[Any, Tuple]], Any]
] = None,
columns: Union[
Any, Sequence[Any], Dict[Union[Any, Tuple], Any], Callable[[Union[Any, Tuple]], Any]
] = None,
axis: Optional[Axis] = 0, axis: Optional[Axis] = 0,
inplace: Optional[bool] = False, inplace: Optional[bool] = False,
) -> Optional["DataFrame"]: ) -> Optional["DataFrame"]:
@ -10311,20 +10289,18 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
""" """
def gen_names( def gen_names(
v: Union[ v: Union[Any, Sequence[Any], Dict[Name, Any], Callable[[Name], Any]],
Any, Sequence[Any], Dict[Union[Any, Tuple], Any], Callable[[Union[Any, Tuple]], Any] curnames: List[Name],
], ) -> List[Label]:
curnames: List[Union[Any, Tuple]],
) -> List[Tuple]:
if is_scalar(v): if is_scalar(v):
newnames = [cast(Any, v)] # type: List[Union[Any, Tuple]] newnames = [cast(Any, v)] # type: List[Name]
elif is_list_like(v) and not is_dict_like(v): elif is_list_like(v) and not is_dict_like(v):
newnames = list(cast(Sequence[Any], v)) newnames = list(cast(Sequence[Any], v))
elif is_dict_like(v): elif is_dict_like(v):
v_dict = cast(Dict[Union[Any, Tuple], Any], v) v_dict = cast(Dict[Name, Any], v)
newnames = [v_dict[name] if name in v_dict else name for name in curnames] newnames = [v_dict[name] if name in v_dict else name for name in curnames]
elif callable(v): elif callable(v):
v_callable = cast(Callable[[Union[Any, Tuple]], Any], v) v_callable = cast(Callable[[Name], Any], v)
newnames = [v_callable(name) for name in curnames] newnames = [v_callable(name) for name in curnames]
else: else:
raise ValueError( raise ValueError(
@ -11184,7 +11160,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
# Returns a frame # Returns a frame
return result return result
def explode(self, column: Union[Any, Tuple]) -> "DataFrame": def explode(self, column: Name) -> "DataFrame":
""" """
Transform each element of a list-like to a row, replicating index values. Transform each element of a list-like to a row, replicating index values.
@ -11280,7 +11256,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
if axis == 0: if axis == 0:
def get_spark_column(psdf: DataFrame, label: Tuple) -> Column: def get_spark_column(psdf: DataFrame, label: Label) -> Column:
scol = psdf._internal.spark_column_for(label) scol = psdf._internal.spark_column_for(label)
col_type = psdf._internal.spark_type_for(label) col_type = psdf._internal.spark_type_for(label)
@ -11289,7 +11265,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
return scol return scol
new_column_labels = [] # type: List[Tuple] new_column_labels = [] # type: List[Label]
for label in self._internal.column_labels: for label in self._internal.column_labels:
# Filtering out only columns of numeric and boolean type column. # Filtering out only columns of numeric and boolean type column.
dtype = self._psser_for(label).spark.data_type dtype = self._psser_for(label).spark.data_type
@ -11603,10 +11579,10 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
@staticmethod @staticmethod
def from_dict( def from_dict(
data: Dict[Union[Any, Tuple], Sequence[Any]], data: Dict[Name, Sequence[Any]],
orient: str = "columns", orient: str = "columns",
dtype: Union[str, Dtype] = None, dtype: Union[str, Dtype] = None,
columns: Optional[List[Union[Any, Tuple]]] = None, columns: Optional[List[Name]] = None,
) -> "DataFrame": ) -> "DataFrame":
""" """
Construct DataFrame from dict of array-like or dicts. Construct DataFrame from dict of array-like or dicts.
@ -11673,7 +11649,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
# Override the `groupby` to specify the actual return type annotation. # Override the `groupby` to specify the actual return type annotation.
def groupby( def groupby(
self, self,
by: Union[Any, Tuple, "Series", List[Union[Any, Tuple, "Series"]]], by: Union[Name, "Series", List[Union[Name, "Series"]]],
axis: Axis = 0, axis: Axis = 0,
as_index: bool = True, as_index: bool = True,
dropna: bool = True, dropna: bool = True,
@ -11685,7 +11661,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
groupby.__doc__ = Frame.groupby.__doc__ groupby.__doc__ = Frame.groupby.__doc__
def _build_groupby( def _build_groupby(
self, by: List[Union["Series", Tuple]], as_index: bool, dropna: bool self, by: List[Union["Series", Label]], as_index: bool, dropna: bool
) -> "DataFrameGroupBy": ) -> "DataFrameGroupBy":
from pyspark.pandas.groupby import DataFrameGroupBy from pyspark.pandas.groupby import DataFrameGroupBy
@ -11780,8 +11756,8 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
value = DataFrame._index_normalized_frame(level, value) value = DataFrame._index_normalized_frame(level, value)
def assign_columns( def assign_columns(
psdf: DataFrame, this_column_labels: List[Tuple], that_column_labels: List[Tuple] psdf: DataFrame, this_column_labels: List[Label], that_column_labels: List[Label]
) -> Iterator[Tuple["Series", Tuple]]: ) -> Iterator[Tuple["Series", Label]]:
assert len(key) == len(that_column_labels) assert len(key) == len(that_column_labels)
# Note that here intentionally uses `zip_longest` that combine # Note that here intentionally uses `zip_longest` that combine
# that_columns. # that_columns.
@ -11821,9 +11797,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
self._update_internal_frame(psdf._internal) self._update_internal_frame(psdf._internal)
@staticmethod @staticmethod
def _index_normalized_label( def _index_normalized_label(level: int, labels: Union[Name, Sequence[Name]]) -> List[Label]:
level: int, labels: Union[Any, Tuple, Sequence[Union[Any, Tuple]]]
) -> List[Tuple]:
""" """
Returns a label that is normalized against the current column index level. Returns a label that is normalized against the current column index level.
For example, the key "abc" can be ("abc", "", "") if the current Frame has For example, the key "abc" can be ("abc", "", "") if the current Frame has
@ -11910,7 +11884,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
] ]
return list(super().__dir__()) + fields return list(super().__dir__()) + fields
def __iter__(self) -> Iterator[Union[Any, Tuple]]: def __iter__(self) -> Iterator[Name]:
return iter(self.columns) return iter(self.columns)
# NDArray Compat # NDArray Compat
@ -11930,8 +11904,8 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
# Different DataFrames # Different DataFrames
def apply_op( def apply_op(
psdf: DataFrame, this_column_labels: List[Tuple], that_column_labels: List[Tuple] psdf: DataFrame, this_column_labels: List[Label], that_column_labels: List[Label]
) -> Iterator[Tuple["Series", Tuple]]: ) -> Iterator[Tuple["Series", Label]]:
for this_label, that_label in zip(this_column_labels, that_column_labels): for this_label, that_label in zip(this_column_labels, that_column_labels):
yield ( yield (
ufunc( ufunc(

View file

@ -53,7 +53,15 @@ from pyspark.sql.types import (
) )
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
from pyspark.pandas._typing import Axis, DataFrameOrSeries, Dtype, FrameLike, Scalar from pyspark.pandas._typing import (
Axis,
DataFrameOrSeries,
Dtype,
FrameLike,
Label,
Name,
Scalar,
)
from pyspark.pandas.indexing import AtIndexer, iAtIndexer, iLocIndexer, LocIndexer from pyspark.pandas.indexing import AtIndexer, iAtIndexer, iLocIndexer, LocIndexer
from pyspark.pandas.internal import InternalFrame from pyspark.pandas.internal import InternalFrame
from pyspark.pandas.spark import functions as SF from pyspark.pandas.spark import functions as SF
@ -636,7 +644,7 @@ class Frame(object, metaclass=ABCMeta):
path: Optional[str] = None, path: Optional[str] = None,
sep: str = ",", sep: str = ",",
na_rep: str = "", na_rep: str = "",
columns: Optional[List[Union[Any, Tuple]]] = None, columns: Optional[List[Name]] = None,
header: bool = True, header: bool = True,
quotechar: str = '"', quotechar: str = '"',
date_format: Optional[str] = None, date_format: Optional[str] = None,
@ -811,9 +819,11 @@ class Frame(object, metaclass=ABCMeta):
column_labels = psdf._internal.column_labels column_labels = psdf._internal.column_labels
else: else:
column_labels = [] column_labels = []
for label in columns: for col in columns:
if not is_name_like_tuple(label): if is_name_like_tuple(col):
label = (label,) label = cast(Label, col)
else:
label = cast(Label, (col,))
if label not in psdf._internal.column_labels: if label not in psdf._internal.column_labels:
raise KeyError(name_like_string(label)) raise KeyError(name_like_string(label))
column_labels.append(label) column_labels.append(label)
@ -2119,7 +2129,7 @@ class Frame(object, metaclass=ABCMeta):
# should be updated when it's supported. # should be updated when it's supported.
def groupby( def groupby(
self: FrameLike, self: FrameLike,
by: Union[Any, Tuple, "Series", List[Union[Any, Tuple, "Series"]]], by: Union[Name, "Series", List[Union[Name, "Series"]]],
axis: Axis = 0, axis: Axis = 0,
as_index: bool = True, as_index: bool = True,
dropna: bool = True, dropna: bool = True,
@ -2206,15 +2216,15 @@ class Frame(object, metaclass=ABCMeta):
if isinstance(by, ps.DataFrame): if isinstance(by, ps.DataFrame):
raise ValueError("Grouper for '{}' not 1-dimensional".format(type(by).__name__)) raise ValueError("Grouper for '{}' not 1-dimensional".format(type(by).__name__))
elif isinstance(by, ps.Series): elif isinstance(by, ps.Series):
new_by = [by] # type: List[Union[Tuple, ps.Series]] new_by = [by] # type: List[Union[Label, ps.Series]]
elif is_name_like_tuple(by): elif is_name_like_tuple(by):
if isinstance(self, ps.Series): if isinstance(self, ps.Series):
raise KeyError(by) raise KeyError(by)
new_by = [cast(Tuple, by)] new_by = [cast(Label, by)]
elif is_name_like_value(by): elif is_name_like_value(by):
if isinstance(self, ps.Series): if isinstance(self, ps.Series):
raise KeyError(by) raise KeyError(by)
new_by = [(by,)] new_by = [cast(Label, (by,))]
elif is_list_like(by): elif is_list_like(by):
new_by = [] new_by = []
for key in by: for key in by:
@ -2227,11 +2237,11 @@ class Frame(object, metaclass=ABCMeta):
elif is_name_like_tuple(key): elif is_name_like_tuple(key):
if isinstance(self, ps.Series): if isinstance(self, ps.Series):
raise KeyError(key) raise KeyError(key)
new_by.append(key) new_by.append(cast(Label, key))
elif is_name_like_value(key): elif is_name_like_value(key):
if isinstance(self, ps.Series): if isinstance(self, ps.Series):
raise KeyError(key) raise KeyError(key)
new_by.append((key,)) new_by.append(cast(Label, (key,)))
else: else:
raise ValueError( raise ValueError(
"Grouper for '{}' not 1-dimensional".format(type(key).__name__) "Grouper for '{}' not 1-dimensional".format(type(key).__name__)
@ -2248,7 +2258,7 @@ class Frame(object, metaclass=ABCMeta):
@abstractmethod @abstractmethod
def _build_groupby( def _build_groupby(
self: FrameLike, by: List[Union["Series", Tuple]], as_index: bool, dropna: bool self: FrameLike, by: List[Union["Series", Label]], as_index: bool, dropna: bool
) -> "GroupBy[FrameLike]": ) -> "GroupBy[FrameLike]":
pass pass

View file

@ -58,7 +58,7 @@ from pyspark.sql.types import ( # noqa: F401
) )
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
from pyspark.pandas._typing import Axis, FrameLike from pyspark.pandas._typing import Axis, FrameLike, Label, Name
from pyspark.pandas.typedef import infer_return_type, DataFrameType, ScalarType, SeriesType from pyspark.pandas.typedef import infer_return_type, DataFrameType, ScalarType, SeriesType
from pyspark.pandas.frame import DataFrame from pyspark.pandas.frame import DataFrame
from pyspark.pandas.internal import ( from pyspark.pandas.internal import (
@ -110,7 +110,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
groupkeys: List[Series], groupkeys: List[Series],
as_index: bool, as_index: bool,
dropna: bool, dropna: bool,
column_labels_to_exlcude: Set[Tuple], column_labels_to_exlcude: Set[Label],
agg_columns_selected: bool, agg_columns_selected: bool,
agg_columns: List[Series], agg_columns: List[Series],
): ):
@ -147,9 +147,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
# TODO: not all arguments are implemented comparing to pandas' for now. # TODO: not all arguments are implemented comparing to pandas' for now.
def aggregate( def aggregate(
self, self,
func_or_funcs: Optional[ func_or_funcs: Optional[Union[str, List[str], Dict[Name, Union[str, List[str]]]]] = None,
Union[str, List[str], Dict[Union[Any, Tuple], Union[str, List[str]]]]
] = None,
*args: Any, *args: Any,
**kwargs: Any **kwargs: Any
) -> DataFrame: ) -> DataFrame:
@ -312,7 +310,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
@staticmethod @staticmethod
def _spark_groupby( def _spark_groupby(
psdf: DataFrame, psdf: DataFrame,
func: Mapping[Union[Any, Tuple], Union[str, List[str]]], func: Mapping[Name, Union[str, List[str]]],
groupkeys: Sequence[Series] = (), groupkeys: Sequence[Series] = (),
) -> InternalFrame: ) -> InternalFrame:
groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(groupkeys))] groupkey_names = [SPARK_INDEX_NAME_FORMAT(i) for i in range(len(groupkeys))]
@ -1405,11 +1403,11 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
@staticmethod @staticmethod
def _prepare_group_map_apply( def _prepare_group_map_apply(
psdf: DataFrame, groupkeys: List[Series], agg_columns: List[Series] psdf: DataFrame, groupkeys: List[Series], agg_columns: List[Series]
) -> Tuple[DataFrame, List[Tuple], List[str]]: ) -> Tuple[DataFrame, List[Label], List[str]]:
groupkey_labels = [ groupkey_labels = [
verify_temp_column_name(psdf, "__groupkey_{}__".format(i)) verify_temp_column_name(psdf, "__groupkey_{}__".format(i))
for i in range(len(groupkeys)) for i in range(len(groupkeys))
] # type: List[Tuple] ] # type: List[Label]
psdf = psdf[[s.rename(label) for s, label in zip(groupkeys, groupkey_labels)] + agg_columns] psdf = psdf[[s.rename(label) for s, label in zip(groupkeys, groupkey_labels)] + agg_columns]
groupkey_names = [label if len(label) > 1 else label[0] for label in groupkey_labels] groupkey_names = [label if len(label) > 1 else label[0] for label in groupkey_labels]
return DataFrame(psdf._internal.resolved_copy), groupkey_labels, groupkey_names return DataFrame(psdf._internal.resolved_copy), groupkey_labels, groupkey_names
@ -2377,7 +2375,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
return ExpandingGroupby(self, min_periods=min_periods) return ExpandingGroupby(self, min_periods=min_periods)
def get_group(self, name: Union[Any, Tuple, List[Union[Any, Tuple]]]) -> FrameLike: def get_group(self, name: Union[Name, List[Name]]) -> FrameLike:
""" """
Construct DataFrame from group with provided name. Construct DataFrame from group with provided name.
@ -2594,8 +2592,8 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
@staticmethod @staticmethod
def _resolve_grouping_from_diff_dataframes( def _resolve_grouping_from_diff_dataframes(
psdf: DataFrame, by: List[Union[Series, Tuple]] psdf: DataFrame, by: List[Union[Series, Label]]
) -> Tuple[DataFrame, List[Series], Set[Tuple]]: ) -> Tuple[DataFrame, List[Series], Set[Label]]:
column_labels_level = psdf._internal.column_labels_level column_labels_level = psdf._internal.column_labels_level
column_labels = [] column_labels = []
@ -2636,8 +2634,8 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
) )
def assign_columns( def assign_columns(
psdf: DataFrame, this_column_labels: List[Tuple], that_column_labels: List[Tuple] psdf: DataFrame, this_column_labels: List[Label], that_column_labels: List[Label]
) -> Iterator[Tuple[Series, Tuple]]: ) -> Iterator[Tuple[Series, Label]]:
raise NotImplementedError( raise NotImplementedError(
"Duplicated labels with groupby() and " "Duplicated labels with groupby() and "
"'compute.ops_on_diff_frames' option are not supported currently " "'compute.ops_on_diff_frames' option are not supported currently "
@ -2669,7 +2667,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
return psdf, new_by_series, tmp_column_labels return psdf, new_by_series, tmp_column_labels
@staticmethod @staticmethod
def _resolve_grouping(psdf: DataFrame, by: List[Union[Series, Tuple]]) -> List[Series]: def _resolve_grouping(psdf: DataFrame, by: List[Union[Series, Label]]) -> List[Series]:
new_by_series = [] new_by_series = []
for col_or_s in by: for col_or_s in by:
if isinstance(col_or_s, Series): if isinstance(col_or_s, Series):
@ -2687,7 +2685,7 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
class DataFrameGroupBy(GroupBy[DataFrame]): class DataFrameGroupBy(GroupBy[DataFrame]):
@staticmethod @staticmethod
def _build( def _build(
psdf: DataFrame, by: List[Union[Series, Tuple]], as_index: bool, dropna: bool psdf: DataFrame, by: List[Union[Series, Label]], as_index: bool, dropna: bool
) -> "DataFrameGroupBy": ) -> "DataFrameGroupBy":
if any(isinstance(col_or_s, Series) and not same_anchor(psdf, col_or_s) for col_or_s in by): if any(isinstance(col_or_s, Series) and not same_anchor(psdf, col_or_s) for col_or_s in by):
( (
@ -2712,8 +2710,8 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
by: List[Series], by: List[Series],
as_index: bool, as_index: bool,
dropna: bool, dropna: bool,
column_labels_to_exlcude: Set[Tuple], column_labels_to_exlcude: Set[Label],
agg_columns: List[Tuple] = None, agg_columns: List[Label] = None,
): ):
agg_columns_selected = agg_columns is not None agg_columns_selected = agg_columns is not None
if agg_columns_selected: if agg_columns_selected:
@ -2891,7 +2889,7 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
class SeriesGroupBy(GroupBy[Series]): class SeriesGroupBy(GroupBy[Series]):
@staticmethod @staticmethod
def _build( def _build(
psser: Series, by: List[Union[Series, Tuple]], as_index: bool, dropna: bool psser: Series, by: List[Union[Series, Label]], as_index: bool, dropna: bool
) -> "SeriesGroupBy": ) -> "SeriesGroupBy":
if any( if any(
isinstance(col_or_s, Series) and not same_anchor(psser, col_or_s) for col_or_s in by isinstance(col_or_s, Series) and not same_anchor(psser, col_or_s) for col_or_s in by
@ -3255,8 +3253,8 @@ def is_multi_agg_with_relabel(**kwargs: Any) -> bool:
def normalize_keyword_aggregation( def normalize_keyword_aggregation(
kwargs: Dict[str, Tuple[Union[Any, Tuple], str]], kwargs: Dict[str, Tuple[Name, str]],
) -> Tuple[Dict[Union[Any, Tuple], List[str]], List[str], List[Tuple]]: ) -> Tuple[Dict[Name, List[str]], List[str], List[Tuple]]:
""" """
Normalize user-provided kwargs. Normalize user-provided kwargs.

View file

@ -40,7 +40,7 @@ from pyspark.sql import functions as F, Column
from pyspark.sql.types import FractionalType, IntegralType, TimestampType from pyspark.sql.types import FractionalType, IntegralType, TimestampType
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
from pyspark.pandas._typing import Dtype, Scalar from pyspark.pandas._typing import Dtype, Label, Name, Scalar
from pyspark.pandas.config import get_option, option_context from pyspark.pandas.config import get_option, option_context
from pyspark.pandas.base import IndexOpsMixin from pyspark.pandas.base import IndexOpsMixin
from pyspark.pandas.frame import DataFrame from pyspark.pandas.frame import DataFrame
@ -125,7 +125,7 @@ class Index(IndexOpsMixin):
data: Optional[Any] = None, data: Optional[Any] = None,
dtype: Optional[Union[str, Dtype]] = None, dtype: Optional[Union[str, Dtype]] = None,
copy: bool = False, copy: bool = False,
name: Optional[Union[Any, Tuple]] = None, name: Optional[Name] = None,
tupleize_cols: bool = True, tupleize_cols: bool = True,
**kwargs: Any **kwargs: Any
) -> "Index": ) -> "Index":
@ -215,7 +215,7 @@ class Index(IndexOpsMixin):
) )
@property @property
def _column_label(self) -> Optional[Tuple]: def _column_label(self) -> Optional[Label]:
return self._psdf._internal.index_names[0] return self._psdf._internal.index_names[0]
def _with_new_scol(self, scol: Column, *, field: Optional[InternalField] = None) -> "Index": def _with_new_scol(self, scol: Column, *, field: Optional[InternalField] = None) -> "Index":
@ -636,24 +636,24 @@ class Index(IndexOpsMixin):
return not self.has_duplicates return not self.has_duplicates
@property @property
def name(self) -> Union[Any, Tuple]: def name(self) -> Name:
"""Return name of the Index.""" """Return name of the Index."""
return self.names[0] return self.names[0]
@name.setter @name.setter
def name(self, name: Union[Any, Tuple]) -> None: def name(self, name: Name) -> None:
self.names = [name] self.names = [name]
@property @property
def names(self) -> List[Union[Any, Tuple]]: def names(self) -> List[Name]:
"""Return names of the Index.""" """Return names of the Index."""
return [ return [
name if name is None or len(name) > 1 else name[0] name if name is None or len(name) > 1 else name[0]
for name in self._internal.index_names # type: ignore for name in self._internal.index_names
] ]
@names.setter @names.setter
def names(self, names: List[Union[Any, Tuple]]) -> None: def names(self, names: List[Name]) -> None:
if not is_list_like(names): if not is_list_like(names):
raise ValueError("Names must be a list-like") raise ValueError("Names must be a list-like")
if self._internal.index_level != len(names): if self._internal.index_level != len(names):
@ -684,9 +684,7 @@ class Index(IndexOpsMixin):
""" """
return self._internal.index_level return self._internal.index_level
def rename( def rename(self, name: Union[Name, List[Name]], inplace: bool = False) -> Optional["Index"]:
self, name: Union[Any, Tuple, List[Union[Any, Tuple]]], inplace: bool = False
) -> Optional["Index"]:
""" """
Alter Index or MultiIndex name. Alter Index or MultiIndex name.
Able to set new names without level. Defaults to returning new index. Able to set new names without level. Defaults to returning new index.
@ -749,7 +747,7 @@ class Index(IndexOpsMixin):
else: else:
return DataFrame(internal).index return DataFrame(internal).index
def _verify_for_rename(self, name: Union[Any, Tuple]) -> List[Tuple]: def _verify_for_rename(self, name: Name) -> List[Label]:
if is_hashable(name): if is_hashable(name):
if is_name_like_tuple(name): if is_name_like_tuple(name):
return [name] return [name]
@ -830,7 +828,7 @@ class Index(IndexOpsMixin):
) )
return DataFrame(internal).index return DataFrame(internal).index
def to_series(self, name: Optional[Union[Any, Tuple]] = None) -> Series: def to_series(self, name: Optional[Name] = None) -> Series:
""" """
Create a Series with both index and values equal to the index keys Create a Series with both index and values equal to the index keys
useful with map for returning an indexer based on an index. useful with map for returning an indexer based on an index.
@ -868,7 +866,7 @@ class Index(IndexOpsMixin):
name = self.name name = self.name
column_labels = [ column_labels = [
name if is_name_like_tuple(name) else (name,) name if is_name_like_tuple(name) else (name,)
] # type: List[Optional[Tuple]] ] # type: List[Optional[Label]]
internal = self._internal.copy( internal = self._internal.copy(
column_labels=column_labels, column_labels=column_labels,
data_spark_columns=[scol], data_spark_columns=[scol],
@ -877,7 +875,7 @@ class Index(IndexOpsMixin):
) )
return first_series(DataFrame(internal)) return first_series(DataFrame(internal))
def to_frame(self, index: bool = True, name: Optional[Union[Any, Tuple]] = None) -> DataFrame: def to_frame(self, index: bool = True, name: Optional[Name] = None) -> DataFrame:
""" """
Create a DataFrame with a column containing the Index. Create a DataFrame with a column containing the Index.
@ -939,7 +937,7 @@ class Index(IndexOpsMixin):
return self._to_frame(index=index, names=[name]) return self._to_frame(index=index, names=[name])
def _to_frame(self, index: bool, names: List[Tuple]) -> DataFrame: def _to_frame(self, index: bool, names: List[Label]) -> DataFrame:
if index: if index:
index_spark_columns = self._internal.index_spark_columns index_spark_columns = self._internal.index_spark_columns
index_names = self._internal.index_names index_names = self._internal.index_names
@ -1115,7 +1113,7 @@ class Index(IndexOpsMixin):
) )
return DataFrame(internal).index return DataFrame(internal).index
def unique(self, level: Optional[Union[int, Any, Tuple]] = None) -> "Index": def unique(self, level: Optional[Union[int, Name]] = None) -> "Index":
""" """
Return unique values in the index. Return unique values in the index.
@ -1203,7 +1201,7 @@ class Index(IndexOpsMixin):
) )
return DataFrame(internal).index return DataFrame(internal).index
def _validate_index_level(self, level: Union[int, Any, Tuple]) -> None: def _validate_index_level(self, level: Union[int, Name]) -> None:
""" """
Validate index level. Validate index level.
For single-level Index getting level number is a no-op, but some For single-level Index getting level number is a no-op, but some
@ -1222,7 +1220,7 @@ class Index(IndexOpsMixin):
"Requested level ({}) does not match index name ({})".format(level, self.name) "Requested level ({}) does not match index name ({})".format(level, self.name)
) )
def get_level_values(self, level: Union[int, Any, Tuple]) -> "Index": def get_level_values(self, level: Union[int, Name]) -> "Index":
""" """
Return Index if a valid level is given. Return Index if a valid level is given.
@ -1238,9 +1236,7 @@ class Index(IndexOpsMixin):
self._validate_index_level(level) self._validate_index_level(level)
return self return self
def copy( def copy(self, name: Optional[Name] = None, deep: Optional[bool] = None) -> "Index":
self, name: Optional[Union[Any, Tuple]] = None, deep: Optional[bool] = None
) -> "Index":
""" """
Make a copy of this object. name sets those attributes on the new object. Make a copy of this object. name sets those attributes on the new object.
@ -1279,7 +1275,7 @@ class Index(IndexOpsMixin):
result.name = name result.name = name
return result return result
def droplevel(self, level: Union[int, Any, Tuple, List[Union[int, Any, Tuple]]]) -> "Index": def droplevel(self, level: Union[int, Name, List[Union[int, Name]]]) -> "Index":
""" """
Return index with requested level(s) removed. Return index with requested level(s) removed.
If resulting index has only 1 level left, the result will be If resulting index has only 1 level left, the result will be
@ -1317,9 +1313,9 @@ class Index(IndexOpsMixin):
names = self.names names = self.names
nlevels = self.nlevels nlevels = self.nlevels
if not is_list_like(level): if not is_list_like(level):
levels = [cast(Union[int, Any, Tuple], level)] levels = [cast(Union[int, Name], level)]
else: else:
levels = cast(List[Union[int, Any, Tuple]], level) levels = cast(List[Union[int, Name]], level)
int_level = set() int_level = set()
for n in levels: for n in levels:
@ -1375,7 +1371,7 @@ class Index(IndexOpsMixin):
def symmetric_difference( def symmetric_difference(
self, self,
other: "Index", other: "Index",
result_name: Optional[Union[Any, Tuple]] = None, result_name: Optional[Name] = None,
sort: Optional[bool] = None, sort: Optional[bool] = None,
) -> "Index": ) -> "Index":
""" """
@ -1887,8 +1883,8 @@ class Index(IndexOpsMixin):
def set_names( def set_names(
self, self,
names: Union[Any, Tuple, List[Union[Any, Tuple]]], names: Union[Name, List[Name]],
level: Optional[Union[int, Any, Tuple, List[Union[int, Any, Tuple]]]] = None, level: Optional[Union[int, Name, List[Union[int, Name]]]] = None,
inplace: bool = False, inplace: bool = False,
) -> Optional["Index"]: ) -> Optional["Index"]:
""" """

View file

@ -28,7 +28,7 @@ from pyspark.sql.types import DataType
# For running doctests and reference resolution in PyCharm. # For running doctests and reference resolution in PyCharm.
from pyspark import pandas as ps # noqa: F401 from pyspark import pandas as ps # noqa: F401
from pyspark.pandas._typing import Scalar from pyspark.pandas._typing import Label, Name, Scalar
from pyspark.pandas.exceptions import PandasNotImplementedError from pyspark.pandas.exceptions import PandasNotImplementedError
from pyspark.pandas.frame import DataFrame from pyspark.pandas.frame import DataFrame
from pyspark.pandas.indexes.base import Index from pyspark.pandas.indexes.base import Index
@ -146,7 +146,7 @@ class MultiIndex(Index):
) )
@property @property
def _column_label(self) -> Optional[Tuple]: def _column_label(self) -> Optional[Label]:
return None return None
def __abs__(self) -> "MultiIndex": def __abs__(self) -> "MultiIndex":
@ -169,7 +169,7 @@ class MultiIndex(Index):
def from_tuples( def from_tuples(
tuples: List[Tuple], tuples: List[Tuple],
sortorder: Optional[int] = None, sortorder: Optional[int] = None,
names: Optional[List[Union[Any, Tuple]]] = None, names: Optional[List[Name]] = None,
) -> "MultiIndex": ) -> "MultiIndex":
""" """
Convert list of tuples to MultiIndex. Convert list of tuples to MultiIndex.
@ -210,7 +210,7 @@ class MultiIndex(Index):
def from_arrays( def from_arrays(
arrays: List[List], arrays: List[List],
sortorder: Optional[int] = None, sortorder: Optional[int] = None,
names: Optional[List[Union[Any, Tuple]]] = None, names: Optional[List[Name]] = None,
) -> "MultiIndex": ) -> "MultiIndex":
""" """
Convert arrays to MultiIndex. Convert arrays to MultiIndex.
@ -251,7 +251,7 @@ class MultiIndex(Index):
def from_product( def from_product(
iterables: List[List], iterables: List[List],
sortorder: Optional[int] = None, sortorder: Optional[int] = None,
names: Optional[List[Union[Any, Tuple]]] = None, names: Optional[List[Name]] = None,
) -> "MultiIndex": ) -> "MultiIndex":
""" """
Make a MultiIndex from the cartesian product of multiple iterables. Make a MultiIndex from the cartesian product of multiple iterables.
@ -297,7 +297,7 @@ class MultiIndex(Index):
) )
@staticmethod @staticmethod
def from_frame(df: DataFrame, names: Optional[List[Union[Any, Tuple]]] = None) -> "MultiIndex": def from_frame(df: DataFrame, names: Optional[List[Name]] = None) -> "MultiIndex":
""" """
Make a MultiIndex from a DataFrame. Make a MultiIndex from a DataFrame.
@ -369,16 +369,14 @@ class MultiIndex(Index):
return cast(MultiIndex, DataFrame(internal).index) return cast(MultiIndex, DataFrame(internal).index)
@property @property
def name(self) -> Union[Any, Tuple]: def name(self) -> Name:
raise PandasNotImplementedError(class_name="pd.MultiIndex", property_name="name") raise PandasNotImplementedError(class_name="pd.MultiIndex", property_name="name")
@name.setter @name.setter
def name(self, name: Union[Any, Tuple]) -> None: def name(self, name: Name) -> None:
raise PandasNotImplementedError(class_name="pd.MultiIndex", property_name="name") raise PandasNotImplementedError(class_name="pd.MultiIndex", property_name="name")
def _verify_for_rename( # type: ignore[override] def _verify_for_rename(self, name: List[Name]) -> List[Label]: # type: ignore[override]
self, name: List[Union[Any, Tuple]]
) -> List[Tuple]:
if is_list_like(name): if is_list_like(name):
if self._internal.index_level != len(name): if self._internal.index_level != len(name):
raise ValueError( raise ValueError(
@ -575,7 +573,7 @@ class MultiIndex(Index):
return first_series(DataFrame(internal)) return first_series(DataFrame(internal))
def to_frame( # type: ignore[override] def to_frame( # type: ignore[override]
self, index: bool = True, name: Optional[List[Union[Any, Tuple]]] = None self, index: bool = True, name: Optional[List[Name]] = None
) -> DataFrame: ) -> DataFrame:
""" """
Create a DataFrame with the levels of the MultiIndex as columns. Create a DataFrame with the levels of the MultiIndex as columns.
@ -712,7 +710,7 @@ class MultiIndex(Index):
def symmetric_difference( # type: ignore[override] def symmetric_difference( # type: ignore[override]
self, self,
other: Index, other: Index,
result_name: Optional[List[Union[Any, Tuple]]] = None, result_name: Optional[List[Name]] = None,
sort: Optional[bool] = None, sort: Optional[bool] = None,
) -> "MultiIndex": ) -> "MultiIndex":
""" """
@ -807,9 +805,7 @@ class MultiIndex(Index):
return result return result
# TODO: ADD error parameter # TODO: ADD error parameter
def drop( def drop(self, codes: List[Any], level: Optional[Union[int, Name]] = None) -> "MultiIndex":
self, codes: List[Any], level: Optional[Union[int, Any, Tuple]] = None
) -> "MultiIndex":
""" """
Make new MultiIndex with passed list of labels deleted Make new MultiIndex with passed list of labels deleted
@ -920,7 +916,7 @@ class MultiIndex(Index):
return partial(property_or_func, self) return partial(property_or_func, self)
raise AttributeError("'MultiIndex' object has no attribute '{}'".format(item)) raise AttributeError("'MultiIndex' object has no attribute '{}'".format(item))
def _get_level_number(self, level: Union[int, Any, Tuple]) -> int: def _get_level_number(self, level: Union[int, Name]) -> int:
""" """
Return the level number if a valid level is given. Return the level number if a valid level is given.
""" """
@ -948,7 +944,7 @@ class MultiIndex(Index):
return level return level
def get_level_values(self, level: Union[int, Any, Tuple]) -> Index: def get_level_values(self, level: Union[int, Name]) -> Index:
""" """
Return vector of label values for requested level, Return vector of label values for requested level,
equal to the length of the index. equal to the length of the index.
@ -1046,7 +1042,7 @@ class MultiIndex(Index):
index_name = [ index_name = [
(name,) for name in self._internal.index_spark_column_names (name,) for name in self._internal.index_spark_column_names
] # type: List[Tuple] ] # type: List[Label]
sdf_before = self.to_frame(name=index_name)[:loc].to_spark() sdf_before = self.to_frame(name=index_name)[:loc].to_spark()
sdf_middle = Index([item]).to_frame(name=index_name).to_spark() sdf_middle = Index([item]).to_frame(name=index_name).to_spark()
sdf_after = self.to_frame(name=index_name)[loc:].to_spark() sdf_after = self.to_frame(name=index_name)[loc:].to_spark()

View file

@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from typing import Any, Optional, Tuple, Union, cast from typing import Any, Optional, Union, cast
import pandas as pd import pandas as pd
from pandas.api.types import is_hashable from pandas.api.types import is_hashable
from pyspark import pandas as ps from pyspark import pandas as ps
from pyspark.pandas._typing import Dtype from pyspark.pandas._typing import Dtype, Name
from pyspark.pandas.indexes.base import Index from pyspark.pandas.indexes.base import Index
from pyspark.pandas.series import Series from pyspark.pandas.series import Series
@ -89,7 +89,7 @@ class Int64Index(IntegerIndex):
data: Optional[Any] = None, data: Optional[Any] = None,
dtype: Optional[Union[str, Dtype]] = None, dtype: Optional[Union[str, Dtype]] = None,
copy: bool = False, copy: bool = False,
name: Optional[Union[Any, Tuple]] = None, name: Optional[Name] = None,
) -> "Int64Index": ) -> "Int64Index":
if not is_hashable(name): if not is_hashable(name):
raise TypeError("Index.name must be a hashable type") raise TypeError("Index.name must be a hashable type")
@ -151,7 +151,7 @@ class Float64Index(NumericIndex):
data: Optional[Any] = None, data: Optional[Any] = None,
dtype: Optional[Union[str, Dtype]] = None, dtype: Optional[Union[str, Dtype]] = None,
copy: bool = False, copy: bool = False,
name: Optional[Union[Any, Tuple]] = None, name: Optional[Name] = None,
) -> "Float64Index": ) -> "Float64Index":
if not is_hashable(name): if not is_hashable(name):
raise TypeError("Index.name must be a hashable type") raise TypeError("Index.name must be a hashable type")

View file

@ -31,7 +31,7 @@ from pyspark.sql.utils import AnalysisException
import numpy as np import numpy as np
from pyspark import pandas as ps # noqa: F401 from pyspark import pandas as ps # noqa: F401
from pyspark.pandas._typing import Scalar from pyspark.pandas._typing import Label, Name, Scalar
from pyspark.pandas.internal import ( from pyspark.pandas.internal import (
InternalField, InternalField,
InternalFrame, InternalFrame,
@ -274,13 +274,13 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
return self._select_rows_else(rows_sel) return self._select_rows_else(rows_sel)
def _select_cols( def _select_cols(
self, cols_sel: Any, missing_keys: Optional[List[Tuple]] = None self, cols_sel: Any, missing_keys: Optional[List[Name]] = None
) -> Tuple[ ) -> Tuple[
List[Tuple], List[Label],
Optional[List[Column]], Optional[List[Column]],
Optional[List[InternalField]], Optional[List[InternalField]],
bool, bool,
Optional[Tuple], Optional[Name],
]: ]:
""" """
Dispatch the logic for select columns to more specific methods by `cols_sel` argument types. Dispatch the logic for select columns to more specific methods by `cols_sel` argument types.
@ -366,65 +366,65 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
@abstractmethod @abstractmethod
def _select_cols_by_series( def _select_cols_by_series(
self, cols_sel: "Series", missing_keys: Optional[List[Tuple]] self, cols_sel: "Series", missing_keys: Optional[List[Name]]
) -> Tuple[ ) -> Tuple[
List[Tuple], List[Label],
Optional[List[Column]], Optional[List[Column]],
Optional[List[InternalField]], Optional[List[InternalField]],
bool, bool,
Optional[Tuple], Optional[Name],
]: ]:
"""Select columns by `Series` type key.""" """Select columns by `Series` type key."""
pass pass
@abstractmethod @abstractmethod
def _select_cols_by_spark_column( def _select_cols_by_spark_column(
self, cols_sel: Column, missing_keys: Optional[List[Tuple]] self, cols_sel: Column, missing_keys: Optional[List[Name]]
) -> Tuple[ ) -> Tuple[
List[Tuple], List[Label],
Optional[List[Column]], Optional[List[Column]],
Optional[List[InternalField]], Optional[List[InternalField]],
bool, bool,
Optional[Tuple], Optional[Name],
]: ]:
"""Select columns by Spark `Column` type key.""" """Select columns by Spark `Column` type key."""
pass pass
@abstractmethod @abstractmethod
def _select_cols_by_slice( def _select_cols_by_slice(
self, cols_sel: slice, missing_keys: Optional[List[Tuple]] self, cols_sel: slice, missing_keys: Optional[List[Name]]
) -> Tuple[ ) -> Tuple[
List[Tuple], List[Label],
Optional[List[Column]], Optional[List[Column]],
Optional[List[InternalField]], Optional[List[InternalField]],
bool, bool,
Optional[Tuple], Optional[Name],
]: ]:
"""Select columns by `slice` type key.""" """Select columns by `slice` type key."""
pass pass
@abstractmethod @abstractmethod
def _select_cols_by_iterable( def _select_cols_by_iterable(
self, cols_sel: Iterable, missing_keys: Optional[List[Tuple]] self, cols_sel: Iterable, missing_keys: Optional[List[Name]]
) -> Tuple[ ) -> Tuple[
List[Tuple], List[Label],
Optional[List[Column]], Optional[List[Column]],
Optional[List[InternalField]], Optional[List[InternalField]],
bool, bool,
Optional[Tuple], Optional[Name],
]: ]:
"""Select columns by `Iterable` type key.""" """Select columns by `Iterable` type key."""
pass pass
@abstractmethod @abstractmethod
def _select_cols_else( def _select_cols_else(
self, cols_sel: Any, missing_keys: Optional[List[Tuple]] self, cols_sel: Any, missing_keys: Optional[List[Name]]
) -> Tuple[ ) -> Tuple[
List[Tuple], List[Label],
Optional[List[Column]], Optional[List[Column]],
Optional[List[InternalField]], Optional[List[InternalField]],
bool, bool,
Optional[Tuple], Optional[Name],
]: ]:
"""Select columns by other type key.""" """Select columns by other type key."""
pass pass
@ -706,7 +706,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
return return
cond, limit, remaining_index = self._select_rows(rows_sel) cond, limit, remaining_index = self._select_rows(rows_sel)
missing_keys = [] # type: Optional[List[Tuple]] missing_keys = [] # type: Optional[List[Name]]
_, data_spark_columns, _, _, _ = self._select_cols(cols_sel, missing_keys=missing_keys) _, data_spark_columns, _, _, _ = self._select_cols(cols_sel, missing_keys=missing_keys)
if cond is None: if cond is None:
@ -746,9 +746,11 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
new_fields.append(new_field) new_fields.append(new_field)
column_labels = self._internal.column_labels.copy() column_labels = self._internal.column_labels.copy()
for label in missing_keys: for missing in missing_keys:
if not is_name_like_tuple(label): if is_name_like_tuple(missing):
label = (label,) label = cast(Label, missing)
else:
label = cast(Label, (missing,))
if len(label) < self._internal.column_labels_level: if len(label) < self._internal.column_labels_level:
label = tuple( label = tuple(
list(label) + ([""] * (self._internal.column_labels_level - len(label))) list(label) + ([""] * (self._internal.column_labels_level - len(label)))
@ -1167,11 +1169,11 @@ class LocIndexer(LocIndexerLike):
def _get_from_multiindex_column( def _get_from_multiindex_column(
self, self,
key: Optional[Tuple], key: Optional[Label],
missing_keys: Optional[List[Tuple]], missing_keys: Optional[List[Name]],
labels: Optional[List[Tuple]] = None, labels: Optional[List[Tuple[Label, Label]]] = None,
recursed: int = 0, recursed: int = 0,
) -> Tuple[List[Tuple], Optional[List[Column]], List[InternalField], bool, Optional[Tuple]]: ) -> Tuple[List[Label], Optional[List[Column]], List[InternalField], bool, Optional[Name]]:
"""Select columns from multi-index columns.""" """Select columns from multi-index columns."""
assert isinstance(key, tuple) assert isinstance(key, tuple)
if labels is None: if labels is None:
@ -1196,14 +1198,14 @@ class LocIndexer(LocIndexerLike):
else: else:
returns_series = all(lbl is None or len(lbl) == 0 for _, lbl in labels) returns_series = all(lbl is None or len(lbl) == 0 for _, lbl in labels)
if returns_series: if returns_series:
labels = set(label for label, _ in labels) # type: ignore label_set = set(label for label, _ in labels)
assert len(labels) == 1 assert len(label_set) == 1
label = list(labels)[0] label = list(label_set)[0]
column_labels = [label] column_labels = [label]
data_spark_columns = [self._internal.spark_column_for(label)] data_spark_columns = [self._internal.spark_column_for(label)]
data_fields = [self._internal.field_for(label)] data_fields = [self._internal.field_for(label)]
if label is None: if label is None:
series_name = None series_name = None # type: Name
else: else:
if recursed > 0: if recursed > 0:
label = label[:-recursed] label = label[:-recursed]
@ -1219,13 +1221,13 @@ class LocIndexer(LocIndexerLike):
return column_labels, data_spark_columns, data_fields, returns_series, series_name return column_labels, data_spark_columns, data_fields, returns_series, series_name
def _select_cols_by_series( def _select_cols_by_series(
self, cols_sel: "Series", missing_keys: Optional[List[Tuple]] self, cols_sel: "Series", missing_keys: Optional[List[Name]]
) -> Tuple[ ) -> Tuple[
List[Tuple], List[Label],
Optional[List[Column]], Optional[List[Column]],
Optional[List[InternalField]], Optional[List[InternalField]],
bool, bool,
Optional[Tuple], Optional[Name],
]: ]:
column_labels = cols_sel._internal.column_labels column_labels = cols_sel._internal.column_labels
data_spark_columns = cols_sel._internal.data_spark_columns data_spark_columns = cols_sel._internal.data_spark_columns
@ -1233,28 +1235,28 @@ class LocIndexer(LocIndexerLike):
return column_labels, data_spark_columns, data_fields, True, None return column_labels, data_spark_columns, data_fields, True, None
def _select_cols_by_spark_column( def _select_cols_by_spark_column(
self, cols_sel: Column, missing_keys: Optional[List[Tuple]] self, cols_sel: Column, missing_keys: Optional[List[Name]]
) -> Tuple[ ) -> Tuple[
List[Tuple], List[Label],
Optional[List[Column]], Optional[List[Column]],
Optional[List[InternalField]], Optional[List[InternalField]],
bool, bool,
Optional[Tuple], Optional[Name],
]: ]:
column_labels = [ column_labels = [
(self._internal.spark_frame.select(cols_sel).columns[0],) (self._internal.spark_frame.select(cols_sel).columns[0],)
] # type: List[Tuple] ] # type: List[Label]
data_spark_columns = [cols_sel] data_spark_columns = [cols_sel]
return column_labels, data_spark_columns, None, True, None return column_labels, data_spark_columns, None, True, None
def _select_cols_by_slice( def _select_cols_by_slice(
self, cols_sel: slice, missing_keys: Optional[List[Tuple]] self, cols_sel: slice, missing_keys: Optional[List[Name]]
) -> Tuple[ ) -> Tuple[
List[Tuple], List[Label],
Optional[List[Column]], Optional[List[Column]],
Optional[List[InternalField]], Optional[List[InternalField]],
bool, bool,
Optional[Tuple], Optional[Name],
]: ]:
start, stop = self._psdf_or_psser.columns.slice_locs( start, stop = self._psdf_or_psser.columns.slice_locs(
start=cols_sel.start, end=cols_sel.stop start=cols_sel.start, end=cols_sel.stop
@ -1265,13 +1267,13 @@ class LocIndexer(LocIndexerLike):
return column_labels, data_spark_columns, data_fields, False, None return column_labels, data_spark_columns, data_fields, False, None
def _select_cols_by_iterable( def _select_cols_by_iterable(
self, cols_sel: Iterable, missing_keys: Optional[List[Tuple]] self, cols_sel: Iterable, missing_keys: Optional[List[Name]]
) -> Tuple[ ) -> Tuple[
List[Tuple], List[Label],
Optional[List[Column]], Optional[List[Column]],
Optional[List[InternalField]], Optional[List[InternalField]],
bool, bool,
Optional[Tuple], Optional[Name],
]: ]:
from pyspark.pandas.series import Series from pyspark.pandas.series import Series
@ -1356,13 +1358,13 @@ class LocIndexer(LocIndexerLike):
return column_labels, data_spark_columns, data_fields, False, None return column_labels, data_spark_columns, data_fields, False, None
def _select_cols_else( def _select_cols_else(
self, cols_sel: Any, missing_keys: Optional[List[Tuple]] self, cols_sel: Any, missing_keys: Optional[List[Name]]
) -> Tuple[ ) -> Tuple[
List[Tuple], List[Label],
Optional[List[Column]], Optional[List[Column]],
Optional[List[InternalField]], Optional[List[InternalField]],
bool, bool,
Optional[Tuple], Optional[Name],
]: ]:
if not is_name_like_tuple(cols_sel): if not is_name_like_tuple(cols_sel):
cols_sel = (cols_sel,) cols_sel = (cols_sel,)
@ -1552,7 +1554,7 @@ class iLocIndexer(LocIndexerLike):
) )
@lazy_property @lazy_property
def _sequence_col(self) -> Union[Any, Tuple]: def _sequence_col(self) -> str:
# Use resolved_copy to fix the natural order. # Use resolved_copy to fix the natural order.
internal = super()._internal.resolved_copy internal = super()._internal.resolved_copy
return verify_temp_column_name(internal.spark_frame, "__distributed_sequence_column__") return verify_temp_column_name(internal.spark_frame, "__distributed_sequence_column__")
@ -1692,13 +1694,13 @@ class iLocIndexer(LocIndexerLike):
) )
def _select_cols_by_series( def _select_cols_by_series(
self, cols_sel: "Series", missing_keys: Optional[List[Tuple]] self, cols_sel: "Series", missing_keys: Optional[List[Name]]
) -> Tuple[ ) -> Tuple[
List[Tuple], List[Label],
Optional[List[Column]], Optional[List[Column]],
Optional[List[InternalField]], Optional[List[InternalField]],
bool, bool,
Optional[Tuple], Optional[Name],
]: ]:
raise ValueError( raise ValueError(
"Location based indexing can only have [integer, integer slice, " "Location based indexing can only have [integer, integer slice, "
@ -1706,13 +1708,13 @@ class iLocIndexer(LocIndexerLike):
) )
def _select_cols_by_spark_column( def _select_cols_by_spark_column(
self, cols_sel: Column, missing_keys: Optional[List[Tuple]] self, cols_sel: Column, missing_keys: Optional[List[Name]]
) -> Tuple[ ) -> Tuple[
List[Tuple], List[Label],
Optional[List[Column]], Optional[List[Column]],
Optional[List[InternalField]], Optional[List[InternalField]],
bool, bool,
Optional[Tuple], Optional[Name],
]: ]:
raise ValueError( raise ValueError(
"Location based indexing can only have [integer, integer slice, " "Location based indexing can only have [integer, integer slice, "
@ -1720,13 +1722,13 @@ class iLocIndexer(LocIndexerLike):
) )
def _select_cols_by_slice( def _select_cols_by_slice(
self, cols_sel: slice, missing_keys: Optional[List[Tuple]] self, cols_sel: slice, missing_keys: Optional[List[Name]]
) -> Tuple[ ) -> Tuple[
List[Tuple], List[Label],
Optional[List[Column]], Optional[List[Column]],
Optional[List[InternalField]], Optional[List[InternalField]],
bool, bool,
Optional[Tuple], Optional[Name],
]: ]:
if all( if all(
s is None or isinstance(s, int) for s in (cols_sel.start, cols_sel.stop, cols_sel.step) s is None or isinstance(s, int) for s in (cols_sel.start, cols_sel.stop, cols_sel.step)
@ -1750,13 +1752,13 @@ class iLocIndexer(LocIndexerLike):
) )
def _select_cols_by_iterable( def _select_cols_by_iterable(
self, cols_sel: Iterable, missing_keys: Optional[List[Tuple]] self, cols_sel: Iterable, missing_keys: Optional[List[Name]]
) -> Tuple[ ) -> Tuple[
List[Tuple], List[Label],
Optional[List[Column]], Optional[List[Column]],
Optional[List[InternalField]], Optional[List[InternalField]],
bool, bool,
Optional[Tuple], Optional[Name],
]: ]:
if all(isinstance(s, bool) for s in cols_sel): if all(isinstance(s, bool) for s in cols_sel):
cols_sel = [i for i, s in enumerate(cols_sel) if s] cols_sel = [i for i, s in enumerate(cols_sel) if s]
@ -1769,13 +1771,13 @@ class iLocIndexer(LocIndexerLike):
raise TypeError("cannot perform reduce with flexible type") raise TypeError("cannot perform reduce with flexible type")
def _select_cols_else( def _select_cols_else(
self, cols_sel: Any, missing_keys: Optional[List[Tuple]] self, cols_sel: Any, missing_keys: Optional[List[Name]]
) -> Tuple[ ) -> Tuple[
List[Tuple], List[Label],
Optional[List[Column]], Optional[List[Column]],
Optional[List[InternalField]], Optional[List[InternalField]],
bool, bool,
Optional[Tuple], Optional[Name],
]: ]:
if isinstance(cols_sel, int): if isinstance(cols_sel, int):
if cols_sel > len(self._internal.column_labels): if cols_sel > len(self._internal.column_labels):

View file

@ -41,6 +41,7 @@ from pyspark.sql.types import ( # noqa: F401
# For running doctests and reference resolution in PyCharm. # For running doctests and reference resolution in PyCharm.
from pyspark import pandas as ps from pyspark import pandas as ps
from pyspark.pandas._typing import Label
if TYPE_CHECKING: if TYPE_CHECKING:
# This is required in old Python 3.5 to prevent circular reference. # This is required in old Python 3.5 to prevent circular reference.
@ -529,12 +530,12 @@ class InternalFrame(object):
self, self,
spark_frame: SparkDataFrame, spark_frame: SparkDataFrame,
index_spark_columns: Optional[List[Column]], index_spark_columns: Optional[List[Column]],
index_names: Optional[List[Optional[Tuple]]] = None, index_names: Optional[List[Optional[Label]]] = None,
index_fields: Optional[List[InternalField]] = None, index_fields: Optional[List[InternalField]] = None,
column_labels: Optional[List[Tuple]] = None, column_labels: Optional[List[Label]] = None,
data_spark_columns: Optional[List[Column]] = None, data_spark_columns: Optional[List[Column]] = None,
data_fields: Optional[List[InternalField]] = None, data_fields: Optional[List[InternalField]] = None,
column_label_names: Optional[List[Optional[Tuple]]] = None, column_label_names: Optional[List[Optional[Label]]] = None,
): ):
""" """
Create a new internal immutable DataFrame to manage Spark DataFrame, column fields and Create a new internal immutable DataFrame to manage Spark DataFrame, column fields and
@ -783,13 +784,13 @@ class InternalFrame(object):
is_name_like_tuple(index_name, check_type=True) for index_name in index_names is_name_like_tuple(index_name, check_type=True) for index_name in index_names
), index_names ), index_names
self._index_names = index_names # type: List[Optional[Tuple]] self._index_names = index_names # type: List[Optional[Label]]
# column_labels # column_labels
if column_labels is None: if column_labels is None:
self._column_labels = [ self._column_labels = [
(col,) for col in spark_frame.select(self._data_spark_columns).columns (col,) for col in spark_frame.select(self._data_spark_columns).columns
] # type: List[Tuple] ] # type: List[Label]
else: else:
assert len(column_labels) == len(self._data_spark_columns), ( assert len(column_labels) == len(self._data_spark_columns), (
len(column_labels), len(column_labels),
@ -810,7 +811,7 @@ class InternalFrame(object):
if column_label_names is None: if column_label_names is None:
self._column_label_names = [None] * column_labels_level( self._column_label_names = [None] * column_labels_level(
self._column_labels self._column_labels
) # type: List[Optional[Tuple]] ) # type: List[Optional[Label]]
else: else:
if len(self._column_labels) > 0: if len(self._column_labels) > 0:
assert len(column_label_names) == column_labels_level(self._column_labels), ( assert len(column_label_names) == column_labels_level(self._column_labels), (
@ -1027,7 +1028,7 @@ class InternalFrame(object):
False, False,
) )
def spark_column_for(self, label: Tuple) -> Column: def spark_column_for(self, label: Label) -> Column:
"""Return Spark Column for the given column label.""" """Return Spark Column for the given column label."""
column_labels_to_scol = dict(zip(self.column_labels, self.data_spark_columns)) column_labels_to_scol = dict(zip(self.column_labels, self.data_spark_columns))
if label in column_labels_to_scol: if label in column_labels_to_scol:
@ -1035,28 +1036,28 @@ class InternalFrame(object):
else: else:
raise KeyError(name_like_string(label)) raise KeyError(name_like_string(label))
def spark_column_name_for(self, label_or_scol: Union[Tuple, Column]) -> str: def spark_column_name_for(self, label_or_scol: Union[Label, Column]) -> str:
"""Return the actual Spark column name for the given column label.""" """Return the actual Spark column name for the given column label."""
if isinstance(label_or_scol, Column): if isinstance(label_or_scol, Column):
return self.spark_frame.select(label_or_scol).columns[0] return self.spark_frame.select(label_or_scol).columns[0]
else: else:
return self.field_for(label_or_scol).name return self.field_for(label_or_scol).name
def spark_type_for(self, label_or_scol: Union[Tuple, Column]) -> DataType: def spark_type_for(self, label_or_scol: Union[Label, Column]) -> DataType:
"""Return DataType for the given column label.""" """Return DataType for the given column label."""
if isinstance(label_or_scol, Column): if isinstance(label_or_scol, Column):
return self.spark_frame.select(label_or_scol).schema[0].dataType return self.spark_frame.select(label_or_scol).schema[0].dataType
else: else:
return self.field_for(label_or_scol).spark_type return self.field_for(label_or_scol).spark_type
def spark_column_nullable_for(self, label_or_scol: Union[Tuple, Column]) -> bool: def spark_column_nullable_for(self, label_or_scol: Union[Label, Column]) -> bool:
"""Return nullability for the given column label.""" """Return nullability for the given column label."""
if isinstance(label_or_scol, Column): if isinstance(label_or_scol, Column):
return self.spark_frame.select(label_or_scol).schema[0].nullable return self.spark_frame.select(label_or_scol).schema[0].nullable
else: else:
return self.field_for(label_or_scol).nullable return self.field_for(label_or_scol).nullable
def field_for(self, label: Tuple) -> InternalField: def field_for(self, label: Label) -> InternalField:
"""Return InternalField for the given column label.""" """Return InternalField for the given column label."""
column_labels_to_fields = dict(zip(self.column_labels, self.data_fields)) column_labels_to_fields = dict(zip(self.column_labels, self.data_fields))
if label in column_labels_to_fields: if label in column_labels_to_fields:
@ -1105,7 +1106,7 @@ class InternalFrame(object):
] ]
@property @property
def index_names(self) -> List[Optional[Tuple]]: def index_names(self) -> List[Optional[Label]]:
"""Return the managed index names.""" """Return the managed index names."""
return self._index_names return self._index_names
@ -1115,7 +1116,7 @@ class InternalFrame(object):
return len(self._index_names) return len(self._index_names)
@property @property
def column_labels(self) -> List[Tuple]: def column_labels(self) -> List[Label]:
"""Return the managed column index.""" """Return the managed column index."""
return self._column_labels return self._column_labels
@ -1125,7 +1126,7 @@ class InternalFrame(object):
return len(self._column_label_names) return len(self._column_label_names)
@property @property
def column_label_names(self) -> List[Optional[Tuple]]: def column_label_names(self) -> List[Optional[Label]]:
"""Return names of the index levels.""" """Return names of the index levels."""
return self._column_label_names return self._column_label_names
@ -1197,10 +1198,10 @@ class InternalFrame(object):
pdf: pd.DataFrame, pdf: pd.DataFrame,
*, *,
index_columns: List[str], index_columns: List[str],
index_names: List[Tuple], index_names: List[Label],
data_columns: List[str], data_columns: List[str],
column_labels: List[Tuple], column_labels: List[Label],
column_label_names: List[Tuple], column_label_names: List[Label],
fields: List[InternalField] = None, fields: List[InternalField] = None,
) -> pd.DataFrame: ) -> pd.DataFrame:
""" """
@ -1335,9 +1336,9 @@ class InternalFrame(object):
self, self,
scols_or_pssers: Sequence[Union[Column, "Series"]], scols_or_pssers: Sequence[Union[Column, "Series"]],
*, *,
column_labels: Optional[List[Tuple]] = None, column_labels: Optional[List[Label]] = None,
data_fields: Optional[List[InternalField]] = None, data_fields: Optional[List[InternalField]] = None,
column_label_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue, column_label_names: Union[Optional[List[Optional[Label]]], _NoValueType] = _NoValue,
keep_order: bool = True, keep_order: bool = True,
) -> "InternalFrame": ) -> "InternalFrame":
""" """
@ -1439,7 +1440,7 @@ class InternalFrame(object):
def with_new_spark_column( def with_new_spark_column(
self, self,
column_label: Tuple, column_label: Label,
scol: Column, scol: Column,
*, *,
field: Optional[InternalField] = None, field: Optional[InternalField] = None,
@ -1465,7 +1466,7 @@ class InternalFrame(object):
data_spark_columns, data_fields=data_fields, keep_order=keep_order data_spark_columns, data_fields=data_fields, keep_order=keep_order
) )
def select_column(self, column_label: Tuple) -> "InternalFrame": def select_column(self, column_label: Label) -> "InternalFrame":
""" """
Copy the immutable InternalFrame with the specified column. Copy the immutable InternalFrame with the specified column.
@ -1486,12 +1487,12 @@ class InternalFrame(object):
*, *,
spark_frame: Union[SparkDataFrame, _NoValueType] = _NoValue, spark_frame: Union[SparkDataFrame, _NoValueType] = _NoValue,
index_spark_columns: Union[List[Column], _NoValueType] = _NoValue, index_spark_columns: Union[List[Column], _NoValueType] = _NoValue,
index_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue, index_names: Union[Optional[List[Optional[Label]]], _NoValueType] = _NoValue,
index_fields: Union[Optional[List[InternalField]], _NoValueType] = _NoValue, index_fields: Union[Optional[List[InternalField]], _NoValueType] = _NoValue,
column_labels: Union[Optional[List[Tuple]], _NoValueType] = _NoValue, column_labels: Union[Optional[List[Label]], _NoValueType] = _NoValue,
data_spark_columns: Union[Optional[List[Column]], _NoValueType] = _NoValue, data_spark_columns: Union[Optional[List[Column]], _NoValueType] = _NoValue,
data_fields: Union[Optional[List[InternalField]], _NoValueType] = _NoValue, data_fields: Union[Optional[List[InternalField]], _NoValueType] = _NoValue,
column_label_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue, column_label_names: Union[Optional[List[Optional[Label]]], _NoValueType] = _NoValue,
) -> "InternalFrame": ) -> "InternalFrame":
""" """
Copy the immutable InternalFrame. Copy the immutable InternalFrame.
@ -1530,12 +1531,12 @@ class InternalFrame(object):
return InternalFrame( return InternalFrame(
spark_frame=cast(SparkDataFrame, spark_frame), spark_frame=cast(SparkDataFrame, spark_frame),
index_spark_columns=cast(List[Column], index_spark_columns), index_spark_columns=cast(List[Column], index_spark_columns),
index_names=cast(Optional[List[Optional[Tuple]]], index_names), index_names=cast(Optional[List[Optional[Label]]], index_names),
index_fields=cast(Optional[List[InternalField]], index_fields), index_fields=cast(Optional[List[InternalField]], index_fields),
column_labels=cast(Optional[List[Tuple]], column_labels), column_labels=cast(Optional[List[Label]], column_labels),
data_spark_columns=cast(Optional[List[Column]], data_spark_columns), data_spark_columns=cast(Optional[List[Column]], data_spark_columns),
data_fields=cast(Optional[List[InternalField]], data_fields), data_fields=cast(Optional[List[InternalField]], data_fields),
column_label_names=cast(Optional[List[Optional[Tuple]]], column_label_names), column_label_names=cast(Optional[List[Optional[Label]]], column_label_names),
) )
@staticmethod @staticmethod
@ -1548,16 +1549,16 @@ class InternalFrame(object):
index_names = [ index_names = [
name if name is None or isinstance(name, tuple) else (name,) for name in pdf.index.names name if name is None or isinstance(name, tuple) else (name,) for name in pdf.index.names
] ] # type: List[Optional[Label]]
columns = pdf.columns columns = pdf.columns
if isinstance(columns, pd.MultiIndex): if isinstance(columns, pd.MultiIndex):
column_labels = columns.tolist() column_labels = columns.tolist() # type: List[Label]
else: else:
column_labels = [(col,) for col in columns] column_labels = [(col,) for col in columns]
column_label_names = [ column_label_names = [
name if name is None or isinstance(name, tuple) else (name,) for name in columns.names name if name is None or isinstance(name, tuple) else (name,) for name in columns.names
] ] # type: List[Optional[Label]]
( (
pdf, pdf,

View file

@ -24,6 +24,7 @@ import pyspark
from pyspark.ml.feature import VectorAssembler from pyspark.ml.feature import VectorAssembler
from pyspark.ml.stat import Correlation from pyspark.ml.stat import Correlation
from pyspark.pandas._typing import Label
from pyspark.pandas.utils import column_labels_level from pyspark.pandas.utils import column_labels_level
if TYPE_CHECKING: if TYPE_CHECKING:
@ -62,7 +63,7 @@ def corr(psdf: "ps.DataFrame", method: str = "pearson") -> pd.DataFrame:
return pd.DataFrame(arr, columns=idx, index=idx) return pd.DataFrame(arr, columns=idx, index=idx)
def to_numeric_df(psdf: "ps.DataFrame") -> Tuple[pyspark.sql.DataFrame, List[Tuple]]: def to_numeric_df(psdf: "ps.DataFrame") -> Tuple[pyspark.sql.DataFrame, List[Label]]:
""" """
Takes a dataframe and turns it into a dataframe containing a single numerical Takes a dataframe and turns it into a dataframe containing a single numerical
vector of doubles. This dataframe has a single field called '_1'. vector of doubles. This dataframe has a single field called '_1'.

View file

@ -18,13 +18,14 @@
""" """
MLflow-related functions to load models and apply them to pandas-on-Spark dataframes. MLflow-related functions to load models and apply them to pandas-on-Spark dataframes.
""" """
from typing import List, Tuple, Union # noqa: F401 (SPARK-34943) from typing import List, Union # noqa: F401 (SPARK-34943)
from pyspark.sql.types import DataType from pyspark.sql.types import DataType
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from typing import Any from typing import Any
from pyspark.pandas._typing import Label # noqa: F401 (SPARK-34943)
from pyspark.pandas.utils import lazy_property, default_session from pyspark.pandas.utils import lazy_property, default_session
from pyspark.pandas.frame import DataFrame from pyspark.pandas.frame import DataFrame
from pyspark.pandas.series import Series, first_series from pyspark.pandas.series import Series, first_series
@ -99,7 +100,7 @@ class PythonModelWrapper(object):
# return_col = self._model_udf(s) # return_col = self._model_udf(s)
column_labels = [ column_labels = [
(col,) for col in data._internal.spark_frame.select(return_col).columns (col,) for col in data._internal.spark_frame.select(return_col).columns
] # type: List[Tuple] ] # type: List[Label]
internal = data._internal.copy( internal = data._internal.copy(
column_labels=column_labels, data_spark_columns=[return_col], data_fields=None column_labels=column_labels, data_spark_columns=[return_col], data_fields=None
) )

View file

@ -65,7 +65,7 @@ from pyspark.sql.types import (
) )
from pyspark import pandas as ps # noqa: F401 from pyspark import pandas as ps # noqa: F401
from pyspark.pandas._typing import Axis, Dtype from pyspark.pandas._typing import Axis, Dtype, Label, Name
from pyspark.pandas.base import IndexOpsMixin from pyspark.pandas.base import IndexOpsMixin
from pyspark.pandas.utils import ( from pyspark.pandas.utils import (
align_diff_frames, align_diff_frames,
@ -398,7 +398,7 @@ def read_csv(
if col not in column_labels: if col not in column_labels:
raise KeyError(col) raise KeyError(col)
index_spark_column_names = [column_labels[col] for col in index_col] index_spark_column_names = [column_labels[col] for col in index_col]
index_names = [(col,) for col in index_col] # type: List[Tuple] index_names = [(col,) for col in index_col] # type: List[Label]
column_labels = OrderedDict( column_labels = OrderedDict(
(label, col) for label, col in column_labels.items() if label not in index_col (label, col) for label, col in column_labels.items() if label not in index_col
) )
@ -1818,7 +1818,7 @@ def get_dummies(
prefix: Optional[Union[str, List[str], Dict[str, str]]] = None, prefix: Optional[Union[str, List[str], Dict[str, str]]] = None,
prefix_sep: str = "_", prefix_sep: str = "_",
dummy_na: bool = False, dummy_na: bool = False,
columns: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, columns: Optional[Union[Name, List[Name]]] = None,
sparse: bool = False, sparse: bool = False,
drop_first: bool = False, drop_first: bool = False,
dtype: Optional[Union[str, Dtype]] = None, dtype: Optional[Union[str, Dtype]] = None,
@ -2443,8 +2443,8 @@ def concat(
def melt( def melt(
frame: DataFrame, frame: DataFrame,
id_vars: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, id_vars: Optional[Union[Name, List[Name]]] = None,
value_vars: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, value_vars: Optional[Union[Name, List[Name]]] = None,
var_name: Optional[Union[str, List[str]]] = None, var_name: Optional[Union[str, List[str]]] = None,
value_name: str = "value", value_name: str = "value",
) -> DataFrame: ) -> DataFrame:
@ -2616,9 +2616,9 @@ def merge(
obj: DataFrame, obj: DataFrame,
right: DataFrame, right: DataFrame,
how: str = "inner", how: str = "inner",
on: Union[Any, List[Any], Tuple, List[Tuple]] = None, on: Optional[Union[Name, List[Name]]] = None,
left_on: Union[Any, List[Any], Tuple, List[Tuple]] = None, left_on: Optional[Union[Name, List[Name]]] = None,
right_on: Union[Any, List[Any], Tuple, List[Tuple]] = None, right_on: Optional[Union[Name, List[Name]]] = None,
left_index: bool = False, left_index: bool = False,
right_index: bool = False, right_index: bool = False,
suffixes: Tuple[str, str] = ("_x", "_y"), suffixes: Tuple[str, str] = ("_x", "_y"),
@ -2919,7 +2919,7 @@ def read_orc(
def _get_index_map( def _get_index_map(
sdf: SparkDataFrame, index_col: Optional[Union[str, List[str]]] = None sdf: SparkDataFrame, index_col: Optional[Union[str, List[str]]] = None
) -> Tuple[Optional[List[Column]], Optional[List[Tuple]]]: ) -> Tuple[Optional[List[Column]], Optional[List[Label]]]:
if index_col is not None: if index_col is not None:
if isinstance(index_col, str): if isinstance(index_col, str):
index_col = [index_col] index_col = [index_col]
@ -2930,7 +2930,7 @@ def _get_index_map(
index_spark_columns = [ index_spark_columns = [
scol_for(sdf, col) for col in index_col scol_for(sdf, col) for col in index_col
] # type: Optional[List[Column]] ] # type: Optional[List[Column]]
index_names = [(col,) for col in index_col] # type: Optional[List[Tuple]] index_names = [(col,) for col in index_col] # type: Optional[List[Label]]
else: else:
index_spark_columns = None index_spark_columns = None
index_names = None index_names = None

View file

@ -67,7 +67,7 @@ from pyspark.sql.types import (
from pyspark.sql.window import Window from pyspark.sql.window import Window
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
from pyspark.pandas._typing import Axis, Dtype, Scalar, T from pyspark.pandas._typing import Axis, Dtype, Label, Name, Scalar, T
from pyspark.pandas.accessors import PandasOnSparkSeriesMethods from pyspark.pandas.accessors import PandasOnSparkSeriesMethods
from pyspark.pandas.categorical import CategoricalAccessor from pyspark.pandas.categorical import CategoricalAccessor
from pyspark.pandas.config import get_option from pyspark.pandas.config import get_option
@ -410,7 +410,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
assert not fastpath assert not fastpath
self._anchor = data # type: DataFrame self._anchor = data # type: DataFrame
self._col_label = index # type: Tuple self._col_label = index # type: Label
else: else:
if isinstance(data, pd.Series): if isinstance(data, pd.Series):
assert index is None assert index is None
@ -441,7 +441,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
return self._psdf._internal.select_column(self._column_label) return self._psdf._internal.select_column(self._column_label)
@property @property
def _column_label(self) -> Optional[Tuple]: def _column_label(self) -> Optional[Label]:
return self._col_label return self._col_label
def _update_anchor(self, psdf: DataFrame) -> None: def _update_anchor(self, psdf: DataFrame) -> None:
@ -1042,7 +1042,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
return (len(self),) return (len(self),)
@property @property
def name(self) -> Union[Any, Tuple]: def name(self) -> Name:
"""Return name of the Series.""" """Return name of the Series."""
name = self._column_label name = self._column_label
if name is not None and len(name) == 1: if name is not None and len(name) == 1:
@ -1051,12 +1051,12 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
return name return name
@name.setter @name.setter
def name(self, name: Union[Any, Tuple]) -> None: def name(self, name: Name) -> None:
self.rename(name, inplace=True) self.rename(name, inplace=True)
# TODO: Functionality and documentation should be matched. Currently, changing index labels # TODO: Functionality and documentation should be matched. Currently, changing index labels
# taking dictionary and function to change index are not supported. # taking dictionary and function to change index are not supported.
def rename(self, index: Optional[Union[Any, Tuple]] = None, **kwargs: Any) -> "Series": def rename(self, index: Optional[Name] = None, **kwargs: Any) -> "Series":
""" """
Alter Series name. Alter Series name.
@ -1227,9 +1227,9 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
def reset_index( def reset_index(
self, self,
level: Optional[Union[int, Any, Tuple, Sequence[Union[int, Any, Tuple]]]] = None, level: Optional[Union[int, Name, Sequence[Union[int, Name]]]] = None,
drop: bool = False, drop: bool = False,
name: Optional[Union[Any, Tuple]] = None, name: Optional[Name] = None,
inplace: bool = False, inplace: bool = False,
) -> Optional[Union["Series", DataFrame]]: ) -> Optional[Union["Series", DataFrame]]:
""" """
@ -1324,7 +1324,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
else: else:
return psdf return psdf
def to_frame(self, name: Optional[Union[Any, Tuple]] = None) -> DataFrame: def to_frame(self, name: Optional[Name] = None) -> DataFrame:
""" """
Convert Series to DataFrame. Convert Series to DataFrame.
@ -1491,13 +1491,13 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
def to_latex( def to_latex(
self, self,
buf: Optional[IO[str]] = None, buf: Optional[IO[str]] = None,
columns: Optional[List[Union[Any, Tuple]]] = None, columns: Optional[List[Name]] = None,
col_space: Optional[int] = None, col_space: Optional[int] = None,
header: bool = True, header: bool = True,
index: bool = True, index: bool = True,
na_rep: str = "NaN", na_rep: str = "NaN",
formatters: Optional[ formatters: Optional[
Union[List[Callable[[Any], str]], Dict[Union[Any, Tuple], Callable[[Any], str]]] Union[List[Callable[[Any], str]], Dict[Name, Callable[[Any], str]]]
] = None, ] = None,
float_format: Optional[Callable[[float], str]] = None, float_format: Optional[Callable[[float], str]] = None,
sparsify: Optional[bool] = None, sparsify: Optional[bool] = None,
@ -2062,8 +2062,8 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
def drop( def drop(
self, self,
labels: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, labels: Optional[Union[Name, List[Name]]] = None,
index: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, index: Optional[Union[Name, List[Name]]] = None,
level: Optional[int] = None, level: Optional[int] = None,
) -> "Series": ) -> "Series":
""" """
@ -2177,8 +2177,8 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
def _drop( def _drop(
self, self,
labels: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, labels: Optional[Union[Name, List[Name]]] = None,
index: Optional[Union[Any, Tuple, List[Union[Any, Tuple]]]] = None, index: Optional[Union[Name, List[Name]]] = None,
level: Optional[int] = None, level: Optional[int] = None,
) -> DataFrame: ) -> DataFrame:
if labels is not None: if labels is not None:
@ -2193,7 +2193,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
raise ValueError("'level' should be less than the number of indexes") raise ValueError("'level' should be less than the number of indexes")
if is_name_like_tuple(index): # type: ignore if is_name_like_tuple(index): # type: ignore
index_list = [cast(Tuple, index)] index_list = [cast(Label, index)]
elif is_name_like_value(index): elif is_name_like_value(index):
index_list = [(index,)] index_list = [(index,)]
elif all(is_name_like_value(idxes, allow_tuple=False) for idxes in index): elif all(is_name_like_value(idxes, allow_tuple=False) for idxes in index):
@ -2205,7 +2205,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
"that contain index names" "that contain index names"
) )
else: else:
index_list = cast(List[Tuple], index) index_list = cast(List[Label], index)
drop_index_scols = [] drop_index_scols = []
for idxes in index_list: for idxes in index_list:
@ -2602,7 +2602,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
return first_series(psdf) return first_series(psdf)
def swaplevel( def swaplevel(
self, i: Union[int, Any, Tuple] = -2, j: Union[int, Any, Tuple] = -1, copy: bool = True self, i: Union[int, Name] = -2, j: Union[int, Name] = -1, copy: bool = True
) -> "Series": ) -> "Series":
""" """
Swap levels i and j in a MultiIndex. Swap levels i and j in a MultiIndex.
@ -3903,7 +3903,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
else: else:
return tuple(values) return tuple(values)
def pop(self, item: Union[Any, Tuple]) -> Union["Series", Scalar]: def pop(self, item: Name) -> Union["Series", Scalar]:
""" """
Return item and drop from series. Return item and drop from series.
@ -4671,7 +4671,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
""" """
return self.where(cast(Series, ~cond), other) return self.where(cast(Series, ~cond), other)
def xs(self, key: Union[Any, Tuple], level: Optional[int] = None) -> "Series": def xs(self, key: Name, level: Optional[int] = None) -> "Series":
""" """
Return cross-section from the Series. Return cross-section from the Series.
@ -5278,7 +5278,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
""" """
return self.head(2)._to_internal_pandas().item() return self.head(2)._to_internal_pandas().item()
def iteritems(self) -> Iterable[Tuple[Union[Any, Tuple], Any]]: def iteritems(self) -> Iterable[Tuple[Name, Any]]:
""" """
Lazily iterate over (index, value) tuples. Lazily iterate over (index, value) tuples.
@ -5311,7 +5311,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
internal_index_columns = self._internal.index_spark_column_names internal_index_columns = self._internal.index_spark_column_names
internal_data_column = self._internal.data_spark_column_names[0] internal_data_column = self._internal.data_spark_column_names[0]
def extract_kv_from_spark_row(row: Row) -> Tuple[Union[Any, Tuple], Any]: def extract_kv_from_spark_row(row: Row) -> Tuple[Name, Any]:
k = ( k = (
row[internal_index_columns[0]] row[internal_index_columns[0]]
if len(internal_index_columns) == 1 if len(internal_index_columns) == 1
@ -5325,11 +5325,11 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
): ):
yield k, v yield k, v
def items(self) -> Iterable[Tuple[Union[Any, Tuple], Any]]: def items(self) -> Iterable[Tuple[Name, Any]]:
"""This is an alias of ``iteritems``.""" """This is an alias of ``iteritems``."""
return self.iteritems() return self.iteritems()
def droplevel(self, level: Union[int, Any, Tuple, List[Union[int, Any, Tuple]]]) -> "Series": def droplevel(self, level: Union[int, Name, List[Union[int, Name]]]) -> "Series":
""" """
Return Series with requested index level(s) removed. Return Series with requested index level(s) removed.
@ -6213,7 +6213,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
# Override the `groupby` to specify the actual return type annotation. # Override the `groupby` to specify the actual return type annotation.
def groupby( def groupby(
self, self,
by: Union[Any, Tuple, "Series", List[Union[Any, Tuple, "Series"]]], by: Union[Name, "Series", List[Union[Name, "Series"]]],
axis: Axis = 0, axis: Axis = 0,
as_index: bool = True, as_index: bool = True,
dropna: bool = True, dropna: bool = True,
@ -6225,7 +6225,7 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
groupby.__doc__ = Frame.groupby.__doc__ groupby.__doc__ = Frame.groupby.__doc__
def _build_groupby( def _build_groupby(
self, by: List[Union["Series", Tuple]], as_index: bool, dropna: bool self, by: List[Union["Series", Label]], as_index: bool, dropna: bool
) -> "SeriesGroupBy": ) -> "SeriesGroupBy":
from pyspark.pandas.groupby import SeriesGroupBy from pyspark.pandas.groupby import SeriesGroupBy

View file

@ -46,7 +46,7 @@ from pandas.api.types import is_list_like
# For running doctests and reference resolution in PyCharm. # For running doctests and reference resolution in PyCharm.
from pyspark import pandas as ps # noqa: F401 from pyspark import pandas as ps # noqa: F401
from pyspark.pandas._typing import Axis, DataFrameOrSeries from pyspark.pandas._typing import Axis, Label, Name, DataFrameOrSeries
from pyspark.pandas.spark import functions as SF from pyspark.pandas.spark import functions as SF
from pyspark.pandas.typedef.typehints import as_spark_type from pyspark.pandas.typedef.typehints import as_spark_type
@ -279,7 +279,7 @@ def combine_frames(
level = max(this_internal.column_labels_level, that_internal.column_labels_level) level = max(this_internal.column_labels_level, that_internal.column_labels_level)
def fill_label(label: Optional[Tuple]) -> List: def fill_label(label: Optional[Label]) -> List:
if label is None: if label is None:
return ([""] * (level - 1)) + [None] return ([""] * (level - 1)) + [None]
else: else:
@ -289,7 +289,7 @@ def combine_frames(
tuple(["this"] + fill_label(label)) for label in this_internal.column_labels tuple(["this"] + fill_label(label)) for label in this_internal.column_labels
] + [tuple(["that"] + fill_label(label)) for label in that_internal.column_labels] ] + [tuple(["that"] + fill_label(label)) for label in that_internal.column_labels]
column_label_names = ( column_label_names = (
cast(List[Optional[Tuple]], [None]) * (1 + level - this_internal.column_labels_level) cast(List[Optional[Label]], [None]) * (1 + level - this_internal.column_labels_level)
) + this_internal.column_label_names ) + this_internal.column_label_names
return DataFrame( return DataFrame(
InternalFrame( InternalFrame(
@ -309,7 +309,7 @@ def combine_frames(
def align_diff_frames( def align_diff_frames(
resolve_func: Callable[ resolve_func: Callable[
["DataFrame", List[Tuple], List[Tuple]], Iterator[Tuple["Series", Tuple]] ["DataFrame", List[Label], List[Label]], Iterator[Tuple["Series", Label]]
], ],
this: "DataFrame", this: "DataFrame",
that: "DataFrame", that: "DataFrame",
@ -385,11 +385,11 @@ def align_diff_frames(
# 2. Apply the given function to transform the columns in a batch and keep the new columns. # 2. Apply the given function to transform the columns in a batch and keep the new columns.
combined_column_labels = combined._internal.column_labels combined_column_labels = combined._internal.column_labels
that_columns_to_apply = [] # type: List[Tuple] that_columns_to_apply = [] # type: List[Label]
this_columns_to_apply = [] # type: List[Tuple] this_columns_to_apply = [] # type: List[Label]
additional_that_columns = [] # type: List[Tuple] additional_that_columns = [] # type: List[Label]
columns_to_keep = [] # type: List[Union[Series, Column]] columns_to_keep = [] # type: List[Union[Series, Column]]
column_labels_to_keep = [] # type: List[Tuple] column_labels_to_keep = [] # type: List[Label]
for combined_label in combined_column_labels: for combined_label in combined_column_labels:
for common_label in common_column_labels: for common_label in common_column_labels:
@ -424,7 +424,7 @@ def align_diff_frames(
*resolve_func(combined, this_columns_to_apply, that_columns_to_apply) *resolve_func(combined, this_columns_to_apply, that_columns_to_apply)
) )
columns_applied = list(psser_set) # type: List[Union[Series, Column]] columns_applied = list(psser_set) # type: List[Union[Series, Column]]
column_labels_applied = list(column_labels_set) # type: List[Tuple] column_labels_applied = list(column_labels_set) # type: List[Label]
else: else:
columns_applied = [] columns_applied = []
column_labels_applied = [] column_labels_applied = []
@ -592,7 +592,7 @@ def scol_for(sdf: SparkDataFrame, column_name: str) -> Column:
return sdf["`{}`".format(column_name)] return sdf["`{}`".format(column_name)]
def column_labels_level(column_labels: List[Tuple]) -> int: def column_labels_level(column_labels: List[Label]) -> int:
"""Return the level of the column index.""" """Return the level of the column index."""
if len(column_labels) == 0: if len(column_labels) == 0:
return 1 return 1
@ -602,7 +602,7 @@ def column_labels_level(column_labels: List[Tuple]) -> int:
return list(levels)[0] return list(levels)[0]
def name_like_string(name: Optional[Union[Any, Tuple]]) -> str: def name_like_string(name: Optional[Name]) -> str:
""" """
Return the name-like strings from str or tuple of str Return the name-like strings from str or tuple of str
@ -621,12 +621,12 @@ def name_like_string(name: Optional[Union[Any, Tuple]]) -> str:
'(a, b, c)' '(a, b, c)'
""" """
if name is None: if name is None:
name = ("__none__",) label = ("__none__",) # type: Label
elif is_list_like(name): elif is_list_like(name):
name = tuple([str(n) for n in name]) label = tuple([str(n) for n in name])
else: else:
name = (str(name),) label = (str(name),)
return ("(%s)" % ", ".join(name)) if len(name) > 1 else name[0] return ("(%s)" % ", ".join(label)) if len(label) > 1 else label[0]
def is_name_like_tuple(value: Any, allow_none: bool = True, check_type: bool = False) -> bool: def is_name_like_tuple(value: Any, allow_none: bool = True, check_type: bool = False) -> bool:
@ -760,15 +760,13 @@ def verify_temp_column_name(df: SparkDataFrame, column_name_or_label: str) -> st
@overload @overload
def verify_temp_column_name( def verify_temp_column_name(df: "DataFrame", column_name_or_label: Name) -> Label:
df: "DataFrame", column_name_or_label: Union[Any, Tuple]
) -> Union[Any, Tuple]:
... ...
def verify_temp_column_name( def verify_temp_column_name(
df: Union["DataFrame", SparkDataFrame], column_name_or_label: Union[Any, Tuple] df: Union["DataFrame", SparkDataFrame], column_name_or_label: Union[str, Name]
) -> Union[Any, Tuple]: ) -> Union[str, Label]:
""" """
Verify that the given column name does not exist in the given pandas-on-Spark or Verify that the given column name does not exist in the given pandas-on-Spark or
Spark DataFrame. Spark DataFrame.