[SPARK-35616][PYTHON] Make astype method data-type-based

### What changes were proposed in this pull request?

Make `astype` method data-type-based.

**Non-goal: Match pandas' `astype` TypeErrors.**
Currently, `astype` throws TypeError error messages only when the destination type is not recognized. However, for some destination types that don't make sense to the specific type of  Series/Index, for example, `numeric Series/Index → bytes`, we don't have proper TypeError error messages.
Since the goal of the PR is refactoring mainly, the above issue might be resolved later if needed.

### Why are the changes needed?

There are many type checks in the `astype` method. Since `DataTypeOps` and its subclasses are introduced, we should refactor `astype` to make it data-type-based. In this way, code is cleaner, more maintainable, and more flexible.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit tests.

Closes #32847 from xinrong-databricks/datatypeops_astype.

Authored-by: Xinrong Meng <xinrong.meng@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
Xinrong Meng 2021-06-14 16:33:15 -07:00 committed by Takuya UESHIN
parent aab0c2bf66
commit 03756618fc
21 changed files with 533 additions and 131 deletions

View file

@ -29,14 +29,9 @@ from pandas.api.types import is_list_like, CategoricalDtype
from pyspark import sql as spark
from pyspark.sql import functions as F, Window, Column
from pyspark.sql.types import (
BooleanType,
DateType,
DoubleType,
FloatType,
LongType,
NumericType,
StringType,
TimestampType,
)
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
@ -52,7 +47,6 @@ from pyspark.pandas.spark.accessors import SparkIndexOpsMethods
from pyspark.pandas.typedef import (
Dtype,
extension_dtypes,
pandas_on_spark_type,
)
from pyspark.pandas.utils import (
combine_frames,
@ -802,103 +796,7 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
>>> ser.rename("a").to_frame().set_index("a").index.astype('int64')
Int64Index([1, 2], dtype='int64', name='a')
"""
dtype, spark_type = pandas_on_spark_type(dtype)
if not spark_type:
raise ValueError("Type {} not understood".format(dtype))
if isinstance(self.dtype, CategoricalDtype):
if isinstance(dtype, CategoricalDtype) and dtype.categories is None:
return cast(Union[ps.Index, ps.Series], self).copy()
categories = self.dtype.categories
if len(categories) == 0:
scol = F.lit(None)
else:
kvs = chain(
*[(F.lit(code), F.lit(category)) for code, category in enumerate(categories)]
)
map_scol = F.create_map(*kvs)
scol = map_scol.getItem(self.spark.column)
return self._with_new_scol(
scol.alias(self._internal.data_spark_column_names[0])
).astype(dtype)
elif isinstance(dtype, CategoricalDtype):
if dtype.categories is None:
codes, uniques = self.factorize()
return codes._with_new_scol(
codes.spark.column,
field=codes._internal.data_fields[0].copy(
dtype=CategoricalDtype(categories=uniques)
),
)
else:
categories = dtype.categories
if len(categories) == 0:
scol = F.lit(-1)
else:
kvs = chain(
*[
(F.lit(category), F.lit(code))
for code, category in enumerate(categories)
]
)
map_scol = F.create_map(*kvs)
scol = F.coalesce(map_scol.getItem(self.spark.column), F.lit(-1))
return self._with_new_scol(
scol.cast(spark_type).alias(self._internal.data_fields[0].name),
field=self._internal.data_fields[0].copy(
dtype=dtype, spark_type=spark_type, nullable=False
),
)
if isinstance(spark_type, BooleanType):
if isinstance(dtype, extension_dtypes):
scol = self.spark.column.cast(spark_type)
else:
if isinstance(self.spark.data_type, StringType):
scol = F.when(self.spark.column.isNull(), F.lit(False)).otherwise(
F.length(self.spark.column) > 0
)
elif isinstance(self.spark.data_type, (FloatType, DoubleType)):
scol = F.when(
self.spark.column.isNull() | F.isnan(self.spark.column), F.lit(True)
).otherwise(self.spark.column.cast(spark_type))
else:
scol = F.when(self.spark.column.isNull(), F.lit(False)).otherwise(
self.spark.column.cast(spark_type)
)
elif isinstance(spark_type, StringType):
if isinstance(dtype, extension_dtypes):
if isinstance(self.spark.data_type, BooleanType):
scol = F.when(
self.spark.column.isNotNull(),
F.when(self.spark.column, "True").otherwise("False"),
)
elif isinstance(self.spark.data_type, TimestampType):
# seems like a pandas' bug?
scol = F.when(self.spark.column.isNull(), str(pd.NaT)).otherwise(
self.spark.column.cast(spark_type)
)
else:
scol = self.spark.column.cast(spark_type)
else:
if isinstance(self.spark.data_type, NumericType):
null_str = str(np.nan)
elif isinstance(self.spark.data_type, (DateType, TimestampType)):
null_str = str(pd.NaT)
else:
null_str = str(None)
if isinstance(self.spark.data_type, BooleanType):
casted = F.when(self.spark.column, "True").otherwise("False")
else:
casted = self.spark.column.cast(spark_type)
scol = F.when(self.spark.column.isNull(), null_str).otherwise(casted)
else:
scol = self.spark.column.cast(spark_type)
return self._with_new_scol(
scol.alias(self._internal.data_spark_column_names[0]), field=InternalField(dtype=dtype)
)
return self._dtype_op.astype(self, dtype)
def isin(self, values) -> Union["Series", "Index"]:
"""

View file

@ -17,12 +17,14 @@
import numbers
from abc import ABCMeta
from itertools import chain
from typing import Any, Optional, TYPE_CHECKING, Union
import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype
from pyspark.sql import functions as F
from pyspark.sql.types import (
ArrayType,
BinaryType,
@ -39,9 +41,7 @@ from pyspark.sql.types import (
TimestampType,
UserDefinedType,
)
import pyspark.sql.types as types
from pyspark.pandas.typedef import Dtype
from pyspark.pandas.typedef import Dtype, extension_dtypes
from pyspark.pandas.typedef.typehints import extension_object_dtypes_available
if extension_object_dtypes_available:
@ -70,7 +70,7 @@ def is_valid_operand_for_numeric_arithmetic(operand: Any, *, allow_bool: bool =
def transform_boolean_operand_to_numeric(
operand: Any, spark_type: Optional[types.DataType] = None
operand: Any, spark_type: Optional[DataType] = None
) -> Any:
"""Transform boolean operand to numeric.
@ -90,6 +90,99 @@ def transform_boolean_operand_to_numeric(
return operand
def _as_categorical_type(
index_ops: Union["Series", "Index"], dtype: CategoricalDtype, spark_type: DataType
) -> Union["Index", "Series"]:
"""Cast `index_ops` to categorical dtype, given `dtype` and `spark_type`."""
assert isinstance(dtype, CategoricalDtype)
if dtype.categories is None:
codes, uniques = index_ops.factorize()
return codes._with_new_scol(
codes.spark.column,
field=codes._internal.data_fields[0].copy(dtype=CategoricalDtype(categories=uniques)),
)
else:
categories = dtype.categories
if len(categories) == 0:
scol = F.lit(-1)
else:
kvs = chain(
*[(F.lit(category), F.lit(code)) for code, category in enumerate(categories)]
)
map_scol = F.create_map(*kvs)
scol = F.coalesce(map_scol.getItem(index_ops.spark.column), F.lit(-1))
return index_ops._with_new_scol(
scol.cast(spark_type).alias(index_ops._internal.data_fields[0].name),
field=index_ops._internal.data_fields[0].copy(
dtype=dtype, spark_type=spark_type, nullable=False
),
)
def _as_bool_type(
index_ops: Union["Series", "Index"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
"""Cast `index_ops` to BooleanType Spark type, given `dtype`."""
from pyspark.pandas.internal import InternalField
if isinstance(dtype, extension_dtypes):
scol = index_ops.spark.column.cast(BooleanType())
else:
scol = F.when(index_ops.spark.column.isNull(), F.lit(False)).otherwise(
index_ops.spark.column.cast(BooleanType())
)
return index_ops._with_new_scol(
scol.alias(index_ops._internal.data_spark_column_names[0]),
field=InternalField(dtype=dtype),
)
def _as_string_type(
index_ops: Union["Series", "Index"],
dtype: Union[str, type, Dtype],
*,
null_str: str = str(None)
) -> Union["Index", "Series"]:
"""Cast `index_ops` to StringType Spark type, given `dtype` and `null_str`,
representing null Spark column.
"""
from pyspark.pandas.internal import InternalField
if isinstance(dtype, extension_dtypes):
scol = index_ops.spark.column.cast(StringType())
else:
casted = index_ops.spark.column.cast(StringType())
scol = F.when(index_ops.spark.column.isNull(), null_str).otherwise(casted)
return index_ops._with_new_scol(
scol.alias(index_ops._internal.data_spark_column_names[0]),
field=InternalField(dtype=dtype),
)
def _as_other_type(
index_ops: Union["Series", "Index"], dtype: Union[str, type, Dtype], spark_type: DataType
) -> Union["Index", "Series"]:
"""Cast `index_ops` to a `dtype` (`spark_type`) that needs no pre-processing.
Destination types that need pre-processing: CategoricalDtype, BooleanType, and StringType.
"""
from pyspark.pandas.internal import InternalField
need_pre_process = (
isinstance(dtype, CategoricalDtype)
or isinstance(spark_type, BooleanType)
or isinstance(spark_type, StringType)
)
assert not need_pre_process, "Pre-processing is needed before the type casting."
scol = index_ops.spark.column.cast(spark_type)
return index_ops._with_new_scol(
scol.alias(index_ops._internal.data_spark_column_names[0]),
field=InternalField(dtype=dtype),
)
class DataTypeOps(object, metaclass=ABCMeta):
"""The base class for binary operations of pandas-on-Spark objects (of different data types)."""
@ -206,3 +299,8 @@ class DataTypeOps(object, metaclass=ABCMeta):
def prepare(self, col: pd.Series) -> pd.Series:
"""Prepare column when from_pandas."""
return col.replace({np.nan: None})
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
raise TypeError("astype can not be applied to %s." % self.pretty_name)

View file

@ -17,10 +17,19 @@
from typing import TYPE_CHECKING, Union
from pyspark.sql import functions as F
from pyspark.sql.types import BinaryType
from pandas.api.types import CategoricalDtype
from pyspark.pandas.base import column_op, IndexOpsMixin
from pyspark.pandas.data_type_ops.base import DataTypeOps
from pyspark.pandas.data_type_ops.base import (
DataTypeOps,
_as_bool_type,
_as_categorical_type,
_as_other_type,
_as_string_type,
)
from pyspark.pandas.typedef import Dtype, pandas_on_spark_type
from pyspark.sql import functions as F
from pyspark.sql.types import BinaryType, BooleanType, StringType
if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
@ -53,3 +62,17 @@ class BinaryOps(DataTypeOps):
raise TypeError(
"Concatenation can not be applied to %s and the given type." % self.pretty_name
)
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
elif isinstance(spark_type, BooleanType):
return _as_bool_type(index_ops, dtype)
elif isinstance(spark_type, StringType):
return _as_string_type(index_ops, dtype)
else:
return _as_other_type(index_ops, dtype, spark_type)

View file

@ -19,6 +19,7 @@ import numbers
from typing import TYPE_CHECKING, Union
import pandas as pd
from pandas.api.types import CategoricalDtype
from pyspark import sql as spark
from pyspark.pandas.base import column_op, IndexOpsMixin
@ -26,11 +27,15 @@ from pyspark.pandas.data_type_ops.base import (
is_valid_operand_for_numeric_arithmetic,
DataTypeOps,
transform_boolean_operand_to_numeric,
_as_bool_type,
_as_categorical_type,
_as_other_type,
)
from pyspark.pandas.typedef import extension_dtypes
from pyspark.pandas.internal import InternalField
from pyspark.pandas.typedef import Dtype, extension_dtypes, pandas_on_spark_type
from pyspark.pandas.typedef.typehints import as_spark_type
from pyspark.sql import functions as F
from pyspark.sql.types import BooleanType
from pyspark.sql.types import BooleanType, StringType
if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
@ -245,6 +250,32 @@ class BooleanOps(DataTypeOps):
return column_op(or_func)(left, right)
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
elif isinstance(spark_type, BooleanType):
return _as_bool_type(index_ops, dtype)
elif isinstance(spark_type, StringType):
if isinstance(dtype, extension_dtypes):
scol = F.when(
index_ops.spark.column.isNotNull(),
F.when(index_ops.spark.column, "True").otherwise("False"),
)
else:
null_str = str(None)
casted = F.when(index_ops.spark.column, "True").otherwise("False")
scol = F.when(index_ops.spark.column.isNull(), null_str).otherwise(casted)
return index_ops._with_new_scol(
scol.alias(index_ops._internal.data_spark_column_names[0]),
field=InternalField(dtype=dtype),
)
else:
return _as_other_type(index_ops, dtype, spark_type)
class BooleanExtensionOps(BooleanOps):
"""

View file

@ -15,9 +15,20 @@
# limitations under the License.
#
import pandas as pd
from itertools import chain
from typing import cast, TYPE_CHECKING, Union
import pandas as pd
from pandas.api.types import CategoricalDtype
import pyspark.pandas as ps
from pyspark.pandas.data_type_ops.base import DataTypeOps
from pyspark.pandas.typedef import Dtype, pandas_on_spark_type
from pyspark.sql import functions as F
if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
class CategoricalOps(DataTypeOps):
@ -38,3 +49,24 @@ class CategoricalOps(DataTypeOps):
def prepare(self, col: pd.Series) -> pd.Series:
"""Prepare column when from_pandas."""
return col.cat.codes
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype) and dtype.categories is None:
return cast(Union[ps.Index, ps.Series], index_ops).copy()
categories = index_ops.dtype.categories
if len(categories) == 0:
scol = F.lit(None)
else:
kvs = chain(
*[(F.lit(code), F.lit(category)) for code, category in enumerate(categories)]
)
map_scol = F.create_map(*kvs)
scol = map_scol.getItem(index_ops.spark.column)
return index_ops._with_new_scol(
scol.alias(index_ops._internal.data_spark_column_names[0])
).astype(dtype)

View file

@ -17,10 +17,19 @@
from typing import TYPE_CHECKING, Union
from pandas.api.types import CategoricalDtype
from pyspark.pandas.base import column_op, IndexOpsMixin
from pyspark.pandas.data_type_ops.base import DataTypeOps
from pyspark.pandas.data_type_ops.base import (
DataTypeOps,
_as_bool_type,
_as_categorical_type,
_as_other_type,
_as_string_type,
)
from pyspark.pandas.typedef import Dtype, pandas_on_spark_type
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, NumericType
from pyspark.sql.types import ArrayType, BooleanType, NumericType, StringType
if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
@ -56,6 +65,20 @@ class ArrayOps(DataTypeOps):
return column_op(F.concat)(left, right)
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
elif isinstance(spark_type, BooleanType):
return _as_bool_type(index_ops, dtype)
elif isinstance(spark_type, StringType):
return _as_string_type(index_ops, dtype)
else:
return _as_other_type(index_ops, dtype, spark_type)
class MapOps(DataTypeOps):
"""

View file

@ -19,11 +19,21 @@ import datetime
import warnings
from typing import TYPE_CHECKING, Union
import pandas as pd
from pandas.api.types import CategoricalDtype
from pyspark.sql import functions as F
from pyspark.sql.types import DateType
from pyspark.sql.types import BooleanType, DateType, StringType
from pyspark.pandas.base import column_op, IndexOpsMixin
from pyspark.pandas.data_type_ops.base import DataTypeOps
from pyspark.pandas.data_type_ops.base import (
DataTypeOps,
_as_bool_type,
_as_categorical_type,
_as_other_type,
_as_string_type,
)
from pyspark.pandas.typedef import Dtype, pandas_on_spark_type
if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
@ -69,3 +79,17 @@ class DateOps(DataTypeOps):
return -column_op(F.datediff)(left, F.lit(right)).astype("long")
else:
raise TypeError("date subtraction can only be applied to date series.")
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
elif isinstance(spark_type, BooleanType):
return _as_bool_type(index_ops, dtype)
elif isinstance(spark_type, StringType):
return _as_string_type(index_ops, dtype, null_str=str(pd.NaT))
else:
return _as_other_type(index_ops, dtype, spark_type)

View file

@ -19,12 +19,21 @@ import datetime
import warnings
from typing import TYPE_CHECKING, Union
import pandas as pd
from pandas.api.types import CategoricalDtype
from pyspark.sql import functions as F
from pyspark.sql.types import TimestampType
from pyspark.sql.types import BooleanType, StringType, TimestampType
from pyspark.pandas.base import IndexOpsMixin
from pyspark.pandas.data_type_ops.base import DataTypeOps
from pyspark.pandas.typedef import as_spark_type
from pyspark.pandas.data_type_ops.base import (
DataTypeOps,
_as_bool_type,
_as_categorical_type,
_as_other_type,
)
from pyspark.pandas.internal import InternalField
from pyspark.pandas.typedef import as_spark_type, Dtype, extension_dtypes, pandas_on_spark_type
if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
@ -78,3 +87,29 @@ class DatetimeOps(DataTypeOps):
def prepare(self, col):
"""Prepare column when from_pandas."""
return col
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
elif isinstance(spark_type, BooleanType):
return _as_bool_type(index_ops, dtype)
elif isinstance(spark_type, StringType):
if isinstance(dtype, extension_dtypes):
# seems like a pandas' bug?
scol = F.when(index_ops.spark.column.isNull(), str(pd.NaT)).otherwise(
index_ops.spark.column.cast(spark_type)
)
else:
null_str = str(pd.NaT)
casted = index_ops.spark.column.cast(spark_type)
scol = F.when(index_ops.spark.column.isNull(), null_str).otherwise(casted)
return index_ops._with_new_scol(
scol.alias(index_ops._internal.data_spark_column_names[0]),
field=InternalField(dtype=dtype),
)
else:
return _as_other_type(index_ops, dtype, spark_type)

View file

@ -15,7 +15,23 @@
# limitations under the License.
#
from pyspark.pandas.data_type_ops.base import DataTypeOps
from typing import TYPE_CHECKING, Union
from pandas.api.types import CategoricalDtype
from pyspark.pandas.data_type_ops.base import (
DataTypeOps,
_as_bool_type,
_as_categorical_type,
_as_other_type,
_as_string_type,
)
from pyspark.pandas.typedef import Dtype, pandas_on_spark_type
from pyspark.sql.types import BooleanType, StringType
if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
class NullOps(DataTypeOps):
@ -26,3 +42,17 @@ class NullOps(DataTypeOps):
@property
def pretty_name(self) -> str:
return "nulls"
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
elif isinstance(spark_type, BooleanType):
return _as_bool_type(index_ops, dtype)
elif isinstance(spark_type, StringType):
return _as_string_type(index_ops, dtype)
else:
return _as_other_type(index_ops, dtype, spark_type)

View file

@ -19,22 +19,30 @@ import numbers
from typing import TYPE_CHECKING, Union
import numpy as np
from pyspark.sql import functions as F
from pyspark.sql.types import (
StringType,
TimestampType,
)
from pandas.api.types import CategoricalDtype
from pyspark.pandas.base import column_op, IndexOpsMixin, numpy_column_op
from pyspark.pandas.data_type_ops.base import (
is_valid_operand_for_numeric_arithmetic,
DataTypeOps,
transform_boolean_operand_to_numeric,
_as_bool_type,
_as_categorical_type,
_as_other_type,
_as_string_type,
)
from pyspark.pandas.internal import InternalField
from pyspark.pandas.spark import functions as SF
from pyspark.pandas.typedef import Dtype, extension_dtypes, pandas_on_spark_type
from pyspark.sql import functions as F
from pyspark.sql.column import Column
from pyspark.sql.types import (
BooleanType,
DoubleType,
FloatType,
StringType,
TimestampType,
)
if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
@ -248,6 +256,20 @@ class IntegralOps(NumericOps):
right = transform_boolean_operand_to_numeric(right, left.spark.data_type)
return numpy_column_op(rfloordiv)(left, right)
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
elif isinstance(spark_type, BooleanType):
return _as_bool_type(index_ops, dtype)
elif isinstance(spark_type, StringType):
return _as_string_type(index_ops, dtype, null_str=str(np.nan))
else:
return _as_other_type(index_ops, dtype, spark_type)
class FractionalOps(NumericOps):
"""
@ -344,3 +366,32 @@ class FractionalOps(NumericOps):
right = transform_boolean_operand_to_numeric(right, left.spark.data_type)
return numpy_column_op(rfloordiv)(left, right)
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
elif isinstance(spark_type, BooleanType):
if isinstance(dtype, extension_dtypes):
scol = index_ops.spark.column.cast(spark_type)
else:
if isinstance(index_ops.spark.data_type, (FloatType, DoubleType)):
scol = F.when(
index_ops.spark.column.isNull() | F.isnan(index_ops.spark.column),
F.lit(True),
).otherwise(index_ops.spark.column.cast(spark_type))
else: # DecimalType
scol = F.when(index_ops.spark.column.isNull(), F.lit(False)).otherwise(
index_ops.spark.column.cast(spark_type)
)
return index_ops._with_new_scol(
scol.alias(index_ops._internal.data_spark_column_names[0]),
field=InternalField(dtype=dtype),
)
elif isinstance(spark_type, StringType):
return _as_string_type(index_ops, dtype, null_str=str(np.nan))
else:
return _as_other_type(index_ops, dtype, spark_type)

View file

@ -23,8 +23,16 @@ from pyspark.sql import functions as F
from pyspark.sql.types import IntegralType, StringType
from pyspark.pandas.base import column_op, IndexOpsMixin
from pyspark.pandas.data_type_ops.base import DataTypeOps
from pyspark.pandas.data_type_ops.base import (
DataTypeOps,
_as_categorical_type,
_as_other_type,
_as_string_type,
)
from pyspark.pandas.internal import InternalField
from pyspark.pandas.spark import functions as SF
from pyspark.pandas.typedef import Dtype, extension_dtypes, pandas_on_spark_type
from pyspark.sql.types import BooleanType
if TYPE_CHECKING:
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
@ -102,3 +110,27 @@ class StringOps(DataTypeOps):
def rmod(self, left, right):
raise TypeError("modulo can not be applied on string series or literals.")
def astype(
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
) -> Union["Index", "Series"]:
dtype, spark_type = pandas_on_spark_type(dtype)
if isinstance(dtype, CategoricalDtype):
return _as_categorical_type(index_ops, dtype, spark_type)
if isinstance(spark_type, BooleanType):
if isinstance(dtype, extension_dtypes):
scol = index_ops.spark.column.cast(spark_type)
else:
scol = F.when(index_ops.spark.column.isNull(), F.lit(False)).otherwise(
F.length(index_ops.spark.column) > 0
)
return index_ops._with_new_scol(
scol.alias(index_ops._internal.data_spark_column_names[0]),
field=InternalField(dtype=dtype),
)
elif isinstance(spark_type, StringType):
return _as_string_type(index_ops, dtype)
else:
return _as_other_type(index_ops, dtype, spark_type)

View file

@ -16,6 +16,7 @@
#
import pandas as pd
from pandas.api.types import CategoricalDtype
from pyspark import pandas as ps
from pyspark.pandas.config import option_context
@ -147,6 +148,14 @@ class BinaryOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
pser = self.pser
psser = self.psser
self.assert_eq(pd.Series(["1", "2", "3"]), psser.astype(str))
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=[b"2", b"3", b"1"])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
if __name__ == "__main__":
import unittest

View file

@ -21,6 +21,7 @@ from distutils.version import LooseVersion
import pandas as pd
import numpy as np
from pandas.api.types import CategoricalDtype
from pyspark import pandas as ps
from pyspark.pandas.config import option_context
@ -287,6 +288,21 @@ class BooleanOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(True | pser, True | psser)
self.assert_eq(False | pser, False | psser)
def test_astype(self):
pser = self.pser
psser = self.psser
self.assert_eq(pser.astype(int), psser.astype(int))
self.assert_eq(pser.astype(float), psser.astype(float))
self.assert_eq(pser.astype(np.float32), psser.astype(np.float32))
self.assert_eq(pser.astype(np.int32), psser.astype(np.int32))
self.assert_eq(pser.astype(np.int16), psser.astype(np.int16))
self.assert_eq(pser.astype(np.int8), psser.astype(np.int8))
self.assert_eq(pser.astype(str), psser.astype(str))
self.assert_eq(pser.astype(bool), psser.astype(bool))
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=[False, True])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
@unittest.skipIf(not extension_dtypes_available, "pandas extension dtypes are not available")
class BooleanExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
@ -578,6 +594,14 @@ class BooleanExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
pser = self.pser
psser = self.psser
self.assert_eq(["True", "False", "None"], self.psser.astype(str).tolist())
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=[False, True])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
if __name__ == "__main__":
from pyspark.pandas.tests.data_type_ops.test_boolean_ops import * # noqa: F401

View file

@ -15,7 +15,11 @@
# limitations under the License.
#
from distutils.version import LooseVersion
import pandas as pd
import numpy as np
from pandas.api.types import CategoricalDtype
from pyspark import pandas as ps
from pyspark.pandas.config import option_context
@ -140,6 +144,25 @@ class CategoricalOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
data = [1, 2, 3]
pser = pd.Series(data, dtype="category")
psser = ps.from_pandas(pser)
self.assert_eq(pser.astype(int), psser.astype(int))
self.assert_eq(pser.astype(float), psser.astype(float))
self.assert_eq(pser.astype(np.float32), psser.astype(np.float32))
self.assert_eq(pser.astype(np.int32), psser.astype(np.int32))
self.assert_eq(pser.astype(np.int16), psser.astype(np.int16))
self.assert_eq(pser.astype(np.int8), psser.astype(np.int8))
self.assert_eq(pser.astype(str), psser.astype(str))
self.assert_eq(pser.astype(bool), psser.astype(bool))
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=[3, 1, 2])
if LooseVersion(pd.__version__) >= LooseVersion("1.2"):
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
else:
self.assert_eq(pd.Series(data).astype(cat_type), psser.astype(cat_type))
if __name__ == "__main__":
import unittest

View file

@ -65,9 +65,13 @@ class ComplexOpsTest(PandasOnSparkTestCase, TestCasesUtils):
def pssers(self):
return self.numeric_array_pssers + list(self.non_numeric_array_pssers.values())
@property
def pser(self):
return pd.Series([[1, 2, 3]])
@property
def psser(self):
return ps.Series([[1, 2, 3]])
return ps.from_pandas(self.pser)
def test_add(self):
for pser, psser in zip(self.psers, self.pssers):
@ -213,6 +217,9 @@ class ComplexOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
self.assert_eq(self.pser.astype(str), self.psser.astype(str))
if __name__ == "__main__":
import unittest

View file

@ -18,6 +18,7 @@
import datetime
import pandas as pd
from pandas.api.types import CategoricalDtype
from pyspark.sql.types import DateType
@ -172,6 +173,14 @@ class DateOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
pser = self.pser
psser = self.psser
self.assert_eq(pser.astype(str), psser.astype(str))
self.assert_eq(pd.Series([None, None, None]), psser.astype(bool))
cat_type = CategoricalDtype(categories=["a", "b", "c"])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
if __name__ == "__main__":
import unittest

View file

@ -19,6 +19,7 @@ import datetime
import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype
from pyspark import pandas as ps
from pyspark.pandas.config import option_context
@ -172,6 +173,14 @@ class DatetimeOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
pser = self.pser
psser = self.psser
self.assert_eq(pser.astype(str), psser.astype(str))
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=["a", "b", "c"])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
if __name__ == "__main__":
import unittest

View file

@ -16,6 +16,7 @@
#
import pandas as pd
from pandas.api.types import CategoricalDtype
import pyspark.pandas as ps
from pyspark.pandas.config import option_context
@ -122,6 +123,15 @@ class NullOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
pser = self.pser
psser = self.psser
self.assert_eq(pser.astype(str), psser.astype(str))
self.assert_eq(pser.astype(bool), psser.astype(bool))
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=[1, 2, 3])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
if __name__ == "__main__":
import unittest

View file

@ -20,6 +20,7 @@ from distutils.version import LooseVersion
import pandas as pd
import numpy as np
from pandas.api.types import CategoricalDtype
from pyspark import pandas as ps
from pyspark.pandas.config import option_context
@ -293,6 +294,20 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
for pser, psser in self.numeric_pser_psser_pairs:
self.assert_eq(pser.astype(int), psser.astype(int))
self.assert_eq(pser.astype(float), psser.astype(float))
self.assert_eq(pser.astype(np.float32), psser.astype(np.float32))
self.assert_eq(pser.astype(np.int32), psser.astype(np.int32))
self.assert_eq(pser.astype(np.int16), psser.astype(np.int16))
self.assert_eq(pser.astype(np.int8), psser.astype(np.int8))
self.assert_eq(pser.astype(str), psser.astype(str))
self.assert_eq(pser.astype(bool), psser.astype(bool))
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=[2, 1, 3])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
if __name__ == "__main__":
import unittest

View file

@ -17,6 +17,7 @@
import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype
from pyspark import pandas as ps
from pyspark.pandas.config import option_context
@ -154,10 +155,25 @@ class StringOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
pser = pd.Series(["1", "2", "3"])
psser = ps.from_pandas(pser)
self.assert_eq(pser.astype(int), psser.astype(int))
self.assert_eq(pser.astype(float), psser.astype(float))
self.assert_eq(pser.astype(np.float32), psser.astype(np.float32))
self.assert_eq(pser.astype(np.int32), psser.astype(np.int32))
self.assert_eq(pser.astype(np.int16), psser.astype(np.int16))
self.assert_eq(pser.astype(np.int8), psser.astype(np.int8))
self.assert_eq(pser.astype(str), psser.astype(str))
self.assert_eq(pser.astype(bool), psser.astype(bool))
self.assert_eq(pser.astype("category"), psser.astype("category"))
cat_type = CategoricalDtype(categories=["3", "1", "2"])
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.data_type_ops.test_num_ops import * # noqa: F401
from pyspark.pandas.tests.data_type_ops.test_string_ops import * # noqa: F401
try:
import xmlrunner # type: ignore[import]

View file

@ -125,6 +125,9 @@ class UDTOpsTest(PandasOnSparkTestCase, TestCasesUtils):
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)
def test_astype(self):
self.assertRaises(TypeError, lambda: self.psser.astype(str))
if __name__ == "__main__":
import unittest