[SPARK-35616][PYTHON] Make astype method data-type-based

### What changes were proposed in this pull request?

Make `astype` method data-type-based.

**Non-goal: Match pandas' `astype` TypeErrors.**
Currently, `astype` throws TypeError error messages only when the destination type is not recognized. However, for some destination types that don't make sense to the specific type of  Series/Index, for example, `numeric Series/Index → bytes`, we don't have proper TypeError error messages.
Since the goal of the PR is refactoring mainly, the above issue might be resolved later if needed.

### Why are the changes needed?

There are many type checks in the `astype` method. Since `DataTypeOps` and its subclasses are introduced, we should refactor `astype` to make it data-type-based. In this way, code is cleaner, more maintainable, and more flexible.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit tests.

Closes #32847 from xinrong-databricks/datatypeops_astype.

Authored-by: Xinrong Meng <xinrong.meng@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
Xinrong Meng 2021-06-14 16:33:15 -07:00 committed by Takuya UESHIN
parent aab0c2bf66
commit 03756618fc
21 changed files with 533 additions and 131 deletions

View file

@ -29,14 +29,9 @@ from pandas.api.types import is_list_like, CategoricalDtype
from pyspark import sql as spark from pyspark import sql as spark
from pyspark.sql import functions as F, Window, Column from pyspark.sql import functions as F, Window, Column
from pyspark.sql.types import ( from pyspark.sql.types import (
BooleanType,
DateType,
DoubleType, DoubleType,
FloatType, FloatType,
LongType, LongType,
NumericType,
StringType,
TimestampType,
) )
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
@ -52,7 +47,6 @@ from pyspark.pandas.spark.accessors import SparkIndexOpsMethods
from pyspark.pandas.typedef import ( from pyspark.pandas.typedef import (
Dtype, Dtype,
extension_dtypes, extension_dtypes,
pandas_on_spark_type,
) )
from pyspark.pandas.utils import ( from pyspark.pandas.utils import (
combine_frames, combine_frames,
@ -802,103 +796,7 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
>>> ser.rename("a").to_frame().set_index("a").index.astype('int64') >>> ser.rename("a").to_frame().set_index("a").index.astype('int64')
Int64Index([1, 2], dtype='int64', name='a') Int64Index([1, 2], dtype='int64', name='a')
""" """
dtype, spark_type = pandas_on_spark_type(dtype) return self._dtype_op.astype(self, dtype)
if not spark_type:
raise ValueError("Type {} not understood".format(dtype))
if isinstance(self.dtype, CategoricalDtype):
if isinstance(dtype, CategoricalDtype) and dtype.categories is None:
return cast(Union[ps.Index, ps.Series], self).copy()
categories = self.dtype.categories
if len(categories) == 0:
scol = F.lit(None)
else:
kvs = chain(
*[(F.lit(code), F.lit(category)) for code, category in enumerate(categories)]
)
map_scol = F.create_map(*kvs)
scol = map_scol.getItem(self.spark.column)
return self._with_new_scol(
scol.alias(self._internal.data_spark_column_names[0])
).astype(dtype)
elif isinstance(dtype, CategoricalDtype):
if dtype.categories is None:
codes, uniques = self.factorize()
return codes._with_new_scol(
codes.spark.column,
field=codes._internal.data_fields[0].copy(
dtype=CategoricalDtype(categories=uniques)
),
)
else:
categories = dtype.categories
if len(categories) == 0:
scol = F.lit(-1)
else:
kvs = chain(
*[
(F.lit(category), F.lit(code))
for code, category in enumerate(categories)
]
)
map_scol = F.create_map(*kvs)
scol = F.coalesce(map_scol.getItem(self.spark.column), F.lit(-1))
return self._with_new_scol(
scol.cast(spark_type).alias(self._internal.data_fields[0].name),
field=self._internal.data_fields[0].copy(
dtype=dtype, spark_type=spark_type, nullable=False
),
)
if isinstance(spark_type, BooleanType):
if isinstance(dtype, extension_dtypes):
scol = self.spark.column.cast(spark_type)
else:
if isinstance(self.spark.data_type, StringType):
scol = F.when(self.spark.column.isNull(), F.lit(False)).otherwise(
F.length(self.spark.column) > 0
)
elif isinstance(self.spark.data_type, (FloatType, DoubleType)):
scol = F.when(
self.spark.column.isNull() | F.isnan(self.spark.column), F.lit(True)
).otherwise(self.spark.column.cast(spark_type))
else:
scol = F.when(self.spark.column.isNull(), F.lit(False)).otherwise(
self.spark.column.cast(spark_type)
)
elif isinstance(spark_type, StringType):
if isinstance(dtype, extension_dtypes):
if isinstance(self.spark.data_type, BooleanType):
scol = F.when(
self.spark.column.isNotNull(),
F.when(self.spark.column, "True").otherwise("False"),
)
elif isinstance(self.spark.data_type, TimestampType):
# seems like a pandas' bug?
scol = F.when(self.spark.column.isNull(), str(pd.NaT)).otherwise(
self.spark.column.cast(spark_type)
)
else:
scol = self.spark.column.cast(spark_type)
else:
if isinstance(self.spark.data_type, NumericType):
null_str = str(np.nan)
elif isinstance(self.spark.data_type, (DateType, TimestampType)):
null_str = str(pd.NaT)
else:
null_str = str(None)
if isinstance(self.spark.data_type, BooleanType):
casted = F.when(self.spark.column, "True").otherwise("False")
else:
casted = self.spark.column.cast(spark_type)
scol = F.when(self.spark.column.isNull(), null_str).otherwise(casted)
else:
scol = self.spark.column.cast(spark_type)
return self._with_new_scol(
scol.alias(self._internal.data_spark_column_names[0]), field=InternalField(dtype=dtype)
)
def isin(self, values) -> Union["Series", "Index"]: def isin(self, values) -> Union["Series", "Index"]:
""" """

View file

@ -17,12 +17,14 @@
import numbers import numbers
from abc import ABCMeta from abc import ABCMeta
from itertools import chain
from typing import Any, Optional, TYPE_CHECKING, Union from typing import Any, Optional, TYPE_CHECKING, Union
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from pandas.api.types import CategoricalDtype from pandas.api.types import CategoricalDtype
from pyspark.sql import functions as F
from pyspark.sql.types import ( from pyspark.sql.types import (
ArrayType, ArrayType,
BinaryType, BinaryType,
@ -39,9 +41,7 @@ from pyspark.sql.types import (
TimestampType, TimestampType,
UserDefinedType, UserDefinedType,
) )
from pyspark.pandas.typedef import Dtype, extension_dtypes
import pyspark.sql.types as types
from pyspark.pandas.typedef import Dtype
from pyspark.pandas.typedef.typehints import extension_object_dtypes_available from pyspark.pandas.typedef.typehints import extension_object_dtypes_available
if extension_object_dtypes_available: if extension_object_dtypes_available:
@ -70,7 +70,7 @@ def is_valid_operand_for_numeric_arithmetic(operand: Any, *, allow_bool: bool =
def transform_boolean_operand_to_numeric( def transform_boolean_operand_to_numeric(
operand: Any, spark_type: Optional[types.DataType] = None operand: Any, spark_type: Optional[DataType] = None
) -> Any: ) -> Any:
"""Transform boolean operand to numeric. """Transform boolean operand to numeric.
@ -90,6 +90,99 @@ def transform_boolean_operand_to_numeric(
return operand return operand
def _as_categorical_type(
index_ops: Union["Series", "Index"], dtype: CategoricalDtype, spark_type: DataType
) -> Union["Index", "Series"]:
"""Cast `index_ops` to categorical dtype, given `dtype` and `spark_type`."""
assert isinstance(dtype, CategoricalDtype)
if dtype.categories is None:
codes, uniques = index_ops.factorize()
return codes._with_new_scol(
codes.spark.column,
field=codes._internal.data_fields[0].copy(dtype=CategoricalDtype(categories=uniques)),
)
else:
categories = dtype.categories
if len(categories) == 0:
scol = F.lit(-1)
else:
kvs = chain(
*[(F.lit(category), F.lit(code)) for code, category in enumerate(categories)]
)
map_scol = F.create_map(*kvs)
scol = F.coalesce(map_scol.getItem(index_ops.spark.column), F.lit(-1))
return index_ops._with_new_scol(
scol.cast(spark_type).alias(index_ops._internal.data_fields[0].name),
field=index_ops._internal.data_fields[0].copy(
dtype=dtype, spark_type=spark_type, nullable=False
),
)
def _as_bool_type(
index_ops: Union["Series", "Index"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
"""Cast `index_ops` to BooleanType Spark type, given `dtype`."""
from pyspark.pandas.internal import InternalField
if isinstance(dtype, extension_dtypes):
scol = index_ops.spark.column.cast(BooleanType())
else:
scol = F.when(index_ops.spark.column.isNull(), F.lit(False)).otherwise(
index_ops.spark.column.cast(BooleanType())
)
return index_ops._with_new_scol(
scol.alias(index_ops._internal.data_spark_column_names[0]),
field=InternalField(dtype=dtype),
)
def _as_string_type(
index_ops: Union["Series", "Index"],
dtype: Union[str, type, Dtype],
*,
null_str: str = str(None)
) -> Union["Index", "Series"]:
"""Cast `index_ops` to StringType Spark type, given `dtype` and `null_str`,
representing null Spark column.
"""
from pyspark.pandas.internal import InternalField
if isinstance(dtype, extension_dtypes):
scol = index_ops.spark.column.cast(StringType())
else:
casted = index_ops.spark.column.cast(StringType())
scol = F.when(index_ops.spark.column.isNull(), null_str).otherwise(casted)
return index_ops._with_new_scol(
scol.alias(index_ops._internal.data_spark_column_names[0]),
field=InternalField(dtype=dtype),
)
def _as_other_type(
index_ops: Union["Series", "Index"], dtype: Union[str, type, Dtype], spark_type: DataType
) -> Union["Index", "Series"]:
"""Cast `index_ops` to a `dtype` (`spark_type`) that needs no pre-processing.
Destination types that need pre-processing: CategoricalDtype, BooleanType, and StringType.
"""
from pyspark.pandas.internal import InternalField
need_pre_process = (
isinstance(dtype, CategoricalDtype)
or isinstance(spark_type, BooleanType)
or isinstance(spark_type, StringType)
)
assert not need_pre_process, "Pre-processing is needed before the type casting."
scol = index_ops.spark.column.cast(spark_type)
return index_ops._with_new_scol(
scol.alias(index_ops._internal.data_spark_column_names[0]),
field=InternalField(dtype=dtype),
)
class DataTypeOps(object, metaclass=ABCMeta): class DataTypeOps(object, metaclass=ABCMeta):
"""The base class for binary operations of pandas-on-Spark objects (of different data types).""" """The base class for binary operations of pandas-on-Spark objects (of different data types)."""
@ -206,3 +299,8 @@ class DataTypeOps(object, metaclass=ABCMeta):
def prepare(self, col: pd.Series) -> pd.Series: def prepare(self, col: pd.Series) -> pd.Series:
"""Prepare column when from_pandas.""" """Prepare column when from_pandas."""
return col.replace({np.nan: None}) return col.replace({np.nan: None})
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
raise TypeError("astype can not be applied to %s." % self.pretty_name)

View file

@ -17,10 +17,19 @@
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
from pyspark.sql import functions as F from pandas.api.types import CategoricalDtype
from pyspark.sql.types import BinaryType
from pyspark.pandas.base import column_op, IndexOpsMixin from pyspark.pandas.base import column_op, IndexOpsMixin
from pyspark.pandas.data_type_ops.base import DataTypeOps from pyspark.pandas.data_type_ops.base import (
DataTypeOps,
_as_bool_type,
_as_categorical_type,
_as_other_type,
_as_string_type,
)
from pyspark.pandas.typedef import Dtype, pandas_on_spark_type
from pyspark.sql import functions as F
from pyspark.sql.types import BinaryType, BooleanType, StringType
if TYPE_CHECKING: if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
@ -53,3 +62,17 @@ class BinaryOps(DataTypeOps):
raise TypeError( raise TypeError(
"Concatenation can not be applied to %s and the given type." % self.pretty_name "Concatenation can not be applied to %s and the given type." % self.pretty_name
) )
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
elif isinstance(spark_type, BooleanType):
return _as_bool_type(index_ops, dtype)
elif isinstance(spark_type, StringType):
return _as_string_type(index_ops, dtype)
else:
return _as_other_type(index_ops, dtype, spark_type)

View file

@ -19,6 +19,7 @@ import numbers
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
import pandas as pd import pandas as pd
from pandas.api.types import CategoricalDtype
from pyspark import sql as spark from pyspark import sql as spark
from pyspark.pandas.base import column_op, IndexOpsMixin from pyspark.pandas.base import column_op, IndexOpsMixin
@ -26,11 +27,15 @@ from pyspark.pandas.data_type_ops.base import (
is_valid_operand_for_numeric_arithmetic, is_valid_operand_for_numeric_arithmetic,
DataTypeOps, DataTypeOps,
transform_boolean_operand_to_numeric, transform_boolean_operand_to_numeric,
_as_bool_type,
_as_categorical_type,
_as_other_type,
) )
from pyspark.pandas.typedef import extension_dtypes from pyspark.pandas.internal import InternalField
from pyspark.pandas.typedef import Dtype, extension_dtypes, pandas_on_spark_type
from pyspark.pandas.typedef.typehints import as_spark_type from pyspark.pandas.typedef.typehints import as_spark_type
from pyspark.sql import functions as F from pyspark.sql import functions as F
from pyspark.sql.types import BooleanType from pyspark.sql.types import BooleanType, StringType
if TYPE_CHECKING: if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
@ -245,6 +250,32 @@ class BooleanOps(DataTypeOps):
return column_op(or_func)(left, right) return column_op(or_func)(left, right)
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
elif isinstance(spark_type, BooleanType):
return _as_bool_type(index_ops, dtype)
elif isinstance(spark_type, StringType):
if isinstance(dtype, extension_dtypes):
scol = F.when(
index_ops.spark.column.isNotNull(),
F.when(index_ops.spark.column, "True").otherwise("False"),
)
else:
null_str = str(None)
casted = F.when(index_ops.spark.column, "True").otherwise("False")
scol = F.when(index_ops.spark.column.isNull(), null_str).otherwise(casted)
return index_ops._with_new_scol(
scol.alias(index_ops._internal.data_spark_column_names[0]),
field=InternalField(dtype=dtype),
)
else:
return _as_other_type(index_ops, dtype, spark_type)
class BooleanExtensionOps(BooleanOps): class BooleanExtensionOps(BooleanOps):
""" """

View file

@ -15,9 +15,20 @@
# limitations under the License. # limitations under the License.
# #
import pandas as pd from itertools import chain
from typing import cast, TYPE_CHECKING, Union
import pandas as pd
from pandas.api.types import CategoricalDtype
import pyspark.pandas as ps
from pyspark.pandas.data_type_ops.base import DataTypeOps from pyspark.pandas.data_type_ops.base import DataTypeOps
from pyspark.pandas.typedef import Dtype, pandas_on_spark_type
from pyspark.sql import functions as F
if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
class CategoricalOps(DataTypeOps): class CategoricalOps(DataTypeOps):
@ -38,3 +49,24 @@ class CategoricalOps(DataTypeOps):
def prepare(self, col: pd.Series) -> pd.Series: def prepare(self, col: pd.Series) -> pd.Series:
"""Prepare column when from_pandas.""" """Prepare column when from_pandas."""
return col.cat.codes return col.cat.codes
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype) and dtype.categories is None:
return cast(Union[ps.Index, ps.Series], index_ops).copy()
categories = index_ops.dtype.categories
if len(categories) == 0:
scol = F.lit(None)
else:
kvs = chain(
*[(F.lit(code), F.lit(category)) for code, category in enumerate(categories)]
)
map_scol = F.create_map(*kvs)
scol = map_scol.getItem(index_ops.spark.column)
return index_ops._with_new_scol(
scol.alias(index_ops._internal.data_spark_column_names[0])
).astype(dtype)

View file

@ -17,10 +17,19 @@
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
from pandas.api.types import CategoricalDtype
from pyspark.pandas.base import column_op, IndexOpsMixin from pyspark.pandas.base import column_op, IndexOpsMixin
from pyspark.pandas.data_type_ops.base import DataTypeOps from pyspark.pandas.data_type_ops.base import (
DataTypeOps,
_as_bool_type,
_as_categorical_type,
_as_other_type,
_as_string_type,
)
from pyspark.pandas.typedef import Dtype, pandas_on_spark_type
from pyspark.sql import functions as F from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, NumericType from pyspark.sql.types import ArrayType, BooleanType, NumericType, StringType
if TYPE_CHECKING: if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
@ -56,6 +65,20 @@ class ArrayOps(DataTypeOps):
return column_op(F.concat)(left, right) return column_op(F.concat)(left, right)
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
elif isinstance(spark_type, BooleanType):
return _as_bool_type(index_ops, dtype)
elif isinstance(spark_type, StringType):
return _as_string_type(index_ops, dtype)
else:
return _as_other_type(index_ops, dtype, spark_type)
class MapOps(DataTypeOps): class MapOps(DataTypeOps):
""" """

View file

@ -19,11 +19,21 @@ import datetime
import warnings import warnings
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
import pandas as pd
from pandas.api.types import CategoricalDtype
from pyspark.sql import functions as F from pyspark.sql import functions as F
from pyspark.sql.types import DateType from pyspark.sql.types import BooleanType, DateType, StringType
from pyspark.pandas.base import column_op, IndexOpsMixin from pyspark.pandas.base import column_op, IndexOpsMixin
from pyspark.pandas.data_type_ops.base import DataTypeOps from pyspark.pandas.data_type_ops.base import (
DataTypeOps,
_as_bool_type,
_as_categorical_type,
_as_other_type,
_as_string_type,
)
from pyspark.pandas.typedef import Dtype, pandas_on_spark_type
if TYPE_CHECKING: if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
@ -69,3 +79,17 @@ class DateOps(DataTypeOps):
return -column_op(F.datediff)(left, F.lit(right)).astype("long") return -column_op(F.datediff)(left, F.lit(right)).astype("long")
else: else:
raise TypeError("date subtraction can only be applied to date series.") raise TypeError("date subtraction can only be applied to date series.")
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
elif isinstance(spark_type, BooleanType):
return _as_bool_type(index_ops, dtype)
elif isinstance(spark_type, StringType):
return _as_string_type(index_ops, dtype, null_str=str(pd.NaT))
else:
return _as_other_type(index_ops, dtype, spark_type)

View file

@ -19,12 +19,21 @@ import datetime
import warnings import warnings
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
import pandas as pd
from pandas.api.types import CategoricalDtype
from pyspark.sql import functions as F from pyspark.sql import functions as F
from pyspark.sql.types import TimestampType from pyspark.sql.types import BooleanType, StringType, TimestampType
from pyspark.pandas.base import IndexOpsMixin from pyspark.pandas.base import IndexOpsMixin
from pyspark.pandas.data_type_ops.base import DataTypeOps from pyspark.pandas.data_type_ops.base import (
from pyspark.pandas.typedef import as_spark_type DataTypeOps,
_as_bool_type,
_as_categorical_type,
_as_other_type,
)
from pyspark.pandas.internal import InternalField
from pyspark.pandas.typedef import as_spark_type, Dtype, extension_dtypes, pandas_on_spark_type
if TYPE_CHECKING: if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
@ -78,3 +87,29 @@ class DatetimeOps(DataTypeOps):
def prepare(self, col): def prepare(self, col):
"""Prepare column when from_pandas.""" """Prepare column when from_pandas."""
return col return col
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
elif isinstance(spark_type, BooleanType):
return _as_bool_type(index_ops, dtype)
elif isinstance(spark_type, StringType):
if isinstance(dtype, extension_dtypes):
# seems like a pandas' bug?
scol = F.when(index_ops.spark.column.isNull(), str(pd.NaT)).otherwise(
index_ops.spark.column.cast(spark_type)
)
else:
null_str = str(pd.NaT)
casted = index_ops.spark.column.cast(spark_type)
scol = F.when(index_ops.spark.column.isNull(), null_str).otherwise(casted)
return index_ops._with_new_scol(
scol.alias(index_ops._internal.data_spark_column_names[0]),
field=InternalField(dtype=dtype),
)
else:
return _as_other_type(index_ops, dtype, spark_type)

View file

@ -15,7 +15,23 @@
# limitations under the License. # limitations under the License.
# #
from pyspark.pandas.data_type_ops.base import DataTypeOps from typing import TYPE_CHECKING, Union
from pandas.api.types import CategoricalDtype
from pyspark.pandas.data_type_ops.base import (
DataTypeOps,
_as_bool_type,
_as_categorical_type,
_as_other_type,
_as_string_type,
)
from pyspark.pandas.typedef import Dtype, pandas_on_spark_type
from pyspark.sql.types import BooleanType, StringType
if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
class NullOps(DataTypeOps): class NullOps(DataTypeOps):
@ -26,3 +42,17 @@ class NullOps(DataTypeOps):
@property @property
def pretty_name(self) -> str: def pretty_name(self) -> str:
return "nulls" return "nulls"
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
elif isinstance(spark_type, BooleanType):
return _as_bool_type(index_ops, dtype)
elif isinstance(spark_type, StringType):
return _as_string_type(index_ops, dtype)
else:
return _as_other_type(index_ops, dtype, spark_type)

View file

@ -19,22 +19,30 @@ import numbers
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
import numpy as np import numpy as np
from pandas.api.types import CategoricalDtype
from pyspark.sql import functions as F
from pyspark.sql.types import (
StringType,
TimestampType,
)
from pyspark.pandas.base import column_op, IndexOpsMixin, numpy_column_op from pyspark.pandas.base import column_op, IndexOpsMixin, numpy_column_op
from pyspark.pandas.data_type_ops.base import ( from pyspark.pandas.data_type_ops.base import (
is_valid_operand_for_numeric_arithmetic, is_valid_operand_for_numeric_arithmetic,
DataTypeOps, DataTypeOps,
transform_boolean_operand_to_numeric, transform_boolean_operand_to_numeric,
_as_bool_type,
_as_categorical_type,
_as_other_type,
_as_string_type,
) )
from pyspark.pandas.internal import InternalField
from pyspark.pandas.spark import functions as SF from pyspark.pandas.spark import functions as SF
from pyspark.pandas.typedef import Dtype, extension_dtypes, pandas_on_spark_type
from pyspark.sql import functions as F
from pyspark.sql.column import Column from pyspark.sql.column import Column
from pyspark.sql.types import (
BooleanType,
DoubleType,
FloatType,
StringType,
TimestampType,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
@ -248,6 +256,20 @@ class IntegralOps(NumericOps):
right = transform_boolean_operand_to_numeric(right, left.spark.data_type) right = transform_boolean_operand_to_numeric(right, left.spark.data_type)
return numpy_column_op(rfloordiv)(left, right) return numpy_column_op(rfloordiv)(left, right)
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
elif isinstance(spark_type, BooleanType):
return _as_bool_type(index_ops, dtype)
elif isinstance(spark_type, StringType):
return _as_string_type(index_ops, dtype, null_str=str(np.nan))
else:
return _as_other_type(index_ops, dtype, spark_type)
class FractionalOps(NumericOps): class FractionalOps(NumericOps):
""" """
@ -344,3 +366,32 @@ class FractionalOps(NumericOps):
right = transform_boolean_operand_to_numeric(right, left.spark.data_type) right = transform_boolean_operand_to_numeric(right, left.spark.data_type)
return numpy_column_op(rfloordiv)(left, right) return numpy_column_op(rfloordiv)(left, right)
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
elif isinstance(spark_type, BooleanType):
if isinstance(dtype, extension_dtypes):
scol = index_ops.spark.column.cast(spark_type)
else:
if isinstance(index_ops.spark.data_type, (FloatType, DoubleType)):
scol = F.when(
index_ops.spark.column.isNull() | F.isnan(index_ops.spark.column),
F.lit(True),
).otherwise(index_ops.spark.column.cast(spark_type))
else: # DecimalType
scol = F.when(index_ops.spark.column.isNull(), F.lit(False)).otherwise(
index_ops.spark.column.cast(spark_type)
)
return index_ops._with_new_scol(
scol.alias(index_ops._internal.data_spark_column_names[0]),
field=InternalField(dtype=dtype),
)
elif isinstance(spark_type, StringType):
return _as_string_type(index_ops, dtype, null_str=str(np.nan))
else:
return _as_other_type(index_ops, dtype, spark_type)

View file

@ -23,8 +23,16 @@ from pyspark.sql import functions as F
from pyspark.sql.types import IntegralType, StringType from pyspark.sql.types import IntegralType, StringType
from pyspark.pandas.base import column_op, IndexOpsMixin from pyspark.pandas.base import column_op, IndexOpsMixin
from pyspark.pandas.data_type_ops.base import DataTypeOps from pyspark.pandas.data_type_ops.base import (
DataTypeOps,
_as_categorical_type,
_as_other_type,
_as_string_type,
)
from pyspark.pandas.internal import InternalField
from pyspark.pandas.spark import functions as SF from pyspark.pandas.spark import functions as SF
from pyspark.pandas.typedef import Dtype, extension_dtypes, pandas_on_spark_type
from pyspark.sql.types import BooleanType
if TYPE_CHECKING: if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
@ -102,3 +110,27 @@ class StringOps(DataTypeOps):
def rmod(self, left, right): def rmod(self, left, right):
raise TypeError("modulo can not be applied on string series or literals.") raise TypeError("modulo can not be applied on string series or literals.")
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
if isinstance(spark_type, BooleanType):
if isinstance(dtype, extension_dtypes):
scol = index_ops.spark.column.cast(spark_type)
else:
scol = F.when(index_ops.spark.column.isNull(), F.lit(False)).otherwise(
F.length(index_ops.spark.column) > 0
)
return index_ops._with_new_scol(
scol.alias(index_ops._internal.data_spark_column_names[0]),
field=InternalField(dtype=dtype),
)
elif isinstance(spark_type, StringType):
return _as_string_type(index_ops, dtype)
else:
return _as_other_type(index_ops, dtype, spark_type)

View file

@ -16,6 +16,7 @@
# #
import pandas as pd import pandas as pd
from pandas.api.types import CategoricalDtype
from pyspark import pandas as ps from pyspark import pandas as ps
from pyspark.pandas.config import option_context from pyspark.pandas.config import option_context
@ -147,6 +148,14 @@ class BinaryOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas()) self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser) self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
pser = self.pser
psser = self.psser
self.assert_eq(pd.Series(["1", "2", "3"]), psser.astype(str))
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=[b"2", b"3", b"1"])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest

View file

@ -21,6 +21,7 @@ from distutils.version import LooseVersion
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from pandas.api.types import CategoricalDtype
from pyspark import pandas as ps from pyspark import pandas as ps
from pyspark.pandas.config import option_context from pyspark.pandas.config import option_context
@ -287,6 +288,21 @@ class BooleanOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(True | pser, True | psser) self.assert_eq(True | pser, True | psser)
self.assert_eq(False | pser, False | psser) self.assert_eq(False | pser, False | psser)
def test_astype(self):
pser = self.pser
psser = self.psser
self.assert_eq(pser.astype(int), psser.astype(int))
self.assert_eq(pser.astype(float), psser.astype(float))
self.assert_eq(pser.astype(np.float32), psser.astype(np.float32))
self.assert_eq(pser.astype(np.int32), psser.astype(np.int32))
self.assert_eq(pser.astype(np.int16), psser.astype(np.int16))
self.assert_eq(pser.astype(np.int8), psser.astype(np.int8))
self.assert_eq(pser.astype(str), psser.astype(str))
self.assert_eq(pser.astype(bool), psser.astype(bool))
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=[False, True])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
@unittest.skipIf(not extension_dtypes_available, "pandas extension dtypes are not available") @unittest.skipIf(not extension_dtypes_available, "pandas extension dtypes are not available")
class BooleanExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils): class BooleanExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
@ -578,6 +594,14 @@ class BooleanExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas()) self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser) self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
pser = self.pser
psser = self.psser
self.assert_eq(["True", "False", "None"], self.psser.astype(str).tolist())
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=[False, True])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
if __name__ == "__main__": if __name__ == "__main__":
from pyspark.pandas.tests.data_type_ops.test_boolean_ops import * # noqa: F401 from pyspark.pandas.tests.data_type_ops.test_boolean_ops import * # noqa: F401

View file

@ -15,7 +15,11 @@
# limitations under the License. # limitations under the License.
# #
from distutils.version import LooseVersion
import pandas as pd import pandas as pd
import numpy as np
from pandas.api.types import CategoricalDtype
from pyspark import pandas as ps from pyspark import pandas as ps
from pyspark.pandas.config import option_context from pyspark.pandas.config import option_context
@ -140,6 +144,25 @@ class CategoricalOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas()) self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser) self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
data = [1, 2, 3]
pser = pd.Series(data, dtype="category")
psser = ps.from_pandas(pser)
self.assert_eq(pser.astype(int), psser.astype(int))
self.assert_eq(pser.astype(float), psser.astype(float))
self.assert_eq(pser.astype(np.float32), psser.astype(np.float32))
self.assert_eq(pser.astype(np.int32), psser.astype(np.int32))
self.assert_eq(pser.astype(np.int16), psser.astype(np.int16))
self.assert_eq(pser.astype(np.int8), psser.astype(np.int8))
self.assert_eq(pser.astype(str), psser.astype(str))
self.assert_eq(pser.astype(bool), psser.astype(bool))
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=[3, 1, 2])
if LooseVersion(pd.__version__) >= LooseVersion("1.2"):
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
else:
self.assert_eq(pd.Series(data).astype(cat_type), psser.astype(cat_type))
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest

View file

@ -65,9 +65,13 @@ class ComplexOpsTest(PandasOnSparkTestCase, TestCasesUtils):
def pssers(self): def pssers(self):
return self.numeric_array_pssers + list(self.non_numeric_array_pssers.values()) return self.numeric_array_pssers + list(self.non_numeric_array_pssers.values())
@property
def pser(self):
return pd.Series([[1, 2, 3]])
@property @property
def psser(self): def psser(self):
return ps.Series([[1, 2, 3]]) return ps.from_pandas(self.pser)
def test_add(self): def test_add(self):
for pser, psser in zip(self.psers, self.pssers): for pser, psser in zip(self.psers, self.pssers):
@ -213,6 +217,9 @@ class ComplexOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas()) self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser) self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
self.assert_eq(self.pser.astype(str), self.psser.astype(str))
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest

View file

@ -18,6 +18,7 @@
import datetime import datetime
import pandas as pd import pandas as pd
from pandas.api.types import CategoricalDtype
from pyspark.sql.types import DateType from pyspark.sql.types import DateType
@ -172,6 +173,14 @@ class DateOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas()) self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser) self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
pser = self.pser
psser = self.psser
self.assert_eq(pser.astype(str), psser.astype(str))
self.assert_eq(pd.Series([None, None, None]), psser.astype(bool))
cat_type = CategoricalDtype(categories=["a", "b", "c"])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest

View file

@ -19,6 +19,7 @@ import datetime
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from pandas.api.types import CategoricalDtype
from pyspark import pandas as ps from pyspark import pandas as ps
from pyspark.pandas.config import option_context from pyspark.pandas.config import option_context
@ -172,6 +173,14 @@ class DatetimeOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas()) self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser) self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
pser = self.pser
psser = self.psser
self.assert_eq(pser.astype(str), psser.astype(str))
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=["a", "b", "c"])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest

View file

@ -16,6 +16,7 @@
# #
import pandas as pd import pandas as pd
from pandas.api.types import CategoricalDtype
import pyspark.pandas as ps import pyspark.pandas as ps
from pyspark.pandas.config import option_context from pyspark.pandas.config import option_context
@ -122,6 +123,15 @@ class NullOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas()) self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser) self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
pser = self.pser
psser = self.psser
self.assert_eq(pser.astype(str), psser.astype(str))
self.assert_eq(pser.astype(bool), psser.astype(bool))
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=[1, 2, 3])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest

View file

@ -20,6 +20,7 @@ from distutils.version import LooseVersion
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from pandas.api.types import CategoricalDtype
from pyspark import pandas as ps from pyspark import pandas as ps
from pyspark.pandas.config import option_context from pyspark.pandas.config import option_context
@ -293,6 +294,20 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas()) self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser) self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
for pser, psser in self.numeric_pser_psser_pairs:
self.assert_eq(pser.astype(int), psser.astype(int))
self.assert_eq(pser.astype(float), psser.astype(float))
self.assert_eq(pser.astype(np.float32), psser.astype(np.float32))
self.assert_eq(pser.astype(np.int32), psser.astype(np.int32))
self.assert_eq(pser.astype(np.int16), psser.astype(np.int16))
self.assert_eq(pser.astype(np.int8), psser.astype(np.int8))
self.assert_eq(pser.astype(str), psser.astype(str))
self.assert_eq(pser.astype(bool), psser.astype(bool))
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=[2, 1, 3])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest

View file

@ -17,6 +17,7 @@
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from pandas.api.types import CategoricalDtype
from pyspark import pandas as ps from pyspark import pandas as ps
from pyspark.pandas.config import option_context from pyspark.pandas.config import option_context
@ -154,10 +155,25 @@ class StringOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas()) self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser) self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
pser = pd.Series(["1", "2", "3"])
psser = ps.from_pandas(pser)
self.assert_eq(pser.astype(int), psser.astype(int))
self.assert_eq(pser.astype(float), psser.astype(float))
self.assert_eq(pser.astype(np.float32), psser.astype(np.float32))
self.assert_eq(pser.astype(np.int32), psser.astype(np.int32))
self.assert_eq(pser.astype(np.int16), psser.astype(np.int16))
self.assert_eq(pser.astype(np.int8), psser.astype(np.int8))
self.assert_eq(pser.astype(str), psser.astype(str))
self.assert_eq(pser.astype(bool), psser.astype(bool))
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=["3", "1", "2"])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest
from pyspark.pandas.tests.data_type_ops.test_num_ops import * # noqa: F401 from pyspark.pandas.tests.data_type_ops.test_string_ops import * # noqa: F401
try: try:
import xmlrunner # type: ignore[import] import xmlrunner # type: ignore[import]

View file

@ -125,6 +125,9 @@ class UDTOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas()) self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser) self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
self.assertRaises(TypeError, lambda: self.psser.astype(str))
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest