[SPARK-36711][PYTHON] Support multi-index in new syntax

### What changes were proposed in this pull request?
Support multi-index in new syntax to specify index data type

### Why are the changes needed?
Support multi-index in new syntax to specify index data type

https://issues.apache.org/jira/browse/SPARK-36707

### Does this PR introduce _any_ user-facing change?
After this PR user can use

``` python
>>> ps.DataFrame[[int, int],[int, int]]
typing.Tuple[pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType]

>>> arrays = [[1, 1, 2], ['red', 'blue', 'red']]
>>> idx = pd.MultiIndex.from_arrays(arrays, names=('number', 'color'))
>>> pdf = pd.DataFrame([[1,2,3],[2,3,4],[4,5,6]], index=idx, columns=["a", "b", "c"])
>>> ps.DataFrame[pdf.index.dtypes, pdf.dtypes]
typing.Tuple[pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType]

>>> ps.DataFrame[[("index", int), ("index-2", int)], [("id", int), ("A", int)]]
typing.Tuple[pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType]

>>> ps.DataFrame[zip(pdf.index.names, pdf.index.dtypes), zip(pdf.columns, pdf.dtypes)]
typing.Tuple[pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType]

```

### How was this patch tested?
exist tests

Closes #34176 from dchvn/SPARK-36711.

Authored-by: dchvn nguyen <dgd_contributor@viettel.com.vn>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
dchvn nguyen 2021-10-05 12:45:16 +09:00 committed by Hyukjin Kwon
parent fa1805db48
commit d6786e036d
5 changed files with 252 additions and 121 deletions

View file

@ -34,7 +34,7 @@ from pyspark.pandas.internal import (
InternalFrame,
SPARK_INDEX_NAME_FORMAT,
SPARK_DEFAULT_SERIES_NAME,
SPARK_DEFAULT_INDEX_NAME,
SPARK_INDEX_NAME_PATTERN,
)
from pyspark.pandas.typedef import infer_return_type, DataFrameType, ScalarType, SeriesType
from pyspark.pandas.utils import (
@ -384,8 +384,8 @@ class PandasOnSparkFrameMethods(object):
"The given function should specify a frame as its type "
"hints; however, the return type was %s." % return_sig
)
index_field = cast(DataFrameType, return_type).index_field
should_retain_index = index_field is not None
index_fields = cast(DataFrameType, return_type).index_fields
should_retain_index = index_fields is not None
return_schema = cast(DataFrameType, return_type).spark_type
output_func = GroupBy._make_pandas_df_builder_func(
@ -397,12 +397,19 @@ class PandasOnSparkFrameMethods(object):
index_spark_columns = None
index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
index_fields = None
if should_retain_index:
index_spark_columns = [scol_for(sdf, index_field.struct_field.name)]
index_fields = [index_field]
if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME:
index_names = [(index_field.struct_field.name,)]
index_spark_columns = [
scol_for(sdf, index_field.struct_field.name) for index_field in index_fields
]
if not any(
[
SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
for index_field in index_fields
]
):
index_names = [(index_field.struct_field.name,) for index_field in index_fields]
internal = InternalFrame(
spark_frame=sdf,
index_names=index_names,
@ -680,17 +687,19 @@ class PandasOnSparkFrameMethods(object):
)
return first_series(DataFrame(internal))
else:
index_field = cast(DataFrameType, return_type).index_field
index_field = (
index_field.normalize_spark_type() if index_field is not None else None
index_fields = cast(DataFrameType, return_type).index_fields
index_fields = (
[index_field.normalize_spark_type() for index_field in index_fields]
if index_fields is not None
else None
)
data_fields = [
field.normalize_spark_type()
for field in cast(DataFrameType, return_type).data_fields
]
normalized_fields = ([index_field] if index_field is not None else []) + data_fields
normalized_fields = (index_fields if index_fields is not None else []) + data_fields
return_schema = StructType([field.struct_field for field in normalized_fields])
should_retain_index = index_field is not None
should_retain_index = index_fields is not None
self_applied = DataFrame(self._psdf._internal.resolved_copy)
@ -711,12 +720,21 @@ class PandasOnSparkFrameMethods(object):
index_spark_columns = None
index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
index_fields = None
if should_retain_index:
index_spark_columns = [scol_for(sdf, index_field.struct_field.name)]
index_fields = [index_field]
if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME:
index_names = [(index_field.struct_field.name,)]
index_spark_columns = [
scol_for(sdf, index_field.struct_field.name) for index_field in index_fields
]
if not any(
[
SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
for index_field in index_fields
]
):
index_names = [
(index_field.struct_field.name,) for index_field in index_fields
]
internal = InternalFrame(
spark_frame=sdf,
index_names=index_names,

View file

@ -114,6 +114,7 @@ from pyspark.pandas.internal import (
SPARK_INDEX_NAME_FORMAT,
SPARK_DEFAULT_INDEX_NAME,
SPARK_DEFAULT_SERIES_NAME,
SPARK_INDEX_NAME_PATTERN,
)
from pyspark.pandas.missing.frame import _MissingPandasLikeDataFrame
from pyspark.pandas.ml import corr
@ -2511,7 +2512,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
return_type = infer_return_type(func)
require_index_axis = isinstance(return_type, SeriesType)
require_column_axis = isinstance(return_type, DataFrameType)
index_field = None
index_fields = None
if require_index_axis:
if axis != 0:
@ -2536,8 +2537,8 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
"hints when axis is 1 or 'column'; however, the return type "
"was %s" % return_sig
)
index_field = cast(DataFrameType, return_type).index_field
should_retain_index = index_field is not None
index_fields = cast(DataFrameType, return_type).index_fields
should_retain_index = index_fields is not None
data_fields = cast(DataFrameType, return_type).data_fields
return_schema = cast(DataFrameType, return_type).spark_type
else:
@ -2565,12 +2566,19 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
index_spark_columns = None
index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
index_fields = None
if should_retain_index:
index_spark_columns = [scol_for(sdf, index_field.struct_field.name)]
index_fields = [index_field]
if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME:
index_names = [(index_field.struct_field.name,)]
index_spark_columns = [
scol_for(sdf, index_field.struct_field.name) for index_field in index_fields
]
if not any(
[
SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
for index_field in index_fields
]
):
index_names = [(index_field.struct_field.name,) for index_field in index_fields]
internal = InternalFrame(
spark_frame=sdf,
index_names=index_names,

View file

@ -76,7 +76,7 @@ from pyspark.pandas.internal import (
NATURAL_ORDER_COLUMN_NAME,
SPARK_INDEX_NAME_FORMAT,
SPARK_DEFAULT_SERIES_NAME,
SPARK_DEFAULT_INDEX_NAME,
SPARK_INDEX_NAME_PATTERN,
)
from pyspark.pandas.missing.groupby import (
MissingPandasLikeDataFrameGroupBy,
@ -1252,9 +1252,8 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
if isinstance(return_type, DataFrameType):
data_fields = cast(DataFrameType, return_type).data_fields
return_schema = cast(DataFrameType, return_type).spark_type
index_field = cast(DataFrameType, return_type).index_field
should_retain_index = index_field is not None
index_fields = [index_field]
index_fields = cast(DataFrameType, return_type).index_fields
should_retain_index = index_fields is not None
psdf_from_pandas = None
else:
should_return_series = True
@ -1329,10 +1328,18 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
)
else:
index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
index_field = index_fields[0]
index_spark_columns = [scol_for(sdf, index_field.struct_field.name)]
if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME:
index_names = [(index_field.struct_field.name,)]
index_spark_columns = [
scol_for(sdf, index_field.struct_field.name) for index_field in index_fields
]
if not any(
[
SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
for index_field in index_fields
]
):
index_names = [(index_field.struct_field.name,) for index_field in index_fields]
internal = InternalFrame(
spark_frame=sdf,
index_names=index_names,

View file

@ -4678,6 +4678,32 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
actual.columns = ["a", "b"]
self.assert_eq(actual, pdf)
arrays = [[1, 2, 3, 4, 5, 6, 7, 8, 9], ["a", "b", "c", "d", "e", "f", "g", "h", "i"]]
idx = pd.MultiIndex.from_arrays(arrays, names=("number", "color"))
pdf = pd.DataFrame(
{"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [[e] for e in [4, 5, 6, 3, 2, 1, 0, 0, 0]]},
index=idx,
)
psdf = ps.from_pandas(pdf)
def identify4(x) -> ps.DataFrame[[int, str], [int, List[int]]]:
return x
actual = psdf.pandas_on_spark.apply_batch(identify4)
actual.index.names = ["number", "color"]
actual.columns = ["a", "b"]
self.assert_eq(actual, pdf)
def identify5(
x,
) -> ps.DataFrame[
[("number", int), ("color", str)], [("a", int), ("b", List[int])] # noqa: F405
]:
return x
actual = psdf.pandas_on_spark.apply_batch(identify5)
self.assert_eq(actual, pdf)
def test_transform_batch(self):
pdf = pd.DataFrame(
{

View file

@ -94,12 +94,12 @@ class SeriesType(Generic[T]):
class DataFrameType(object):
def __init__(
self,
index_field: Optional["InternalField"],
index_fields: Optional[List["InternalField"]],
data_fields: List["InternalField"],
):
self.index_field = index_field
self.index_fields = index_fields
self.data_fields = data_fields
self.fields = [index_field] + data_fields if index_field is not None else data_fields
self.fields = index_fields + data_fields if isinstance(index_fields, List) else data_fields
@property
def dtypes(self) -> List[Dtype]:
@ -514,8 +514,8 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp
[dtype('int64'), dtype('int64'), dtype('int64')]
>>> inferred.spark_type.simpleString()
'struct<__index_level_0__:bigint,c0:bigint,c1:bigint>'
>>> inferred.index_field
InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))
>>> inferred.index_fields
[InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))]
>>> def func() -> ps.DataFrame[pdf.index.dtype, pdf.dtypes]:
... pass
@ -524,8 +524,8 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp
[dtype('int64'), dtype('int64'), CategoricalDtype(categories=[3, 4, 5], ordered=False)]
>>> inferred.spark_type.simpleString()
'struct<__index_level_0__:bigint,c0:bigint,c1:bigint>'
>>> inferred.index_field
InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))
>>> inferred.index_fields
[InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))]
>>> def func() -> ps.DataFrame[
... ("index", CategoricalDtype(categories=[3, 4, 5], ordered=False)),
@ -536,8 +536,8 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp
[CategoricalDtype(categories=[3, 4, 5], ordered=False), dtype('int64'), dtype('int64')]
>>> inferred.spark_type.simpleString()
'struct<index:bigint,id:bigint,A:bigint>'
>>> inferred.index_field
InternalField(dtype=category,struct_field=StructField(index,LongType,true))
>>> inferred.index_fields
[InternalField(dtype=category,struct_field=StructField(index,LongType,true))]
>>> def func() -> ps.DataFrame[
... (pdf.index.name, pdf.index.dtype), zip(pdf.columns, pdf.dtypes)]:
@ -547,13 +547,13 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp
[dtype('int64'), dtype('int64'), CategoricalDtype(categories=[3, 4, 5], ordered=False)]
>>> inferred.spark_type.simpleString()
'struct<__index_level_0__:bigint,a:bigint,b:bigint>'
>>> inferred.index_field
InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))
>>> inferred.index_fields
[InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))]
"""
# We should re-import to make sure the class 'SeriesType' is not treated as a class
# within this module locally. See Series.__class_getitem__ which imports this class
# canonically.
from pyspark.pandas.internal import InternalField, SPARK_DEFAULT_INDEX_NAME
from pyspark.pandas.internal import InternalField, SPARK_INDEX_NAME_FORMAT
from pyspark.pandas.typedef import SeriesType, NameTypeHolder, IndexNameTypeHolder
from pyspark.pandas.utils import name_like_string
@ -595,20 +595,26 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp
data_parameters = [p for p in parameters if p not in index_parameters]
assert len(data_parameters) > 0, "Type hints for data must not be empty."
if len(index_parameters) == 1:
index_name = index_parameters[0].name
index_dtype, index_spark_type = pandas_on_spark_type(index_parameters[0].tpe)
index_field = InternalField(
dtype=index_dtype,
struct_field=types.StructField(
name=index_name if index_name is not None else SPARK_DEFAULT_INDEX_NAME,
dataType=index_spark_type,
),
)
index_fields = []
if len(index_parameters) >= 1:
for level, index_parameter in enumerate(index_parameters):
index_name = index_parameter.name
index_dtype, index_spark_type = pandas_on_spark_type(index_parameter.tpe)
index_fields.append(
InternalField(
dtype=index_dtype,
struct_field=types.StructField(
name=index_name
if index_name is not None
else SPARK_INDEX_NAME_FORMAT(level),
dataType=index_spark_type,
),
)
)
else:
assert len(index_parameters) == 0
# No type hint for index.
index_field = None
index_fields = None
data_dtypes, data_spark_types = zip(
*(
@ -636,7 +642,7 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp
)
)
return DataFrameType(index_field=index_field, data_fields=data_fields)
return DataFrameType(index_fields=index_fields, data_fields=data_fields)
tpes = pandas_on_spark_type(tpe)
if tpes is None:
@ -684,10 +690,10 @@ def create_tuple_for_frame_type(params: Any) -> object:
Typing data columns only:
>>> ps.DataFrame[float, float]
typing.Tuple[float, float]
>>> ps.DataFrame[pdf.dtypes]
typing.Tuple[numpy.int64]
>>> ps.DataFrame[float, float] # doctest: +ELLIPSIS
typing.Tuple[...NameType, ...NameType]
>>> ps.DataFrame[pdf.dtypes] # doctest: +ELLIPSIS
typing.Tuple[...NameType]
>>> ps.DataFrame["id": int, "A": int] # doctest: +ELLIPSIS
typing.Tuple[...NameType, ...NameType]
>>> ps.DataFrame[zip(pdf.columns, pdf.dtypes)] # doctest: +ELLIPSIS
@ -696,48 +702,42 @@ def create_tuple_for_frame_type(params: Any) -> object:
Typing data columns with an index:
>>> ps.DataFrame[int, [int, int]] # doctest: +ELLIPSIS
typing.Tuple[...IndexNameType, int, int]
typing.Tuple[...IndexNameType, ...NameType, ...NameType]
>>> ps.DataFrame[pdf.index.dtype, pdf.dtypes] # doctest: +ELLIPSIS
typing.Tuple[...IndexNameType, numpy.int64]
typing.Tuple[...IndexNameType, ...NameType]
>>> ps.DataFrame[("index", int), [("id", int), ("A", int)]] # doctest: +ELLIPSIS
typing.Tuple[...IndexNameType, ...NameType, ...NameType]
>>> ps.DataFrame[(pdf.index.name, pdf.index.dtype), zip(pdf.columns, pdf.dtypes)]
... # doctest: +ELLIPSIS
typing.Tuple[...IndexNameType, ...NameType]
Typing data columns with an Multi-index:
>>> arrays = [[1, 1, 2], ['red', 'blue', 'red']]
>>> idx = pd.MultiIndex.from_arrays(arrays, names=('number', 'color'))
>>> pdf = pd.DataFrame({'a': range(3)}, index=idx)
>>> ps.DataFrame[[int, int], [int, int]] # doctest: +ELLIPSIS
typing.Tuple[...IndexNameType, ...IndexNameType, ...NameType, ...NameType]
>>> ps.DataFrame[pdf.index.dtypes, pdf.dtypes] # doctest: +ELLIPSIS
typing.Tuple[...IndexNameType, ...NameType]
>>> ps.DataFrame[[("index-1", int), ("index-2", int)], [("id", int), ("A", int)]]
... # doctest: +ELLIPSIS
typing.Tuple[...IndexNameType, ...IndexNameType, ...NameType, ...NameType]
>>> ps.DataFrame[zip(pdf.index.names, pdf.index.dtypes), zip(pdf.columns, pdf.dtypes)]
... # doctest: +ELLIPSIS
typing.Tuple[...IndexNameType, ...NameType]
"""
return Tuple[extract_types(params)]
return Tuple[_extract_types(params)]
# TODO(SPARK-36708): numpy.typing (numpy 1.21+) support for nested types.
def extract_types(params: Any) -> Tuple:
def _extract_types(params: Any) -> Tuple:
origin = params
if isinstance(params, zip):
# Example:
# DataFrame[zip(pdf.columns, pdf.dtypes)]
params = tuple(slice(name, tpe) for name, tpe in params) # type: ignore[misc, has-type]
if isinstance(params, Iterable):
params = tuple(params)
else:
params = (params,)
params = _to_tuple_of_params(params)
if all(
isinstance(param, slice)
and param.start is not None
and param.step is None
and param.stop is not None
for param in params
):
if _is_named_params(params):
# Example:
# DataFrame["id": int, "A": int]
new_params = []
for param in params:
new_param = type("NameType", (NameTypeHolder,), {}) # type: Type[NameTypeHolder]
new_param.name = param.start
# When the given argument is a numpy's dtype instance.
new_param.tpe = param.stop.type if isinstance(param.stop, np.dtype) else param.stop
new_params.append(new_param)
new_params = _address_named_type_hoders(params, is_index=False)
return tuple(new_params)
elif len(params) == 2 and isinstance(params[1], (zip, list, pd.Series)):
# Example:
@ -745,49 +745,117 @@ def extract_types(params: Any) -> Tuple:
# DataFrame[pdf.index.dtype, pdf.dtypes]
# DataFrame[("index", int), [("id", int), ("A", int)]]
# DataFrame[(pdf.index.name, pdf.index.dtype), zip(pdf.columns, pdf.dtypes)]
#
# DataFrame[[int, int], [int, int]]
# DataFrame[pdf.index.dtypes, pdf.dtypes]
# DataFrame[[("index", int), ("index-2", int)], [("id", int), ("A", int)]]
# DataFrame[zip(pdf.index.names, pdf.index.dtypes), zip(pdf.columns, pdf.dtypes)]
index_param = params[0]
index_type = type(
"IndexNameType", (IndexNameTypeHolder,), {}
) # type: Type[IndexNameTypeHolder]
if isinstance(index_param, tuple):
if len(index_param) != 2:
raise TypeError(
"Type hints for index should be specified as "
"DataFrame[('name', type), ...]; however, got %s" % index_param
)
name, tpe = index_param
else:
name, tpe = None, index_param
index_params = params[0]
index_type.name = name
if isinstance(tpe, ExtensionDtype):
index_type.tpe = tpe
if isinstance(index_params, tuple) and len(index_params) == 2:
index_params = tuple([slice(*index_params)])
index_params = _convert_tuples_to_zip(index_params)
index_params = _to_tuple_of_params(index_params)
if _is_named_params(index_params):
# Example:
# DataFrame[[("id", int), ("A", int)], [int, int]]
new_index_params = _address_named_type_hoders(index_params, is_index=True)
index_types = tuple(new_index_params)
else:
index_type.tpe = tpe.type if isinstance(tpe, np.dtype) else tpe
# Exaxmples:
# DataFrame[[float, float], [int, int]]
# DataFrame[pdf.dtypes, [int, int]]
index_types = _address_unnamed_type_holders(index_params, origin, is_index=True)
data_types = params[1]
if (
isinstance(data_types, list)
and len(data_types) >= 1
and isinstance(data_types[0], tuple)
):
# Example:
# DataFrame[("index", int), [("id", int), ("A", int)]]
data_types = zip((name for name, _ in data_types), (tpe for _, tpe in data_types))
return (index_type,) + extract_types(data_types)
elif all(not isinstance(param, slice) and not isinstance(param, Iterable) for param in params):
data_types = _convert_tuples_to_zip(data_types)
return index_types + _extract_types(data_types)
else:
# Exaxmples:
# DataFrame[float, float]
# DataFrame[pdf.dtypes]
return _address_unnamed_type_holders(params, origin, is_index=False)
def _is_named_params(params: Any) -> Any:
return all(
isinstance(param, slice) and param.step is None and param.stop is not None
for param in params
)
def _address_named_type_hoders(params: Any, is_index: bool) -> Any:
# Example:
# params = (slice("id", int, None), slice("A", int, None))
new_params = []
for param in params:
new_param = (
type("IndexNameType", (IndexNameTypeHolder,), {})
if is_index
else type("NameType", (NameTypeHolder,), {})
) # type: Union[Type[IndexNameTypeHolder], Type[NameTypeHolder]]
new_param.name = param.start
if isinstance(param.stop, ExtensionDtype):
new_param.tpe = param.stop
else:
# When the given argument is a numpy's dtype instance.
new_param.tpe = param.stop.type if isinstance(param.stop, np.dtype) else param.stop
new_params.append(new_param)
return new_params
def _to_tuple_of_params(params: Any) -> Any:
"""
>>> _to_tuple_of_params(int)
(<class 'int'>,)
>>> _to_tuple_of_params([int, int, int])
(<class 'int'>, <class 'int'>, <class 'int'>)
>>> arrays = [[1, 1, 2], ['red', 'blue', 'red']]
>>> idx = pd.MultiIndex.from_arrays(arrays, names=('number', 'color'))
>>> pdf = pd.DataFrame([[1, 2], [2, 3], [4, 5]], index=idx, columns=["a", "b"])
>>> _to_tuple_of_params(zip(pdf.columns, pdf.dtypes))
(slice('a', dtype('int64'), None), slice('b', dtype('int64'), None))
>>> _to_tuple_of_params(zip(pdf.index.names, pdf.index.dtypes))
(slice('number', dtype('int64'), None), slice('color', dtype('O'), None))
"""
if isinstance(params, zip):
params = tuple(slice(name, tpe) for name, tpe in params) # type: ignore[misc, has-type]
if isinstance(params, Iterable):
params = tuple(params)
else:
params = (params,)
return params
def _convert_tuples_to_zip(params: Any) -> Any:
if isinstance(params, list) and len(params) >= 1 and isinstance(params[0], tuple):
return zip((name for name, _ in params), (tpe for _, tpe in params))
return params
def _address_unnamed_type_holders(params: Any, origin: Any, is_index: bool) -> Any:
if all(not isinstance(param, slice) and not isinstance(param, Iterable) for param in params):
new_types = []
for param in params:
new_type = (
type("IndexNameType", (IndexNameTypeHolder,), {})
if is_index
else type("NameType", (NameTypeHolder,), {})
) # type: Union[Type[IndexNameTypeHolder], Type[NameTypeHolder]]
if isinstance(param, ExtensionDtype):
new_type = type("NameType", (NameTypeHolder,), {}) # type: Type[NameTypeHolder]
new_type.tpe = param
new_types.append(new_type)
else:
new_types.append(param.type if isinstance(param, np.dtype) else param)
new_type.tpe = param.type if isinstance(param, np.dtype) else param
new_types.append(new_type)
return tuple(new_types)
else:
raise TypeError(
@ -799,7 +867,11 @@ def extract_types(params: Any) -> Tuple:
- DataFrame[index_type, [type, ...]]
- DataFrame[(index_name, index_type), [(name, type), ...]]
- DataFrame[dtype instance, dtypes instance]
- DataFrame[(index_name, index_type), zip(names, types)]\n"""
- DataFrame[(index_name, index_type), zip(names, types)]
- DataFrame[[index_type, ...], [type, ...]]
- DataFrame[[(index_name, index_type), ...], [(name, type), ...]]
- DataFrame[dtypes instance, dtypes instance]
- DataFrame[zip(index_names, index_types), zip(names, types)]\n"""
+ "However, got %s." % str(origin)
)