[SPARK-36711][PYTHON] Support multi-index in new syntax
### What changes were proposed in this pull request? Support multi-index in new syntax to specify index data type ### Why are the changes needed? Support multi-index in new syntax to specify index data type https://issues.apache.org/jira/browse/SPARK-36707 ### Does this PR introduce _any_ user-facing change? After this PR user can use ``` python >>> ps.DataFrame[[int, int],[int, int]] typing.Tuple[pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType] >>> arrays = [[1, 1, 2], ['red', 'blue', 'red']] >>> idx = pd.MultiIndex.from_arrays(arrays, names=('number', 'color')) >>> pdf = pd.DataFrame([[1,2,3],[2,3,4],[4,5,6]], index=idx, columns=["a", "b", "c"]) >>> ps.DataFrame[pdf.index.dtypes, pdf.dtypes] typing.Tuple[pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType] >>> ps.DataFrame[[("index", int), ("index-2", int)], [("id", int), ("A", int)]] typing.Tuple[pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType] >>> ps.DataFrame[zip(pdf.index.names, pdf.index.dtypes), zip(pdf.columns, pdf.dtypes)] typing.Tuple[pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.IndexNameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType, pyspark.pandas.typedef.typehints.NameType] ``` ### How was this patch tested? exist tests Closes #34176 from dchvn/SPARK-36711. Authored-by: dchvn nguyen <dgd_contributor@viettel.com.vn> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
parent
fa1805db48
commit
d6786e036d
|
@ -34,7 +34,7 @@ from pyspark.pandas.internal import (
|
|||
InternalFrame,
|
||||
SPARK_INDEX_NAME_FORMAT,
|
||||
SPARK_DEFAULT_SERIES_NAME,
|
||||
SPARK_DEFAULT_INDEX_NAME,
|
||||
SPARK_INDEX_NAME_PATTERN,
|
||||
)
|
||||
from pyspark.pandas.typedef import infer_return_type, DataFrameType, ScalarType, SeriesType
|
||||
from pyspark.pandas.utils import (
|
||||
|
@ -384,8 +384,8 @@ class PandasOnSparkFrameMethods(object):
|
|||
"The given function should specify a frame as its type "
|
||||
"hints; however, the return type was %s." % return_sig
|
||||
)
|
||||
index_field = cast(DataFrameType, return_type).index_field
|
||||
should_retain_index = index_field is not None
|
||||
index_fields = cast(DataFrameType, return_type).index_fields
|
||||
should_retain_index = index_fields is not None
|
||||
return_schema = cast(DataFrameType, return_type).spark_type
|
||||
|
||||
output_func = GroupBy._make_pandas_df_builder_func(
|
||||
|
@ -397,12 +397,19 @@ class PandasOnSparkFrameMethods(object):
|
|||
|
||||
index_spark_columns = None
|
||||
index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
|
||||
index_fields = None
|
||||
|
||||
if should_retain_index:
|
||||
index_spark_columns = [scol_for(sdf, index_field.struct_field.name)]
|
||||
index_fields = [index_field]
|
||||
if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME:
|
||||
index_names = [(index_field.struct_field.name,)]
|
||||
index_spark_columns = [
|
||||
scol_for(sdf, index_field.struct_field.name) for index_field in index_fields
|
||||
]
|
||||
|
||||
if not any(
|
||||
[
|
||||
SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
|
||||
for index_field in index_fields
|
||||
]
|
||||
):
|
||||
index_names = [(index_field.struct_field.name,) for index_field in index_fields]
|
||||
internal = InternalFrame(
|
||||
spark_frame=sdf,
|
||||
index_names=index_names,
|
||||
|
@ -680,17 +687,19 @@ class PandasOnSparkFrameMethods(object):
|
|||
)
|
||||
return first_series(DataFrame(internal))
|
||||
else:
|
||||
index_field = cast(DataFrameType, return_type).index_field
|
||||
index_field = (
|
||||
index_field.normalize_spark_type() if index_field is not None else None
|
||||
index_fields = cast(DataFrameType, return_type).index_fields
|
||||
index_fields = (
|
||||
[index_field.normalize_spark_type() for index_field in index_fields]
|
||||
if index_fields is not None
|
||||
else None
|
||||
)
|
||||
data_fields = [
|
||||
field.normalize_spark_type()
|
||||
for field in cast(DataFrameType, return_type).data_fields
|
||||
]
|
||||
normalized_fields = ([index_field] if index_field is not None else []) + data_fields
|
||||
normalized_fields = (index_fields if index_fields is not None else []) + data_fields
|
||||
return_schema = StructType([field.struct_field for field in normalized_fields])
|
||||
should_retain_index = index_field is not None
|
||||
should_retain_index = index_fields is not None
|
||||
|
||||
self_applied = DataFrame(self._psdf._internal.resolved_copy)
|
||||
|
||||
|
@ -711,12 +720,21 @@ class PandasOnSparkFrameMethods(object):
|
|||
|
||||
index_spark_columns = None
|
||||
index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
|
||||
index_fields = None
|
||||
|
||||
if should_retain_index:
|
||||
index_spark_columns = [scol_for(sdf, index_field.struct_field.name)]
|
||||
index_fields = [index_field]
|
||||
if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME:
|
||||
index_names = [(index_field.struct_field.name,)]
|
||||
index_spark_columns = [
|
||||
scol_for(sdf, index_field.struct_field.name) for index_field in index_fields
|
||||
]
|
||||
|
||||
if not any(
|
||||
[
|
||||
SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
|
||||
for index_field in index_fields
|
||||
]
|
||||
):
|
||||
index_names = [
|
||||
(index_field.struct_field.name,) for index_field in index_fields
|
||||
]
|
||||
internal = InternalFrame(
|
||||
spark_frame=sdf,
|
||||
index_names=index_names,
|
||||
|
|
|
@ -114,6 +114,7 @@ from pyspark.pandas.internal import (
|
|||
SPARK_INDEX_NAME_FORMAT,
|
||||
SPARK_DEFAULT_INDEX_NAME,
|
||||
SPARK_DEFAULT_SERIES_NAME,
|
||||
SPARK_INDEX_NAME_PATTERN,
|
||||
)
|
||||
from pyspark.pandas.missing.frame import _MissingPandasLikeDataFrame
|
||||
from pyspark.pandas.ml import corr
|
||||
|
@ -2511,7 +2512,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
|
|||
return_type = infer_return_type(func)
|
||||
require_index_axis = isinstance(return_type, SeriesType)
|
||||
require_column_axis = isinstance(return_type, DataFrameType)
|
||||
index_field = None
|
||||
index_fields = None
|
||||
|
||||
if require_index_axis:
|
||||
if axis != 0:
|
||||
|
@ -2536,8 +2537,8 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
|
|||
"hints when axis is 1 or 'column'; however, the return type "
|
||||
"was %s" % return_sig
|
||||
)
|
||||
index_field = cast(DataFrameType, return_type).index_field
|
||||
should_retain_index = index_field is not None
|
||||
index_fields = cast(DataFrameType, return_type).index_fields
|
||||
should_retain_index = index_fields is not None
|
||||
data_fields = cast(DataFrameType, return_type).data_fields
|
||||
return_schema = cast(DataFrameType, return_type).spark_type
|
||||
else:
|
||||
|
@ -2565,12 +2566,19 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
|
|||
|
||||
index_spark_columns = None
|
||||
index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
|
||||
index_fields = None
|
||||
|
||||
if should_retain_index:
|
||||
index_spark_columns = [scol_for(sdf, index_field.struct_field.name)]
|
||||
index_fields = [index_field]
|
||||
if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME:
|
||||
index_names = [(index_field.struct_field.name,)]
|
||||
index_spark_columns = [
|
||||
scol_for(sdf, index_field.struct_field.name) for index_field in index_fields
|
||||
]
|
||||
|
||||
if not any(
|
||||
[
|
||||
SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
|
||||
for index_field in index_fields
|
||||
]
|
||||
):
|
||||
index_names = [(index_field.struct_field.name,) for index_field in index_fields]
|
||||
internal = InternalFrame(
|
||||
spark_frame=sdf,
|
||||
index_names=index_names,
|
||||
|
|
|
@ -76,7 +76,7 @@ from pyspark.pandas.internal import (
|
|||
NATURAL_ORDER_COLUMN_NAME,
|
||||
SPARK_INDEX_NAME_FORMAT,
|
||||
SPARK_DEFAULT_SERIES_NAME,
|
||||
SPARK_DEFAULT_INDEX_NAME,
|
||||
SPARK_INDEX_NAME_PATTERN,
|
||||
)
|
||||
from pyspark.pandas.missing.groupby import (
|
||||
MissingPandasLikeDataFrameGroupBy,
|
||||
|
@ -1252,9 +1252,8 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
|
|||
if isinstance(return_type, DataFrameType):
|
||||
data_fields = cast(DataFrameType, return_type).data_fields
|
||||
return_schema = cast(DataFrameType, return_type).spark_type
|
||||
index_field = cast(DataFrameType, return_type).index_field
|
||||
should_retain_index = index_field is not None
|
||||
index_fields = [index_field]
|
||||
index_fields = cast(DataFrameType, return_type).index_fields
|
||||
should_retain_index = index_fields is not None
|
||||
psdf_from_pandas = None
|
||||
else:
|
||||
should_return_series = True
|
||||
|
@ -1329,10 +1328,18 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
|
|||
)
|
||||
else:
|
||||
index_names: Optional[List[Optional[Tuple[Any, ...]]]] = None
|
||||
index_field = index_fields[0]
|
||||
index_spark_columns = [scol_for(sdf, index_field.struct_field.name)]
|
||||
if index_field.struct_field.name != SPARK_DEFAULT_INDEX_NAME:
|
||||
index_names = [(index_field.struct_field.name,)]
|
||||
|
||||
index_spark_columns = [
|
||||
scol_for(sdf, index_field.struct_field.name) for index_field in index_fields
|
||||
]
|
||||
|
||||
if not any(
|
||||
[
|
||||
SPARK_INDEX_NAME_PATTERN.match(index_field.struct_field.name)
|
||||
for index_field in index_fields
|
||||
]
|
||||
):
|
||||
index_names = [(index_field.struct_field.name,) for index_field in index_fields]
|
||||
internal = InternalFrame(
|
||||
spark_frame=sdf,
|
||||
index_names=index_names,
|
||||
|
|
|
@ -4678,6 +4678,32 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
actual.columns = ["a", "b"]
|
||||
self.assert_eq(actual, pdf)
|
||||
|
||||
arrays = [[1, 2, 3, 4, 5, 6, 7, 8, 9], ["a", "b", "c", "d", "e", "f", "g", "h", "i"]]
|
||||
idx = pd.MultiIndex.from_arrays(arrays, names=("number", "color"))
|
||||
pdf = pd.DataFrame(
|
||||
{"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [[e] for e in [4, 5, 6, 3, 2, 1, 0, 0, 0]]},
|
||||
index=idx,
|
||||
)
|
||||
psdf = ps.from_pandas(pdf)
|
||||
|
||||
def identify4(x) -> ps.DataFrame[[int, str], [int, List[int]]]:
|
||||
return x
|
||||
|
||||
actual = psdf.pandas_on_spark.apply_batch(identify4)
|
||||
actual.index.names = ["number", "color"]
|
||||
actual.columns = ["a", "b"]
|
||||
self.assert_eq(actual, pdf)
|
||||
|
||||
def identify5(
|
||||
x,
|
||||
) -> ps.DataFrame[
|
||||
[("number", int), ("color", str)], [("a", int), ("b", List[int])] # noqa: F405
|
||||
]:
|
||||
return x
|
||||
|
||||
actual = psdf.pandas_on_spark.apply_batch(identify5)
|
||||
self.assert_eq(actual, pdf)
|
||||
|
||||
def test_transform_batch(self):
|
||||
pdf = pd.DataFrame(
|
||||
{
|
||||
|
|
|
@ -94,12 +94,12 @@ class SeriesType(Generic[T]):
|
|||
class DataFrameType(object):
|
||||
def __init__(
|
||||
self,
|
||||
index_field: Optional["InternalField"],
|
||||
index_fields: Optional[List["InternalField"]],
|
||||
data_fields: List["InternalField"],
|
||||
):
|
||||
self.index_field = index_field
|
||||
self.index_fields = index_fields
|
||||
self.data_fields = data_fields
|
||||
self.fields = [index_field] + data_fields if index_field is not None else data_fields
|
||||
self.fields = index_fields + data_fields if isinstance(index_fields, List) else data_fields
|
||||
|
||||
@property
|
||||
def dtypes(self) -> List[Dtype]:
|
||||
|
@ -514,8 +514,8 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp
|
|||
[dtype('int64'), dtype('int64'), dtype('int64')]
|
||||
>>> inferred.spark_type.simpleString()
|
||||
'struct<__index_level_0__:bigint,c0:bigint,c1:bigint>'
|
||||
>>> inferred.index_field
|
||||
InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))
|
||||
>>> inferred.index_fields
|
||||
[InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))]
|
||||
|
||||
>>> def func() -> ps.DataFrame[pdf.index.dtype, pdf.dtypes]:
|
||||
... pass
|
||||
|
@ -524,8 +524,8 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp
|
|||
[dtype('int64'), dtype('int64'), CategoricalDtype(categories=[3, 4, 5], ordered=False)]
|
||||
>>> inferred.spark_type.simpleString()
|
||||
'struct<__index_level_0__:bigint,c0:bigint,c1:bigint>'
|
||||
>>> inferred.index_field
|
||||
InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))
|
||||
>>> inferred.index_fields
|
||||
[InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))]
|
||||
|
||||
>>> def func() -> ps.DataFrame[
|
||||
... ("index", CategoricalDtype(categories=[3, 4, 5], ordered=False)),
|
||||
|
@ -536,8 +536,8 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp
|
|||
[CategoricalDtype(categories=[3, 4, 5], ordered=False), dtype('int64'), dtype('int64')]
|
||||
>>> inferred.spark_type.simpleString()
|
||||
'struct<index:bigint,id:bigint,A:bigint>'
|
||||
>>> inferred.index_field
|
||||
InternalField(dtype=category,struct_field=StructField(index,LongType,true))
|
||||
>>> inferred.index_fields
|
||||
[InternalField(dtype=category,struct_field=StructField(index,LongType,true))]
|
||||
|
||||
>>> def func() -> ps.DataFrame[
|
||||
... (pdf.index.name, pdf.index.dtype), zip(pdf.columns, pdf.dtypes)]:
|
||||
|
@ -547,13 +547,13 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp
|
|||
[dtype('int64'), dtype('int64'), CategoricalDtype(categories=[3, 4, 5], ordered=False)]
|
||||
>>> inferred.spark_type.simpleString()
|
||||
'struct<__index_level_0__:bigint,a:bigint,b:bigint>'
|
||||
>>> inferred.index_field
|
||||
InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))
|
||||
>>> inferred.index_fields
|
||||
[InternalField(dtype=int64,struct_field=StructField(__index_level_0__,LongType,true))]
|
||||
"""
|
||||
# We should re-import to make sure the class 'SeriesType' is not treated as a class
|
||||
# within this module locally. See Series.__class_getitem__ which imports this class
|
||||
# canonically.
|
||||
from pyspark.pandas.internal import InternalField, SPARK_DEFAULT_INDEX_NAME
|
||||
from pyspark.pandas.internal import InternalField, SPARK_INDEX_NAME_FORMAT
|
||||
from pyspark.pandas.typedef import SeriesType, NameTypeHolder, IndexNameTypeHolder
|
||||
from pyspark.pandas.utils import name_like_string
|
||||
|
||||
|
@ -595,20 +595,26 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp
|
|||
data_parameters = [p for p in parameters if p not in index_parameters]
|
||||
assert len(data_parameters) > 0, "Type hints for data must not be empty."
|
||||
|
||||
if len(index_parameters) == 1:
|
||||
index_name = index_parameters[0].name
|
||||
index_dtype, index_spark_type = pandas_on_spark_type(index_parameters[0].tpe)
|
||||
index_field = InternalField(
|
||||
dtype=index_dtype,
|
||||
struct_field=types.StructField(
|
||||
name=index_name if index_name is not None else SPARK_DEFAULT_INDEX_NAME,
|
||||
dataType=index_spark_type,
|
||||
),
|
||||
)
|
||||
index_fields = []
|
||||
if len(index_parameters) >= 1:
|
||||
for level, index_parameter in enumerate(index_parameters):
|
||||
index_name = index_parameter.name
|
||||
index_dtype, index_spark_type = pandas_on_spark_type(index_parameter.tpe)
|
||||
index_fields.append(
|
||||
InternalField(
|
||||
dtype=index_dtype,
|
||||
struct_field=types.StructField(
|
||||
name=index_name
|
||||
if index_name is not None
|
||||
else SPARK_INDEX_NAME_FORMAT(level),
|
||||
dataType=index_spark_type,
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert len(index_parameters) == 0
|
||||
# No type hint for index.
|
||||
index_field = None
|
||||
index_fields = None
|
||||
|
||||
data_dtypes, data_spark_types = zip(
|
||||
*(
|
||||
|
@ -636,7 +642,7 @@ def infer_return_type(f: Callable) -> Union[SeriesType, DataFrameType, ScalarTyp
|
|||
)
|
||||
)
|
||||
|
||||
return DataFrameType(index_field=index_field, data_fields=data_fields)
|
||||
return DataFrameType(index_fields=index_fields, data_fields=data_fields)
|
||||
|
||||
tpes = pandas_on_spark_type(tpe)
|
||||
if tpes is None:
|
||||
|
@ -684,10 +690,10 @@ def create_tuple_for_frame_type(params: Any) -> object:
|
|||
|
||||
Typing data columns only:
|
||||
|
||||
>>> ps.DataFrame[float, float]
|
||||
typing.Tuple[float, float]
|
||||
>>> ps.DataFrame[pdf.dtypes]
|
||||
typing.Tuple[numpy.int64]
|
||||
>>> ps.DataFrame[float, float] # doctest: +ELLIPSIS
|
||||
typing.Tuple[...NameType, ...NameType]
|
||||
>>> ps.DataFrame[pdf.dtypes] # doctest: +ELLIPSIS
|
||||
typing.Tuple[...NameType]
|
||||
>>> ps.DataFrame["id": int, "A": int] # doctest: +ELLIPSIS
|
||||
typing.Tuple[...NameType, ...NameType]
|
||||
>>> ps.DataFrame[zip(pdf.columns, pdf.dtypes)] # doctest: +ELLIPSIS
|
||||
|
@ -696,48 +702,42 @@ def create_tuple_for_frame_type(params: Any) -> object:
|
|||
Typing data columns with an index:
|
||||
|
||||
>>> ps.DataFrame[int, [int, int]] # doctest: +ELLIPSIS
|
||||
typing.Tuple[...IndexNameType, int, int]
|
||||
typing.Tuple[...IndexNameType, ...NameType, ...NameType]
|
||||
>>> ps.DataFrame[pdf.index.dtype, pdf.dtypes] # doctest: +ELLIPSIS
|
||||
typing.Tuple[...IndexNameType, numpy.int64]
|
||||
typing.Tuple[...IndexNameType, ...NameType]
|
||||
>>> ps.DataFrame[("index", int), [("id", int), ("A", int)]] # doctest: +ELLIPSIS
|
||||
typing.Tuple[...IndexNameType, ...NameType, ...NameType]
|
||||
>>> ps.DataFrame[(pdf.index.name, pdf.index.dtype), zip(pdf.columns, pdf.dtypes)]
|
||||
... # doctest: +ELLIPSIS
|
||||
typing.Tuple[...IndexNameType, ...NameType]
|
||||
|
||||
Typing data columns with an Multi-index:
|
||||
>>> arrays = [[1, 1, 2], ['red', 'blue', 'red']]
|
||||
>>> idx = pd.MultiIndex.from_arrays(arrays, names=('number', 'color'))
|
||||
>>> pdf = pd.DataFrame({'a': range(3)}, index=idx)
|
||||
>>> ps.DataFrame[[int, int], [int, int]] # doctest: +ELLIPSIS
|
||||
typing.Tuple[...IndexNameType, ...IndexNameType, ...NameType, ...NameType]
|
||||
>>> ps.DataFrame[pdf.index.dtypes, pdf.dtypes] # doctest: +ELLIPSIS
|
||||
typing.Tuple[...IndexNameType, ...NameType]
|
||||
>>> ps.DataFrame[[("index-1", int), ("index-2", int)], [("id", int), ("A", int)]]
|
||||
... # doctest: +ELLIPSIS
|
||||
typing.Tuple[...IndexNameType, ...IndexNameType, ...NameType, ...NameType]
|
||||
>>> ps.DataFrame[zip(pdf.index.names, pdf.index.dtypes), zip(pdf.columns, pdf.dtypes)]
|
||||
... # doctest: +ELLIPSIS
|
||||
typing.Tuple[...IndexNameType, ...NameType]
|
||||
"""
|
||||
return Tuple[extract_types(params)]
|
||||
return Tuple[_extract_types(params)]
|
||||
|
||||
|
||||
# TODO(SPARK-36708): numpy.typing (numpy 1.21+) support for nested types.
|
||||
def extract_types(params: Any) -> Tuple:
|
||||
def _extract_types(params: Any) -> Tuple:
|
||||
origin = params
|
||||
if isinstance(params, zip):
|
||||
# Example:
|
||||
# DataFrame[zip(pdf.columns, pdf.dtypes)]
|
||||
params = tuple(slice(name, tpe) for name, tpe in params) # type: ignore[misc, has-type]
|
||||
|
||||
if isinstance(params, Iterable):
|
||||
params = tuple(params)
|
||||
else:
|
||||
params = (params,)
|
||||
params = _to_tuple_of_params(params)
|
||||
|
||||
if all(
|
||||
isinstance(param, slice)
|
||||
and param.start is not None
|
||||
and param.step is None
|
||||
and param.stop is not None
|
||||
for param in params
|
||||
):
|
||||
if _is_named_params(params):
|
||||
# Example:
|
||||
# DataFrame["id": int, "A": int]
|
||||
new_params = []
|
||||
for param in params:
|
||||
new_param = type("NameType", (NameTypeHolder,), {}) # type: Type[NameTypeHolder]
|
||||
new_param.name = param.start
|
||||
# When the given argument is a numpy's dtype instance.
|
||||
new_param.tpe = param.stop.type if isinstance(param.stop, np.dtype) else param.stop
|
||||
new_params.append(new_param)
|
||||
|
||||
new_params = _address_named_type_hoders(params, is_index=False)
|
||||
return tuple(new_params)
|
||||
elif len(params) == 2 and isinstance(params[1], (zip, list, pd.Series)):
|
||||
# Example:
|
||||
|
@ -745,49 +745,117 @@ def extract_types(params: Any) -> Tuple:
|
|||
# DataFrame[pdf.index.dtype, pdf.dtypes]
|
||||
# DataFrame[("index", int), [("id", int), ("A", int)]]
|
||||
# DataFrame[(pdf.index.name, pdf.index.dtype), zip(pdf.columns, pdf.dtypes)]
|
||||
#
|
||||
# DataFrame[[int, int], [int, int]]
|
||||
# DataFrame[pdf.index.dtypes, pdf.dtypes]
|
||||
# DataFrame[[("index", int), ("index-2", int)], [("id", int), ("A", int)]]
|
||||
# DataFrame[zip(pdf.index.names, pdf.index.dtypes), zip(pdf.columns, pdf.dtypes)]
|
||||
|
||||
index_param = params[0]
|
||||
index_type = type(
|
||||
"IndexNameType", (IndexNameTypeHolder,), {}
|
||||
) # type: Type[IndexNameTypeHolder]
|
||||
if isinstance(index_param, tuple):
|
||||
if len(index_param) != 2:
|
||||
raise TypeError(
|
||||
"Type hints for index should be specified as "
|
||||
"DataFrame[('name', type), ...]; however, got %s" % index_param
|
||||
)
|
||||
name, tpe = index_param
|
||||
else:
|
||||
name, tpe = None, index_param
|
||||
index_params = params[0]
|
||||
|
||||
index_type.name = name
|
||||
if isinstance(tpe, ExtensionDtype):
|
||||
index_type.tpe = tpe
|
||||
if isinstance(index_params, tuple) and len(index_params) == 2:
|
||||
index_params = tuple([slice(*index_params)])
|
||||
|
||||
index_params = _convert_tuples_to_zip(index_params)
|
||||
index_params = _to_tuple_of_params(index_params)
|
||||
|
||||
if _is_named_params(index_params):
|
||||
# Example:
|
||||
# DataFrame[[("id", int), ("A", int)], [int, int]]
|
||||
new_index_params = _address_named_type_hoders(index_params, is_index=True)
|
||||
index_types = tuple(new_index_params)
|
||||
else:
|
||||
index_type.tpe = tpe.type if isinstance(tpe, np.dtype) else tpe
|
||||
# Exaxmples:
|
||||
# DataFrame[[float, float], [int, int]]
|
||||
# DataFrame[pdf.dtypes, [int, int]]
|
||||
index_types = _address_unnamed_type_holders(index_params, origin, is_index=True)
|
||||
|
||||
data_types = params[1]
|
||||
if (
|
||||
isinstance(data_types, list)
|
||||
and len(data_types) >= 1
|
||||
and isinstance(data_types[0], tuple)
|
||||
):
|
||||
# Example:
|
||||
# DataFrame[("index", int), [("id", int), ("A", int)]]
|
||||
data_types = zip((name for name, _ in data_types), (tpe for _, tpe in data_types))
|
||||
return (index_type,) + extract_types(data_types)
|
||||
elif all(not isinstance(param, slice) and not isinstance(param, Iterable) for param in params):
|
||||
data_types = _convert_tuples_to_zip(data_types)
|
||||
|
||||
return index_types + _extract_types(data_types)
|
||||
|
||||
else:
|
||||
# Exaxmples:
|
||||
# DataFrame[float, float]
|
||||
# DataFrame[pdf.dtypes]
|
||||
return _address_unnamed_type_holders(params, origin, is_index=False)
|
||||
|
||||
|
||||
def _is_named_params(params: Any) -> Any:
|
||||
return all(
|
||||
isinstance(param, slice) and param.step is None and param.stop is not None
|
||||
for param in params
|
||||
)
|
||||
|
||||
|
||||
def _address_named_type_hoders(params: Any, is_index: bool) -> Any:
|
||||
# Example:
|
||||
# params = (slice("id", int, None), slice("A", int, None))
|
||||
new_params = []
|
||||
for param in params:
|
||||
new_param = (
|
||||
type("IndexNameType", (IndexNameTypeHolder,), {})
|
||||
if is_index
|
||||
else type("NameType", (NameTypeHolder,), {})
|
||||
) # type: Union[Type[IndexNameTypeHolder], Type[NameTypeHolder]]
|
||||
new_param.name = param.start
|
||||
if isinstance(param.stop, ExtensionDtype):
|
||||
new_param.tpe = param.stop
|
||||
else:
|
||||
# When the given argument is a numpy's dtype instance.
|
||||
new_param.tpe = param.stop.type if isinstance(param.stop, np.dtype) else param.stop
|
||||
new_params.append(new_param)
|
||||
return new_params
|
||||
|
||||
|
||||
def _to_tuple_of_params(params: Any) -> Any:
|
||||
"""
|
||||
>>> _to_tuple_of_params(int)
|
||||
(<class 'int'>,)
|
||||
|
||||
>>> _to_tuple_of_params([int, int, int])
|
||||
(<class 'int'>, <class 'int'>, <class 'int'>)
|
||||
|
||||
>>> arrays = [[1, 1, 2], ['red', 'blue', 'red']]
|
||||
>>> idx = pd.MultiIndex.from_arrays(arrays, names=('number', 'color'))
|
||||
>>> pdf = pd.DataFrame([[1, 2], [2, 3], [4, 5]], index=idx, columns=["a", "b"])
|
||||
|
||||
>>> _to_tuple_of_params(zip(pdf.columns, pdf.dtypes))
|
||||
(slice('a', dtype('int64'), None), slice('b', dtype('int64'), None))
|
||||
>>> _to_tuple_of_params(zip(pdf.index.names, pdf.index.dtypes))
|
||||
(slice('number', dtype('int64'), None), slice('color', dtype('O'), None))
|
||||
"""
|
||||
if isinstance(params, zip):
|
||||
params = tuple(slice(name, tpe) for name, tpe in params) # type: ignore[misc, has-type]
|
||||
|
||||
if isinstance(params, Iterable):
|
||||
params = tuple(params)
|
||||
else:
|
||||
params = (params,)
|
||||
return params
|
||||
|
||||
|
||||
def _convert_tuples_to_zip(params: Any) -> Any:
|
||||
if isinstance(params, list) and len(params) >= 1 and isinstance(params[0], tuple):
|
||||
return zip((name for name, _ in params), (tpe for _, tpe in params))
|
||||
return params
|
||||
|
||||
|
||||
def _address_unnamed_type_holders(params: Any, origin: Any, is_index: bool) -> Any:
|
||||
if all(not isinstance(param, slice) and not isinstance(param, Iterable) for param in params):
|
||||
new_types = []
|
||||
for param in params:
|
||||
new_type = (
|
||||
type("IndexNameType", (IndexNameTypeHolder,), {})
|
||||
if is_index
|
||||
else type("NameType", (NameTypeHolder,), {})
|
||||
) # type: Union[Type[IndexNameTypeHolder], Type[NameTypeHolder]]
|
||||
if isinstance(param, ExtensionDtype):
|
||||
new_type = type("NameType", (NameTypeHolder,), {}) # type: Type[NameTypeHolder]
|
||||
new_type.tpe = param
|
||||
new_types.append(new_type)
|
||||
else:
|
||||
new_types.append(param.type if isinstance(param, np.dtype) else param)
|
||||
new_type.tpe = param.type if isinstance(param, np.dtype) else param
|
||||
new_types.append(new_type)
|
||||
return tuple(new_types)
|
||||
else:
|
||||
raise TypeError(
|
||||
|
@ -799,7 +867,11 @@ def extract_types(params: Any) -> Tuple:
|
|||
- DataFrame[index_type, [type, ...]]
|
||||
- DataFrame[(index_name, index_type), [(name, type), ...]]
|
||||
- DataFrame[dtype instance, dtypes instance]
|
||||
- DataFrame[(index_name, index_type), zip(names, types)]\n"""
|
||||
- DataFrame[(index_name, index_type), zip(names, types)]
|
||||
- DataFrame[[index_type, ...], [type, ...]]
|
||||
- DataFrame[[(index_name, index_type), ...], [(name, type), ...]]
|
||||
- DataFrame[dtypes instance, dtypes instance]
|
||||
- DataFrame[zip(index_names, index_types), zip(names, types)]\n"""
|
||||
+ "However, got %s." % str(origin)
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in a new issue