[SPARK-35616][PYTHON] Make astype
method data-type-based
### What changes were proposed in this pull request? Make `astype` method data-type-based. **Non-goal: Match pandas' `astype` TypeErrors.** Currently, `astype` throws TypeError error messages only when the destination type is not recognized. However, for some destination types that don't make sense to the specific type of Series/Index, for example, `numeric Series/Index → bytes`, we don't have proper TypeError error messages. Since the goal of the PR is refactoring mainly, the above issue might be resolved later if needed. ### Why are the changes needed? There are many type checks in the `astype` method. Since `DataTypeOps` and its subclasses are introduced, we should refactor `astype` to make it data-type-based. In this way, code is cleaner, more maintainable, and more flexible. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. Closes #32847 from xinrong-databricks/datatypeops_astype. Authored-by: Xinrong Meng <xinrong.meng@databricks.com> Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
parent
aab0c2bf66
commit
03756618fc
|
@ -29,14 +29,9 @@ from pandas.api.types import is_list_like, CategoricalDtype
|
|||
from pyspark import sql as spark
|
||||
from pyspark.sql import functions as F, Window, Column
|
||||
from pyspark.sql.types import (
|
||||
BooleanType,
|
||||
DateType,
|
||||
DoubleType,
|
||||
FloatType,
|
||||
LongType,
|
||||
NumericType,
|
||||
StringType,
|
||||
TimestampType,
|
||||
)
|
||||
|
||||
from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm.
|
||||
|
@ -52,7 +47,6 @@ from pyspark.pandas.spark.accessors import SparkIndexOpsMethods
|
|||
from pyspark.pandas.typedef import (
|
||||
Dtype,
|
||||
extension_dtypes,
|
||||
pandas_on_spark_type,
|
||||
)
|
||||
from pyspark.pandas.utils import (
|
||||
combine_frames,
|
||||
|
@ -802,103 +796,7 @@ class IndexOpsMixin(object, metaclass=ABCMeta):
|
|||
>>> ser.rename("a").to_frame().set_index("a").index.astype('int64')
|
||||
Int64Index([1, 2], dtype='int64', name='a')
|
||||
"""
|
||||
dtype, spark_type = pandas_on_spark_type(dtype)
|
||||
if not spark_type:
|
||||
raise ValueError("Type {} not understood".format(dtype))
|
||||
|
||||
if isinstance(self.dtype, CategoricalDtype):
|
||||
if isinstance(dtype, CategoricalDtype) and dtype.categories is None:
|
||||
return cast(Union[ps.Index, ps.Series], self).copy()
|
||||
|
||||
categories = self.dtype.categories
|
||||
if len(categories) == 0:
|
||||
scol = F.lit(None)
|
||||
else:
|
||||
kvs = chain(
|
||||
*[(F.lit(code), F.lit(category)) for code, category in enumerate(categories)]
|
||||
)
|
||||
map_scol = F.create_map(*kvs)
|
||||
scol = map_scol.getItem(self.spark.column)
|
||||
return self._with_new_scol(
|
||||
scol.alias(self._internal.data_spark_column_names[0])
|
||||
).astype(dtype)
|
||||
elif isinstance(dtype, CategoricalDtype):
|
||||
if dtype.categories is None:
|
||||
codes, uniques = self.factorize()
|
||||
return codes._with_new_scol(
|
||||
codes.spark.column,
|
||||
field=codes._internal.data_fields[0].copy(
|
||||
dtype=CategoricalDtype(categories=uniques)
|
||||
),
|
||||
)
|
||||
else:
|
||||
categories = dtype.categories
|
||||
if len(categories) == 0:
|
||||
scol = F.lit(-1)
|
||||
else:
|
||||
kvs = chain(
|
||||
*[
|
||||
(F.lit(category), F.lit(code))
|
||||
for code, category in enumerate(categories)
|
||||
]
|
||||
)
|
||||
map_scol = F.create_map(*kvs)
|
||||
|
||||
scol = F.coalesce(map_scol.getItem(self.spark.column), F.lit(-1))
|
||||
return self._with_new_scol(
|
||||
scol.cast(spark_type).alias(self._internal.data_fields[0].name),
|
||||
field=self._internal.data_fields[0].copy(
|
||||
dtype=dtype, spark_type=spark_type, nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
if isinstance(spark_type, BooleanType):
|
||||
if isinstance(dtype, extension_dtypes):
|
||||
scol = self.spark.column.cast(spark_type)
|
||||
else:
|
||||
if isinstance(self.spark.data_type, StringType):
|
||||
scol = F.when(self.spark.column.isNull(), F.lit(False)).otherwise(
|
||||
F.length(self.spark.column) > 0
|
||||
)
|
||||
elif isinstance(self.spark.data_type, (FloatType, DoubleType)):
|
||||
scol = F.when(
|
||||
self.spark.column.isNull() | F.isnan(self.spark.column), F.lit(True)
|
||||
).otherwise(self.spark.column.cast(spark_type))
|
||||
else:
|
||||
scol = F.when(self.spark.column.isNull(), F.lit(False)).otherwise(
|
||||
self.spark.column.cast(spark_type)
|
||||
)
|
||||
elif isinstance(spark_type, StringType):
|
||||
if isinstance(dtype, extension_dtypes):
|
||||
if isinstance(self.spark.data_type, BooleanType):
|
||||
scol = F.when(
|
||||
self.spark.column.isNotNull(),
|
||||
F.when(self.spark.column, "True").otherwise("False"),
|
||||
)
|
||||
elif isinstance(self.spark.data_type, TimestampType):
|
||||
# seems like a pandas' bug?
|
||||
scol = F.when(self.spark.column.isNull(), str(pd.NaT)).otherwise(
|
||||
self.spark.column.cast(spark_type)
|
||||
)
|
||||
else:
|
||||
scol = self.spark.column.cast(spark_type)
|
||||
else:
|
||||
if isinstance(self.spark.data_type, NumericType):
|
||||
null_str = str(np.nan)
|
||||
elif isinstance(self.spark.data_type, (DateType, TimestampType)):
|
||||
null_str = str(pd.NaT)
|
||||
else:
|
||||
null_str = str(None)
|
||||
if isinstance(self.spark.data_type, BooleanType):
|
||||
casted = F.when(self.spark.column, "True").otherwise("False")
|
||||
else:
|
||||
casted = self.spark.column.cast(spark_type)
|
||||
scol = F.when(self.spark.column.isNull(), null_str).otherwise(casted)
|
||||
else:
|
||||
scol = self.spark.column.cast(spark_type)
|
||||
return self._with_new_scol(
|
||||
scol.alias(self._internal.data_spark_column_names[0]), field=InternalField(dtype=dtype)
|
||||
)
|
||||
return self._dtype_op.astype(self, dtype)
|
||||
|
||||
def isin(self, values) -> Union["Series", "Index"]:
|
||||
"""
|
||||
|
|
|
@ -17,12 +17,14 @@
|
|||
|
||||
import numbers
|
||||
from abc import ABCMeta
|
||||
from itertools import chain
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
from pyspark.sql import functions as F
|
||||
from pyspark.sql.types import (
|
||||
ArrayType,
|
||||
BinaryType,
|
||||
|
@ -39,9 +41,7 @@ from pyspark.sql.types import (
|
|||
TimestampType,
|
||||
UserDefinedType,
|
||||
)
|
||||
|
||||
import pyspark.sql.types as types
|
||||
from pyspark.pandas.typedef import Dtype
|
||||
from pyspark.pandas.typedef import Dtype, extension_dtypes
|
||||
from pyspark.pandas.typedef.typehints import extension_object_dtypes_available
|
||||
|
||||
if extension_object_dtypes_available:
|
||||
|
@ -70,7 +70,7 @@ def is_valid_operand_for_numeric_arithmetic(operand: Any, *, allow_bool: bool =
|
|||
|
||||
|
||||
def transform_boolean_operand_to_numeric(
|
||||
operand: Any, spark_type: Optional[types.DataType] = None
|
||||
operand: Any, spark_type: Optional[DataType] = None
|
||||
) -> Any:
|
||||
"""Transform boolean operand to numeric.
|
||||
|
||||
|
@ -90,6 +90,99 @@ def transform_boolean_operand_to_numeric(
|
|||
return operand
|
||||
|
||||
|
||||
def _as_categorical_type(
|
||||
index_ops: Union["Series", "Index"], dtype: CategoricalDtype, spark_type: DataType
|
||||
) -> Union["Index", "Series"]:
|
||||
"""Cast `index_ops` to categorical dtype, given `dtype` and `spark_type`."""
|
||||
assert isinstance(dtype, CategoricalDtype)
|
||||
if dtype.categories is None:
|
||||
codes, uniques = index_ops.factorize()
|
||||
return codes._with_new_scol(
|
||||
codes.spark.column,
|
||||
field=codes._internal.data_fields[0].copy(dtype=CategoricalDtype(categories=uniques)),
|
||||
)
|
||||
else:
|
||||
categories = dtype.categories
|
||||
if len(categories) == 0:
|
||||
scol = F.lit(-1)
|
||||
else:
|
||||
kvs = chain(
|
||||
*[(F.lit(category), F.lit(code)) for code, category in enumerate(categories)]
|
||||
)
|
||||
map_scol = F.create_map(*kvs)
|
||||
|
||||
scol = F.coalesce(map_scol.getItem(index_ops.spark.column), F.lit(-1))
|
||||
return index_ops._with_new_scol(
|
||||
scol.cast(spark_type).alias(index_ops._internal.data_fields[0].name),
|
||||
field=index_ops._internal.data_fields[0].copy(
|
||||
dtype=dtype, spark_type=spark_type, nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _as_bool_type(
|
||||
index_ops: Union["Series", "Index"], dtype: Union[str, type, Dtype]
|
||||
) -> Union["Index", "Series"]:
|
||||
"""Cast `index_ops` to BooleanType Spark type, given `dtype`."""
|
||||
from pyspark.pandas.internal import InternalField
|
||||
|
||||
if isinstance(dtype, extension_dtypes):
|
||||
scol = index_ops.spark.column.cast(BooleanType())
|
||||
else:
|
||||
scol = F.when(index_ops.spark.column.isNull(), F.lit(False)).otherwise(
|
||||
index_ops.spark.column.cast(BooleanType())
|
||||
)
|
||||
return index_ops._with_new_scol(
|
||||
scol.alias(index_ops._internal.data_spark_column_names[0]),
|
||||
field=InternalField(dtype=dtype),
|
||||
)
|
||||
|
||||
|
||||
def _as_string_type(
|
||||
index_ops: Union["Series", "Index"],
|
||||
dtype: Union[str, type, Dtype],
|
||||
*,
|
||||
null_str: str = str(None)
|
||||
) -> Union["Index", "Series"]:
|
||||
"""Cast `index_ops` to StringType Spark type, given `dtype` and `null_str`,
|
||||
representing null Spark column.
|
||||
"""
|
||||
from pyspark.pandas.internal import InternalField
|
||||
|
||||
if isinstance(dtype, extension_dtypes):
|
||||
scol = index_ops.spark.column.cast(StringType())
|
||||
else:
|
||||
casted = index_ops.spark.column.cast(StringType())
|
||||
scol = F.when(index_ops.spark.column.isNull(), null_str).otherwise(casted)
|
||||
return index_ops._with_new_scol(
|
||||
scol.alias(index_ops._internal.data_spark_column_names[0]),
|
||||
field=InternalField(dtype=dtype),
|
||||
)
|
||||
|
||||
|
||||
def _as_other_type(
|
||||
index_ops: Union["Series", "Index"], dtype: Union[str, type, Dtype], spark_type: DataType
|
||||
) -> Union["Index", "Series"]:
|
||||
"""Cast `index_ops` to a `dtype` (`spark_type`) that needs no pre-processing.
|
||||
|
||||
Destination types that need pre-processing: CategoricalDtype, BooleanType, and StringType.
|
||||
"""
|
||||
from pyspark.pandas.internal import InternalField
|
||||
|
||||
need_pre_process = (
|
||||
isinstance(dtype, CategoricalDtype)
|
||||
or isinstance(spark_type, BooleanType)
|
||||
or isinstance(spark_type, StringType)
|
||||
)
|
||||
assert not need_pre_process, "Pre-processing is needed before the type casting."
|
||||
|
||||
scol = index_ops.spark.column.cast(spark_type)
|
||||
return index_ops._with_new_scol(
|
||||
scol.alias(index_ops._internal.data_spark_column_names[0]),
|
||||
field=InternalField(dtype=dtype),
|
||||
)
|
||||
|
||||
|
||||
class DataTypeOps(object, metaclass=ABCMeta):
|
||||
"""The base class for binary operations of pandas-on-Spark objects (of different data types)."""
|
||||
|
||||
|
@ -206,3 +299,8 @@ class DataTypeOps(object, metaclass=ABCMeta):
|
|||
def prepare(self, col: pd.Series) -> pd.Series:
|
||||
"""Prepare column when from_pandas."""
|
||||
return col.replace({np.nan: None})
|
||||
|
||||
def astype(
|
||||
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
|
||||
) -> Union["Index", "Series"]:
|
||||
raise TypeError("astype can not be applied to %s." % self.pretty_name)
|
||||
|
|
|
@ -17,10 +17,19 @@
|
|||
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from pyspark.sql import functions as F
|
||||
from pyspark.sql.types import BinaryType
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
from pyspark.pandas.base import column_op, IndexOpsMixin
|
||||
from pyspark.pandas.data_type_ops.base import DataTypeOps
|
||||
from pyspark.pandas.data_type_ops.base import (
|
||||
DataTypeOps,
|
||||
_as_bool_type,
|
||||
_as_categorical_type,
|
||||
_as_other_type,
|
||||
_as_string_type,
|
||||
)
|
||||
from pyspark.pandas.typedef import Dtype, pandas_on_spark_type
|
||||
from pyspark.sql import functions as F
|
||||
from pyspark.sql.types import BinaryType, BooleanType, StringType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
|
||||
|
@ -53,3 +62,17 @@ class BinaryOps(DataTypeOps):
|
|||
raise TypeError(
|
||||
"Concatenation can not be applied to %s and the given type." % self.pretty_name
|
||||
)
|
||||
|
||||
def astype(
|
||||
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
|
||||
) -> Union["Index", "Series"]:
|
||||
dtype, spark_type = pandas_on_spark_type(dtype)
|
||||
|
||||
if isinstance(dtype, CategoricalDtype):
|
||||
return _as_categorical_type(index_ops, dtype, spark_type)
|
||||
elif isinstance(spark_type, BooleanType):
|
||||
return _as_bool_type(index_ops, dtype)
|
||||
elif isinstance(spark_type, StringType):
|
||||
return _as_string_type(index_ops, dtype)
|
||||
else:
|
||||
return _as_other_type(index_ops, dtype, spark_type)
|
||||
|
|
|
@ -19,6 +19,7 @@ import numbers
|
|||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import pandas as pd
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
from pyspark import sql as spark
|
||||
from pyspark.pandas.base import column_op, IndexOpsMixin
|
||||
|
@ -26,11 +27,15 @@ from pyspark.pandas.data_type_ops.base import (
|
|||
is_valid_operand_for_numeric_arithmetic,
|
||||
DataTypeOps,
|
||||
transform_boolean_operand_to_numeric,
|
||||
_as_bool_type,
|
||||
_as_categorical_type,
|
||||
_as_other_type,
|
||||
)
|
||||
from pyspark.pandas.typedef import extension_dtypes
|
||||
from pyspark.pandas.internal import InternalField
|
||||
from pyspark.pandas.typedef import Dtype, extension_dtypes, pandas_on_spark_type
|
||||
from pyspark.pandas.typedef.typehints import as_spark_type
|
||||
from pyspark.sql import functions as F
|
||||
from pyspark.sql.types import BooleanType
|
||||
from pyspark.sql.types import BooleanType, StringType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
|
||||
|
@ -245,6 +250,32 @@ class BooleanOps(DataTypeOps):
|
|||
|
||||
return column_op(or_func)(left, right)
|
||||
|
||||
def astype(
|
||||
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
|
||||
) -> Union["Index", "Series"]:
|
||||
dtype, spark_type = pandas_on_spark_type(dtype)
|
||||
|
||||
if isinstance(dtype, CategoricalDtype):
|
||||
return _as_categorical_type(index_ops, dtype, spark_type)
|
||||
elif isinstance(spark_type, BooleanType):
|
||||
return _as_bool_type(index_ops, dtype)
|
||||
elif isinstance(spark_type, StringType):
|
||||
if isinstance(dtype, extension_dtypes):
|
||||
scol = F.when(
|
||||
index_ops.spark.column.isNotNull(),
|
||||
F.when(index_ops.spark.column, "True").otherwise("False"),
|
||||
)
|
||||
else:
|
||||
null_str = str(None)
|
||||
casted = F.when(index_ops.spark.column, "True").otherwise("False")
|
||||
scol = F.when(index_ops.spark.column.isNull(), null_str).otherwise(casted)
|
||||
return index_ops._with_new_scol(
|
||||
scol.alias(index_ops._internal.data_spark_column_names[0]),
|
||||
field=InternalField(dtype=dtype),
|
||||
)
|
||||
else:
|
||||
return _as_other_type(index_ops, dtype, spark_type)
|
||||
|
||||
|
||||
class BooleanExtensionOps(BooleanOps):
|
||||
"""
|
||||
|
|
|
@ -15,9 +15,20 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pandas as pd
|
||||
from itertools import chain
|
||||
from typing import cast, TYPE_CHECKING, Union
|
||||
|
||||
import pandas as pd
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
import pyspark.pandas as ps
|
||||
from pyspark.pandas.data_type_ops.base import DataTypeOps
|
||||
from pyspark.pandas.typedef import Dtype, pandas_on_spark_type
|
||||
from pyspark.sql import functions as F
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
|
||||
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
|
||||
|
||||
|
||||
class CategoricalOps(DataTypeOps):
|
||||
|
@ -38,3 +49,24 @@ class CategoricalOps(DataTypeOps):
|
|||
def prepare(self, col: pd.Series) -> pd.Series:
|
||||
"""Prepare column when from_pandas."""
|
||||
return col.cat.codes
|
||||
|
||||
def astype(
|
||||
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
|
||||
) -> Union["Index", "Series"]:
|
||||
dtype, spark_type = pandas_on_spark_type(dtype)
|
||||
|
||||
if isinstance(dtype, CategoricalDtype) and dtype.categories is None:
|
||||
return cast(Union[ps.Index, ps.Series], index_ops).copy()
|
||||
|
||||
categories = index_ops.dtype.categories
|
||||
if len(categories) == 0:
|
||||
scol = F.lit(None)
|
||||
else:
|
||||
kvs = chain(
|
||||
*[(F.lit(code), F.lit(category)) for code, category in enumerate(categories)]
|
||||
)
|
||||
map_scol = F.create_map(*kvs)
|
||||
scol = map_scol.getItem(index_ops.spark.column)
|
||||
return index_ops._with_new_scol(
|
||||
scol.alias(index_ops._internal.data_spark_column_names[0])
|
||||
).astype(dtype)
|
||||
|
|
|
@ -17,10 +17,19 @@
|
|||
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
from pyspark.pandas.base import column_op, IndexOpsMixin
|
||||
from pyspark.pandas.data_type_ops.base import DataTypeOps
|
||||
from pyspark.pandas.data_type_ops.base import (
|
||||
DataTypeOps,
|
||||
_as_bool_type,
|
||||
_as_categorical_type,
|
||||
_as_other_type,
|
||||
_as_string_type,
|
||||
)
|
||||
from pyspark.pandas.typedef import Dtype, pandas_on_spark_type
|
||||
from pyspark.sql import functions as F
|
||||
from pyspark.sql.types import ArrayType, NumericType
|
||||
from pyspark.sql.types import ArrayType, BooleanType, NumericType, StringType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
|
||||
|
@ -56,6 +65,20 @@ class ArrayOps(DataTypeOps):
|
|||
|
||||
return column_op(F.concat)(left, right)
|
||||
|
||||
def astype(
|
||||
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
|
||||
) -> Union["Index", "Series"]:
|
||||
dtype, spark_type = pandas_on_spark_type(dtype)
|
||||
|
||||
if isinstance(dtype, CategoricalDtype):
|
||||
return _as_categorical_type(index_ops, dtype, spark_type)
|
||||
elif isinstance(spark_type, BooleanType):
|
||||
return _as_bool_type(index_ops, dtype)
|
||||
elif isinstance(spark_type, StringType):
|
||||
return _as_string_type(index_ops, dtype)
|
||||
else:
|
||||
return _as_other_type(index_ops, dtype, spark_type)
|
||||
|
||||
|
||||
class MapOps(DataTypeOps):
|
||||
"""
|
||||
|
|
|
@ -19,11 +19,21 @@ import datetime
|
|||
import warnings
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import pandas as pd
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
from pyspark.sql import functions as F
|
||||
from pyspark.sql.types import DateType
|
||||
from pyspark.sql.types import BooleanType, DateType, StringType
|
||||
|
||||
from pyspark.pandas.base import column_op, IndexOpsMixin
|
||||
from pyspark.pandas.data_type_ops.base import DataTypeOps
|
||||
from pyspark.pandas.data_type_ops.base import (
|
||||
DataTypeOps,
|
||||
_as_bool_type,
|
||||
_as_categorical_type,
|
||||
_as_other_type,
|
||||
_as_string_type,
|
||||
)
|
||||
from pyspark.pandas.typedef import Dtype, pandas_on_spark_type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
|
||||
|
@ -69,3 +79,17 @@ class DateOps(DataTypeOps):
|
|||
return -column_op(F.datediff)(left, F.lit(right)).astype("long")
|
||||
else:
|
||||
raise TypeError("date subtraction can only be applied to date series.")
|
||||
|
||||
def astype(
|
||||
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
|
||||
) -> Union["Index", "Series"]:
|
||||
dtype, spark_type = pandas_on_spark_type(dtype)
|
||||
|
||||
if isinstance(dtype, CategoricalDtype):
|
||||
return _as_categorical_type(index_ops, dtype, spark_type)
|
||||
elif isinstance(spark_type, BooleanType):
|
||||
return _as_bool_type(index_ops, dtype)
|
||||
elif isinstance(spark_type, StringType):
|
||||
return _as_string_type(index_ops, dtype, null_str=str(pd.NaT))
|
||||
else:
|
||||
return _as_other_type(index_ops, dtype, spark_type)
|
||||
|
|
|
@ -19,12 +19,21 @@ import datetime
|
|||
import warnings
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import pandas as pd
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
from pyspark.sql import functions as F
|
||||
from pyspark.sql.types import TimestampType
|
||||
from pyspark.sql.types import BooleanType, StringType, TimestampType
|
||||
|
||||
from pyspark.pandas.base import IndexOpsMixin
|
||||
from pyspark.pandas.data_type_ops.base import DataTypeOps
|
||||
from pyspark.pandas.typedef import as_spark_type
|
||||
from pyspark.pandas.data_type_ops.base import (
|
||||
DataTypeOps,
|
||||
_as_bool_type,
|
||||
_as_categorical_type,
|
||||
_as_other_type,
|
||||
)
|
||||
from pyspark.pandas.internal import InternalField
|
||||
from pyspark.pandas.typedef import as_spark_type, Dtype, extension_dtypes, pandas_on_spark_type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
|
||||
|
@ -78,3 +87,29 @@ class DatetimeOps(DataTypeOps):
|
|||
def prepare(self, col):
|
||||
"""Prepare column when from_pandas."""
|
||||
return col
|
||||
|
||||
def astype(
|
||||
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
|
||||
) -> Union["Index", "Series"]:
|
||||
dtype, spark_type = pandas_on_spark_type(dtype)
|
||||
|
||||
if isinstance(dtype, CategoricalDtype):
|
||||
return _as_categorical_type(index_ops, dtype, spark_type)
|
||||
elif isinstance(spark_type, BooleanType):
|
||||
return _as_bool_type(index_ops, dtype)
|
||||
elif isinstance(spark_type, StringType):
|
||||
if isinstance(dtype, extension_dtypes):
|
||||
# seems like a pandas' bug?
|
||||
scol = F.when(index_ops.spark.column.isNull(), str(pd.NaT)).otherwise(
|
||||
index_ops.spark.column.cast(spark_type)
|
||||
)
|
||||
else:
|
||||
null_str = str(pd.NaT)
|
||||
casted = index_ops.spark.column.cast(spark_type)
|
||||
scol = F.when(index_ops.spark.column.isNull(), null_str).otherwise(casted)
|
||||
return index_ops._with_new_scol(
|
||||
scol.alias(index_ops._internal.data_spark_column_names[0]),
|
||||
field=InternalField(dtype=dtype),
|
||||
)
|
||||
else:
|
||||
return _as_other_type(index_ops, dtype, spark_type)
|
||||
|
|
|
@ -15,7 +15,23 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from pyspark.pandas.data_type_ops.base import DataTypeOps
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
from pyspark.pandas.data_type_ops.base import (
|
||||
DataTypeOps,
|
||||
_as_bool_type,
|
||||
_as_categorical_type,
|
||||
_as_other_type,
|
||||
_as_string_type,
|
||||
)
|
||||
from pyspark.pandas.typedef import Dtype, pandas_on_spark_type
|
||||
from pyspark.sql.types import BooleanType, StringType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
|
||||
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
|
||||
|
||||
|
||||
class NullOps(DataTypeOps):
|
||||
|
@ -26,3 +42,17 @@ class NullOps(DataTypeOps):
|
|||
@property
|
||||
def pretty_name(self) -> str:
|
||||
return "nulls"
|
||||
|
||||
def astype(
|
||||
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
|
||||
) -> Union["Index", "Series"]:
|
||||
dtype, spark_type = pandas_on_spark_type(dtype)
|
||||
|
||||
if isinstance(dtype, CategoricalDtype):
|
||||
return _as_categorical_type(index_ops, dtype, spark_type)
|
||||
elif isinstance(spark_type, BooleanType):
|
||||
return _as_bool_type(index_ops, dtype)
|
||||
elif isinstance(spark_type, StringType):
|
||||
return _as_string_type(index_ops, dtype)
|
||||
else:
|
||||
return _as_other_type(index_ops, dtype, spark_type)
|
||||
|
|
|
@ -19,22 +19,30 @@ import numbers
|
|||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pyspark.sql import functions as F
|
||||
from pyspark.sql.types import (
|
||||
StringType,
|
||||
TimestampType,
|
||||
)
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
from pyspark.pandas.base import column_op, IndexOpsMixin, numpy_column_op
|
||||
from pyspark.pandas.data_type_ops.base import (
|
||||
is_valid_operand_for_numeric_arithmetic,
|
||||
DataTypeOps,
|
||||
transform_boolean_operand_to_numeric,
|
||||
_as_bool_type,
|
||||
_as_categorical_type,
|
||||
_as_other_type,
|
||||
_as_string_type,
|
||||
)
|
||||
from pyspark.pandas.internal import InternalField
|
||||
from pyspark.pandas.spark import functions as SF
|
||||
from pyspark.pandas.typedef import Dtype, extension_dtypes, pandas_on_spark_type
|
||||
from pyspark.sql import functions as F
|
||||
from pyspark.sql.column import Column
|
||||
|
||||
from pyspark.sql.types import (
|
||||
BooleanType,
|
||||
DoubleType,
|
||||
FloatType,
|
||||
StringType,
|
||||
TimestampType,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
|
||||
|
@ -248,6 +256,20 @@ class IntegralOps(NumericOps):
|
|||
right = transform_boolean_operand_to_numeric(right, left.spark.data_type)
|
||||
return numpy_column_op(rfloordiv)(left, right)
|
||||
|
||||
def astype(
|
||||
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
|
||||
) -> Union["Index", "Series"]:
|
||||
dtype, spark_type = pandas_on_spark_type(dtype)
|
||||
|
||||
if isinstance(dtype, CategoricalDtype):
|
||||
return _as_categorical_type(index_ops, dtype, spark_type)
|
||||
elif isinstance(spark_type, BooleanType):
|
||||
return _as_bool_type(index_ops, dtype)
|
||||
elif isinstance(spark_type, StringType):
|
||||
return _as_string_type(index_ops, dtype, null_str=str(np.nan))
|
||||
else:
|
||||
return _as_other_type(index_ops, dtype, spark_type)
|
||||
|
||||
|
||||
class FractionalOps(NumericOps):
|
||||
"""
|
||||
|
@ -344,3 +366,32 @@ class FractionalOps(NumericOps):
|
|||
|
||||
right = transform_boolean_operand_to_numeric(right, left.spark.data_type)
|
||||
return numpy_column_op(rfloordiv)(left, right)
|
||||
|
||||
def astype(
|
||||
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
|
||||
) -> Union["Index", "Series"]:
|
||||
dtype, spark_type = pandas_on_spark_type(dtype)
|
||||
|
||||
if isinstance(dtype, CategoricalDtype):
|
||||
return _as_categorical_type(index_ops, dtype, spark_type)
|
||||
elif isinstance(spark_type, BooleanType):
|
||||
if isinstance(dtype, extension_dtypes):
|
||||
scol = index_ops.spark.column.cast(spark_type)
|
||||
else:
|
||||
if isinstance(index_ops.spark.data_type, (FloatType, DoubleType)):
|
||||
scol = F.when(
|
||||
index_ops.spark.column.isNull() | F.isnan(index_ops.spark.column),
|
||||
F.lit(True),
|
||||
).otherwise(index_ops.spark.column.cast(spark_type))
|
||||
else: # DecimalType
|
||||
scol = F.when(index_ops.spark.column.isNull(), F.lit(False)).otherwise(
|
||||
index_ops.spark.column.cast(spark_type)
|
||||
)
|
||||
return index_ops._with_new_scol(
|
||||
scol.alias(index_ops._internal.data_spark_column_names[0]),
|
||||
field=InternalField(dtype=dtype),
|
||||
)
|
||||
elif isinstance(spark_type, StringType):
|
||||
return _as_string_type(index_ops, dtype, null_str=str(np.nan))
|
||||
else:
|
||||
return _as_other_type(index_ops, dtype, spark_type)
|
||||
|
|
|
@ -23,8 +23,16 @@ from pyspark.sql import functions as F
|
|||
from pyspark.sql.types import IntegralType, StringType
|
||||
|
||||
from pyspark.pandas.base import column_op, IndexOpsMixin
|
||||
from pyspark.pandas.data_type_ops.base import DataTypeOps
|
||||
from pyspark.pandas.data_type_ops.base import (
|
||||
DataTypeOps,
|
||||
_as_categorical_type,
|
||||
_as_other_type,
|
||||
_as_string_type,
|
||||
)
|
||||
from pyspark.pandas.internal import InternalField
|
||||
from pyspark.pandas.spark import functions as SF
|
||||
from pyspark.pandas.typedef import Dtype, extension_dtypes, pandas_on_spark_type
|
||||
from pyspark.sql.types import BooleanType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943)
|
||||
|
@ -102,3 +110,27 @@ class StringOps(DataTypeOps):
|
|||
|
||||
def rmod(self, left, right):
|
||||
raise TypeError("modulo can not be applied on string series or literals.")
|
||||
|
||||
def astype(
|
||||
self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype]
|
||||
) -> Union["Index", "Series"]:
|
||||
dtype, spark_type = pandas_on_spark_type(dtype)
|
||||
|
||||
if isinstance(dtype, CategoricalDtype):
|
||||
return _as_categorical_type(index_ops, dtype, spark_type)
|
||||
|
||||
if isinstance(spark_type, BooleanType):
|
||||
if isinstance(dtype, extension_dtypes):
|
||||
scol = index_ops.spark.column.cast(spark_type)
|
||||
else:
|
||||
scol = F.when(index_ops.spark.column.isNull(), F.lit(False)).otherwise(
|
||||
F.length(index_ops.spark.column) > 0
|
||||
)
|
||||
return index_ops._with_new_scol(
|
||||
scol.alias(index_ops._internal.data_spark_column_names[0]),
|
||||
field=InternalField(dtype=dtype),
|
||||
)
|
||||
elif isinstance(spark_type, StringType):
|
||||
return _as_string_type(index_ops, dtype)
|
||||
else:
|
||||
return _as_other_type(index_ops, dtype, spark_type)
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#
|
||||
|
||||
import pandas as pd
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
from pyspark import pandas as ps
|
||||
from pyspark.pandas.config import option_context
|
||||
|
@ -147,6 +148,14 @@ class BinaryOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
|||
self.assert_eq(pser, psser.to_pandas())
|
||||
self.assert_eq(ps.from_pandas(pser), psser)
|
||||
|
||||
def test_astype(self):
|
||||
pser = self.pser
|
||||
psser = self.psser
|
||||
self.assert_eq(pd.Series(["1", "2", "3"]), psser.astype(str))
|
||||
self.assert_eq(pser.astype("category"), psser.astype("category"))
|
||||
cat_type = CategoricalDtype(categories=[b"2", b"3", b"1"])
|
||||
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
|
|
@ -21,6 +21,7 @@ from distutils.version import LooseVersion
|
|||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
from pyspark import pandas as ps
|
||||
from pyspark.pandas.config import option_context
|
||||
|
@ -287,6 +288,21 @@ class BooleanOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
|||
self.assert_eq(True | pser, True | psser)
|
||||
self.assert_eq(False | pser, False | psser)
|
||||
|
||||
def test_astype(self):
|
||||
pser = self.pser
|
||||
psser = self.psser
|
||||
self.assert_eq(pser.astype(int), psser.astype(int))
|
||||
self.assert_eq(pser.astype(float), psser.astype(float))
|
||||
self.assert_eq(pser.astype(np.float32), psser.astype(np.float32))
|
||||
self.assert_eq(pser.astype(np.int32), psser.astype(np.int32))
|
||||
self.assert_eq(pser.astype(np.int16), psser.astype(np.int16))
|
||||
self.assert_eq(pser.astype(np.int8), psser.astype(np.int8))
|
||||
self.assert_eq(pser.astype(str), psser.astype(str))
|
||||
self.assert_eq(pser.astype(bool), psser.astype(bool))
|
||||
self.assert_eq(pser.astype("category"), psser.astype("category"))
|
||||
cat_type = CategoricalDtype(categories=[False, True])
|
||||
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
|
||||
|
||||
|
||||
@unittest.skipIf(not extension_dtypes_available, "pandas extension dtypes are not available")
|
||||
class BooleanExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
||||
|
@ -578,6 +594,14 @@ class BooleanExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
|||
self.assert_eq(pser, psser.to_pandas())
|
||||
self.assert_eq(ps.from_pandas(pser), psser)
|
||||
|
||||
def test_astype(self):
|
||||
pser = self.pser
|
||||
psser = self.psser
|
||||
self.assert_eq(["True", "False", "None"], self.psser.astype(str).tolist())
|
||||
self.assert_eq(pser.astype("category"), psser.astype("category"))
|
||||
cat_type = CategoricalDtype(categories=[False, True])
|
||||
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pyspark.pandas.tests.data_type_ops.test_boolean_ops import * # noqa: F401
|
||||
|
|
|
@ -15,7 +15,11 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
from pyspark import pandas as ps
|
||||
from pyspark.pandas.config import option_context
|
||||
|
@ -140,6 +144,25 @@ class CategoricalOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
|||
self.assert_eq(pser, psser.to_pandas())
|
||||
self.assert_eq(ps.from_pandas(pser), psser)
|
||||
|
||||
def test_astype(self):
|
||||
data = [1, 2, 3]
|
||||
pser = pd.Series(data, dtype="category")
|
||||
psser = ps.from_pandas(pser)
|
||||
self.assert_eq(pser.astype(int), psser.astype(int))
|
||||
self.assert_eq(pser.astype(float), psser.astype(float))
|
||||
self.assert_eq(pser.astype(np.float32), psser.astype(np.float32))
|
||||
self.assert_eq(pser.astype(np.int32), psser.astype(np.int32))
|
||||
self.assert_eq(pser.astype(np.int16), psser.astype(np.int16))
|
||||
self.assert_eq(pser.astype(np.int8), psser.astype(np.int8))
|
||||
self.assert_eq(pser.astype(str), psser.astype(str))
|
||||
self.assert_eq(pser.astype(bool), psser.astype(bool))
|
||||
self.assert_eq(pser.astype("category"), psser.astype("category"))
|
||||
cat_type = CategoricalDtype(categories=[3, 1, 2])
|
||||
if LooseVersion(pd.__version__) >= LooseVersion("1.2"):
|
||||
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
|
||||
else:
|
||||
self.assert_eq(pd.Series(data).astype(cat_type), psser.astype(cat_type))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
|
|
@ -65,9 +65,13 @@ class ComplexOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
|||
def pssers(self):
|
||||
return self.numeric_array_pssers + list(self.non_numeric_array_pssers.values())
|
||||
|
||||
@property
|
||||
def pser(self):
|
||||
return pd.Series([[1, 2, 3]])
|
||||
|
||||
@property
|
||||
def psser(self):
|
||||
return ps.Series([[1, 2, 3]])
|
||||
return ps.from_pandas(self.pser)
|
||||
|
||||
def test_add(self):
|
||||
for pser, psser in zip(self.psers, self.pssers):
|
||||
|
@ -213,6 +217,9 @@ class ComplexOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
|||
self.assert_eq(pser, psser.to_pandas())
|
||||
self.assert_eq(ps.from_pandas(pser), psser)
|
||||
|
||||
def test_astype(self):
|
||||
self.assert_eq(self.pser.astype(str), self.psser.astype(str))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
import datetime
|
||||
|
||||
import pandas as pd
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
from pyspark.sql.types import DateType
|
||||
|
||||
|
@ -172,6 +173,14 @@ class DateOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
|||
self.assert_eq(pser, psser.to_pandas())
|
||||
self.assert_eq(ps.from_pandas(pser), psser)
|
||||
|
||||
def test_astype(self):
|
||||
pser = self.pser
|
||||
psser = self.psser
|
||||
self.assert_eq(pser.astype(str), psser.astype(str))
|
||||
self.assert_eq(pd.Series([None, None, None]), psser.astype(bool))
|
||||
cat_type = CategoricalDtype(categories=["a", "b", "c"])
|
||||
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
|
|
@ -19,6 +19,7 @@ import datetime
|
|||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
from pyspark import pandas as ps
|
||||
from pyspark.pandas.config import option_context
|
||||
|
@ -172,6 +173,14 @@ class DatetimeOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
|||
self.assert_eq(pser, psser.to_pandas())
|
||||
self.assert_eq(ps.from_pandas(pser), psser)
|
||||
|
||||
def test_astype(self):
|
||||
pser = self.pser
|
||||
psser = self.psser
|
||||
self.assert_eq(pser.astype(str), psser.astype(str))
|
||||
self.assert_eq(pser.astype("category"), psser.astype("category"))
|
||||
cat_type = CategoricalDtype(categories=["a", "b", "c"])
|
||||
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#
|
||||
|
||||
import pandas as pd
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
import pyspark.pandas as ps
|
||||
from pyspark.pandas.config import option_context
|
||||
|
@ -122,6 +123,15 @@ class NullOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
|||
self.assert_eq(pser, psser.to_pandas())
|
||||
self.assert_eq(ps.from_pandas(pser), psser)
|
||||
|
||||
def test_astype(self):
|
||||
pser = self.pser
|
||||
psser = self.psser
|
||||
self.assert_eq(pser.astype(str), psser.astype(str))
|
||||
self.assert_eq(pser.astype(bool), psser.astype(bool))
|
||||
self.assert_eq(pser.astype("category"), psser.astype("category"))
|
||||
cat_type = CategoricalDtype(categories=[1, 2, 3])
|
||||
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
|
|
@ -20,6 +20,7 @@ from distutils.version import LooseVersion
|
|||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
from pyspark import pandas as ps
|
||||
from pyspark.pandas.config import option_context
|
||||
|
@ -293,6 +294,20 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
|||
self.assert_eq(pser, psser.to_pandas())
|
||||
self.assert_eq(ps.from_pandas(pser), psser)
|
||||
|
||||
def test_astype(self):
|
||||
for pser, psser in self.numeric_pser_psser_pairs:
|
||||
self.assert_eq(pser.astype(int), psser.astype(int))
|
||||
self.assert_eq(pser.astype(float), psser.astype(float))
|
||||
self.assert_eq(pser.astype(np.float32), psser.astype(np.float32))
|
||||
self.assert_eq(pser.astype(np.int32), psser.astype(np.int32))
|
||||
self.assert_eq(pser.astype(np.int16), psser.astype(np.int16))
|
||||
self.assert_eq(pser.astype(np.int8), psser.astype(np.int8))
|
||||
self.assert_eq(pser.astype(str), psser.astype(str))
|
||||
self.assert_eq(pser.astype(bool), psser.astype(bool))
|
||||
self.assert_eq(pser.astype("category"), psser.astype("category"))
|
||||
cat_type = CategoricalDtype(categories=[2, 1, 3])
|
||||
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
from pyspark import pandas as ps
|
||||
from pyspark.pandas.config import option_context
|
||||
|
@ -154,10 +155,25 @@ class StringOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
|||
self.assert_eq(pser, psser.to_pandas())
|
||||
self.assert_eq(ps.from_pandas(pser), psser)
|
||||
|
||||
def test_astype(self):
|
||||
pser = pd.Series(["1", "2", "3"])
|
||||
psser = ps.from_pandas(pser)
|
||||
self.assert_eq(pser.astype(int), psser.astype(int))
|
||||
self.assert_eq(pser.astype(float), psser.astype(float))
|
||||
self.assert_eq(pser.astype(np.float32), psser.astype(np.float32))
|
||||
self.assert_eq(pser.astype(np.int32), psser.astype(np.int32))
|
||||
self.assert_eq(pser.astype(np.int16), psser.astype(np.int16))
|
||||
self.assert_eq(pser.astype(np.int8), psser.astype(np.int8))
|
||||
self.assert_eq(pser.astype(str), psser.astype(str))
|
||||
self.assert_eq(pser.astype(bool), psser.astype(bool))
|
||||
self.assert_eq(pser.astype("category"), psser.astype("category"))
|
||||
cat_type = CategoricalDtype(categories=["3", "1", "2"])
|
||||
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
from pyspark.pandas.tests.data_type_ops.test_num_ops import * # noqa: F401
|
||||
from pyspark.pandas.tests.data_type_ops.test_string_ops import * # noqa: F401
|
||||
|
||||
try:
|
||||
import xmlrunner # type: ignore[import]
|
||||
|
|
|
@ -125,6 +125,9 @@ class UDTOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
|||
self.assert_eq(pser, psser.to_pandas())
|
||||
self.assert_eq(ps.from_pandas(pser), psser)
|
||||
|
||||
def test_astype(self):
|
||||
self.assertRaises(TypeError, lambda: self.psser.astype(str))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
|
Loading…
Reference in a new issue