diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py index ec1b8b400e..4428391a78 100644 --- a/python/pyspark/pandas/base.py +++ b/python/pyspark/pandas/base.py @@ -29,14 +29,9 @@ from pandas.api.types import is_list_like, CategoricalDtype from pyspark import sql as spark from pyspark.sql import functions as F, Window, Column from pyspark.sql.types import ( - BooleanType, - DateType, DoubleType, FloatType, LongType, - NumericType, - StringType, - TimestampType, ) from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. @@ -52,7 +47,6 @@ from pyspark.pandas.spark.accessors import SparkIndexOpsMethods from pyspark.pandas.typedef import ( Dtype, extension_dtypes, - pandas_on_spark_type, ) from pyspark.pandas.utils import ( combine_frames, @@ -802,103 +796,7 @@ class IndexOpsMixin(object, metaclass=ABCMeta): >>> ser.rename("a").to_frame().set_index("a").index.astype('int64') Int64Index([1, 2], dtype='int64', name='a') """ - dtype, spark_type = pandas_on_spark_type(dtype) - if not spark_type: - raise ValueError("Type {} not understood".format(dtype)) - - if isinstance(self.dtype, CategoricalDtype): - if isinstance(dtype, CategoricalDtype) and dtype.categories is None: - return cast(Union[ps.Index, ps.Series], self).copy() - - categories = self.dtype.categories - if len(categories) == 0: - scol = F.lit(None) - else: - kvs = chain( - *[(F.lit(code), F.lit(category)) for code, category in enumerate(categories)] - ) - map_scol = F.create_map(*kvs) - scol = map_scol.getItem(self.spark.column) - return self._with_new_scol( - scol.alias(self._internal.data_spark_column_names[0]) - ).astype(dtype) - elif isinstance(dtype, CategoricalDtype): - if dtype.categories is None: - codes, uniques = self.factorize() - return codes._with_new_scol( - codes.spark.column, - field=codes._internal.data_fields[0].copy( - dtype=CategoricalDtype(categories=uniques) - ), - ) - else: - categories = dtype.categories - if len(categories) == 0: - scol = F.lit(-1) - else: - kvs = chain( - *[ - (F.lit(category), F.lit(code)) - for code, category in enumerate(categories) - ] - ) - map_scol = F.create_map(*kvs) - - scol = F.coalesce(map_scol.getItem(self.spark.column), F.lit(-1)) - return self._with_new_scol( - scol.cast(spark_type).alias(self._internal.data_fields[0].name), - field=self._internal.data_fields[0].copy( - dtype=dtype, spark_type=spark_type, nullable=False - ), - ) - - if isinstance(spark_type, BooleanType): - if isinstance(dtype, extension_dtypes): - scol = self.spark.column.cast(spark_type) - else: - if isinstance(self.spark.data_type, StringType): - scol = F.when(self.spark.column.isNull(), F.lit(False)).otherwise( - F.length(self.spark.column) > 0 - ) - elif isinstance(self.spark.data_type, (FloatType, DoubleType)): - scol = F.when( - self.spark.column.isNull() | F.isnan(self.spark.column), F.lit(True) - ).otherwise(self.spark.column.cast(spark_type)) - else: - scol = F.when(self.spark.column.isNull(), F.lit(False)).otherwise( - self.spark.column.cast(spark_type) - ) - elif isinstance(spark_type, StringType): - if isinstance(dtype, extension_dtypes): - if isinstance(self.spark.data_type, BooleanType): - scol = F.when( - self.spark.column.isNotNull(), - F.when(self.spark.column, "True").otherwise("False"), - ) - elif isinstance(self.spark.data_type, TimestampType): - # seems like a pandas' bug? - scol = F.when(self.spark.column.isNull(), str(pd.NaT)).otherwise( - self.spark.column.cast(spark_type) - ) - else: - scol = self.spark.column.cast(spark_type) - else: - if isinstance(self.spark.data_type, NumericType): - null_str = str(np.nan) - elif isinstance(self.spark.data_type, (DateType, TimestampType)): - null_str = str(pd.NaT) - else: - null_str = str(None) - if isinstance(self.spark.data_type, BooleanType): - casted = F.when(self.spark.column, "True").otherwise("False") - else: - casted = self.spark.column.cast(spark_type) - scol = F.when(self.spark.column.isNull(), null_str).otherwise(casted) - else: - scol = self.spark.column.cast(spark_type) - return self._with_new_scol( - scol.alias(self._internal.data_spark_column_names[0]), field=InternalField(dtype=dtype) - ) + return self._dtype_op.astype(self, dtype) def isin(self, values) -> Union["Series", "Index"]: """ diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py index c08db48419..9084453e30 100644 --- a/python/pyspark/pandas/data_type_ops/base.py +++ b/python/pyspark/pandas/data_type_ops/base.py @@ -17,12 +17,14 @@ import numbers from abc import ABCMeta +from itertools import chain from typing import Any, Optional, TYPE_CHECKING, Union import numpy as np import pandas as pd from pandas.api.types import CategoricalDtype +from pyspark.sql import functions as F from pyspark.sql.types import ( ArrayType, BinaryType, @@ -39,9 +41,7 @@ from pyspark.sql.types import ( TimestampType, UserDefinedType, ) - -import pyspark.sql.types as types -from pyspark.pandas.typedef import Dtype +from pyspark.pandas.typedef import Dtype, extension_dtypes from pyspark.pandas.typedef.typehints import extension_object_dtypes_available if extension_object_dtypes_available: @@ -70,7 +70,7 @@ def is_valid_operand_for_numeric_arithmetic(operand: Any, *, allow_bool: bool = def transform_boolean_operand_to_numeric( - operand: Any, spark_type: Optional[types.DataType] = None + operand: Any, spark_type: Optional[DataType] = None ) -> Any: """Transform boolean operand to numeric. @@ -90,6 +90,99 @@ def transform_boolean_operand_to_numeric( return operand +def _as_categorical_type( + index_ops: Union["Series", "Index"], dtype: CategoricalDtype, spark_type: DataType +) -> Union["Index", "Series"]: + """Cast `index_ops` to categorical dtype, given `dtype` and `spark_type`.""" + assert isinstance(dtype, CategoricalDtype) + if dtype.categories is None: + codes, uniques = index_ops.factorize() + return codes._with_new_scol( + codes.spark.column, + field=codes._internal.data_fields[0].copy(dtype=CategoricalDtype(categories=uniques)), + ) + else: + categories = dtype.categories + if len(categories) == 0: + scol = F.lit(-1) + else: + kvs = chain( + *[(F.lit(category), F.lit(code)) for code, category in enumerate(categories)] + ) + map_scol = F.create_map(*kvs) + + scol = F.coalesce(map_scol.getItem(index_ops.spark.column), F.lit(-1)) + return index_ops._with_new_scol( + scol.cast(spark_type).alias(index_ops._internal.data_fields[0].name), + field=index_ops._internal.data_fields[0].copy( + dtype=dtype, spark_type=spark_type, nullable=False + ), + ) + + +def _as_bool_type( + index_ops: Union["Series", "Index"], dtype: Union[str, type, Dtype] +) -> Union["Index", "Series"]: + """Cast `index_ops` to BooleanType Spark type, given `dtype`.""" + from pyspark.pandas.internal import InternalField + + if isinstance(dtype, extension_dtypes): + scol = index_ops.spark.column.cast(BooleanType()) + else: + scol = F.when(index_ops.spark.column.isNull(), F.lit(False)).otherwise( + index_ops.spark.column.cast(BooleanType()) + ) + return index_ops._with_new_scol( + scol.alias(index_ops._internal.data_spark_column_names[0]), + field=InternalField(dtype=dtype), + ) + + +def _as_string_type( + index_ops: Union["Series", "Index"], + dtype: Union[str, type, Dtype], + *, + null_str: str = str(None) +) -> Union["Index", "Series"]: + """Cast `index_ops` to StringType Spark type, given `dtype` and `null_str`, + representing null Spark column. + """ + from pyspark.pandas.internal import InternalField + + if isinstance(dtype, extension_dtypes): + scol = index_ops.spark.column.cast(StringType()) + else: + casted = index_ops.spark.column.cast(StringType()) + scol = F.when(index_ops.spark.column.isNull(), null_str).otherwise(casted) + return index_ops._with_new_scol( + scol.alias(index_ops._internal.data_spark_column_names[0]), + field=InternalField(dtype=dtype), + ) + + +def _as_other_type( + index_ops: Union["Series", "Index"], dtype: Union[str, type, Dtype], spark_type: DataType +) -> Union["Index", "Series"]: + """Cast `index_ops` to a `dtype` (`spark_type`) that needs no pre-processing. + + Destination types that need pre-processing: CategoricalDtype, BooleanType, and StringType. + """ + from pyspark.pandas.internal import InternalField + + need_pre_process = ( + isinstance(dtype, CategoricalDtype) + or isinstance(spark_type, BooleanType) + or isinstance(spark_type, StringType) + ) + assert not need_pre_process, "Pre-processing is needed before the type casting." + + scol = index_ops.spark.column.cast(spark_type) + return index_ops._with_new_scol( + scol.alias(index_ops._internal.data_spark_column_names[0]), + field=InternalField(dtype=dtype), + ) + + class DataTypeOps(object, metaclass=ABCMeta): """The base class for binary operations of pandas-on-Spark objects (of different data types).""" @@ -206,3 +299,8 @@ class DataTypeOps(object, metaclass=ABCMeta): def prepare(self, col: pd.Series) -> pd.Series: """Prepare column when from_pandas.""" return col.replace({np.nan: None}) + + def astype( + self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype] + ) -> Union["Index", "Series"]: + raise TypeError("astype can not be applied to %s." % self.pretty_name) diff --git a/python/pyspark/pandas/data_type_ops/binary_ops.py b/python/pyspark/pandas/data_type_ops/binary_ops.py index 71c18bfe25..f5407e1415 100644 --- a/python/pyspark/pandas/data_type_ops/binary_ops.py +++ b/python/pyspark/pandas/data_type_ops/binary_ops.py @@ -17,10 +17,19 @@ from typing import TYPE_CHECKING, Union -from pyspark.sql import functions as F -from pyspark.sql.types import BinaryType +from pandas.api.types import CategoricalDtype + from pyspark.pandas.base import column_op, IndexOpsMixin -from pyspark.pandas.data_type_ops.base import DataTypeOps +from pyspark.pandas.data_type_ops.base import ( + DataTypeOps, + _as_bool_type, + _as_categorical_type, + _as_other_type, + _as_string_type, +) +from pyspark.pandas.typedef import Dtype, pandas_on_spark_type +from pyspark.sql import functions as F +from pyspark.sql.types import BinaryType, BooleanType, StringType if TYPE_CHECKING: from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) @@ -53,3 +62,17 @@ class BinaryOps(DataTypeOps): raise TypeError( "Concatenation can not be applied to %s and the given type." % self.pretty_name ) + + def astype( + self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype] + ) -> Union["Index", "Series"]: + dtype, spark_type = pandas_on_spark_type(dtype) + + if isinstance(dtype, CategoricalDtype): + return _as_categorical_type(index_ops, dtype, spark_type) + elif isinstance(spark_type, BooleanType): + return _as_bool_type(index_ops, dtype) + elif isinstance(spark_type, StringType): + return _as_string_type(index_ops, dtype) + else: + return _as_other_type(index_ops, dtype, spark_type) diff --git a/python/pyspark/pandas/data_type_ops/boolean_ops.py b/python/pyspark/pandas/data_type_ops/boolean_ops.py index 0015218576..47604559de 100644 --- a/python/pyspark/pandas/data_type_ops/boolean_ops.py +++ b/python/pyspark/pandas/data_type_ops/boolean_ops.py @@ -19,6 +19,7 @@ import numbers from typing import TYPE_CHECKING, Union import pandas as pd +from pandas.api.types import CategoricalDtype from pyspark import sql as spark from pyspark.pandas.base import column_op, IndexOpsMixin @@ -26,11 +27,15 @@ from pyspark.pandas.data_type_ops.base import ( is_valid_operand_for_numeric_arithmetic, DataTypeOps, transform_boolean_operand_to_numeric, + _as_bool_type, + _as_categorical_type, + _as_other_type, ) -from pyspark.pandas.typedef import extension_dtypes +from pyspark.pandas.internal import InternalField +from pyspark.pandas.typedef import Dtype, extension_dtypes, pandas_on_spark_type from pyspark.pandas.typedef.typehints import as_spark_type from pyspark.sql import functions as F -from pyspark.sql.types import BooleanType +from pyspark.sql.types import BooleanType, StringType if TYPE_CHECKING: from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) @@ -245,6 +250,32 @@ class BooleanOps(DataTypeOps): return column_op(or_func)(left, right) + def astype( + self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype] + ) -> Union["Index", "Series"]: + dtype, spark_type = pandas_on_spark_type(dtype) + + if isinstance(dtype, CategoricalDtype): + return _as_categorical_type(index_ops, dtype, spark_type) + elif isinstance(spark_type, BooleanType): + return _as_bool_type(index_ops, dtype) + elif isinstance(spark_type, StringType): + if isinstance(dtype, extension_dtypes): + scol = F.when( + index_ops.spark.column.isNotNull(), + F.when(index_ops.spark.column, "True").otherwise("False"), + ) + else: + null_str = str(None) + casted = F.when(index_ops.spark.column, "True").otherwise("False") + scol = F.when(index_ops.spark.column.isNull(), null_str).otherwise(casted) + return index_ops._with_new_scol( + scol.alias(index_ops._internal.data_spark_column_names[0]), + field=InternalField(dtype=dtype), + ) + else: + return _as_other_type(index_ops, dtype, spark_type) + class BooleanExtensionOps(BooleanOps): """ diff --git a/python/pyspark/pandas/data_type_ops/categorical_ops.py b/python/pyspark/pandas/data_type_ops/categorical_ops.py index 89cf9037e7..a20f61da3f 100644 --- a/python/pyspark/pandas/data_type_ops/categorical_ops.py +++ b/python/pyspark/pandas/data_type_ops/categorical_ops.py @@ -15,9 +15,20 @@ # limitations under the License. # -import pandas as pd +from itertools import chain +from typing import cast, TYPE_CHECKING, Union +import pandas as pd +from pandas.api.types import CategoricalDtype + +import pyspark.pandas as ps from pyspark.pandas.data_type_ops.base import DataTypeOps +from pyspark.pandas.typedef import Dtype, pandas_on_spark_type +from pyspark.sql import functions as F + +if TYPE_CHECKING: + from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) + from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943) class CategoricalOps(DataTypeOps): @@ -38,3 +49,24 @@ class CategoricalOps(DataTypeOps): def prepare(self, col: pd.Series) -> pd.Series: """Prepare column when from_pandas.""" return col.cat.codes + + def astype( + self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype] + ) -> Union["Index", "Series"]: + dtype, spark_type = pandas_on_spark_type(dtype) + + if isinstance(dtype, CategoricalDtype) and dtype.categories is None: + return cast(Union[ps.Index, ps.Series], index_ops).copy() + + categories = index_ops.dtype.categories + if len(categories) == 0: + scol = F.lit(None) + else: + kvs = chain( + *[(F.lit(code), F.lit(category)) for code, category in enumerate(categories)] + ) + map_scol = F.create_map(*kvs) + scol = map_scol.getItem(index_ops.spark.column) + return index_ops._with_new_scol( + scol.alias(index_ops._internal.data_spark_column_names[0]) + ).astype(dtype) diff --git a/python/pyspark/pandas/data_type_ops/complex_ops.py b/python/pyspark/pandas/data_type_ops/complex_ops.py index edf709f476..c29063d375 100644 --- a/python/pyspark/pandas/data_type_ops/complex_ops.py +++ b/python/pyspark/pandas/data_type_ops/complex_ops.py @@ -17,10 +17,19 @@ from typing import TYPE_CHECKING, Union +from pandas.api.types import CategoricalDtype + from pyspark.pandas.base import column_op, IndexOpsMixin -from pyspark.pandas.data_type_ops.base import DataTypeOps +from pyspark.pandas.data_type_ops.base import ( + DataTypeOps, + _as_bool_type, + _as_categorical_type, + _as_other_type, + _as_string_type, +) +from pyspark.pandas.typedef import Dtype, pandas_on_spark_type from pyspark.sql import functions as F -from pyspark.sql.types import ArrayType, NumericType +from pyspark.sql.types import ArrayType, BooleanType, NumericType, StringType if TYPE_CHECKING: from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) @@ -56,6 +65,20 @@ class ArrayOps(DataTypeOps): return column_op(F.concat)(left, right) + def astype( + self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype] + ) -> Union["Index", "Series"]: + dtype, spark_type = pandas_on_spark_type(dtype) + + if isinstance(dtype, CategoricalDtype): + return _as_categorical_type(index_ops, dtype, spark_type) + elif isinstance(spark_type, BooleanType): + return _as_bool_type(index_ops, dtype) + elif isinstance(spark_type, StringType): + return _as_string_type(index_ops, dtype) + else: + return _as_other_type(index_ops, dtype, spark_type) + class MapOps(DataTypeOps): """ diff --git a/python/pyspark/pandas/data_type_ops/date_ops.py b/python/pyspark/pandas/data_type_ops/date_ops.py index 3d71bbf885..c57be9112c 100644 --- a/python/pyspark/pandas/data_type_ops/date_ops.py +++ b/python/pyspark/pandas/data_type_ops/date_ops.py @@ -19,11 +19,21 @@ import datetime import warnings from typing import TYPE_CHECKING, Union +import pandas as pd +from pandas.api.types import CategoricalDtype + from pyspark.sql import functions as F -from pyspark.sql.types import DateType +from pyspark.sql.types import BooleanType, DateType, StringType from pyspark.pandas.base import column_op, IndexOpsMixin -from pyspark.pandas.data_type_ops.base import DataTypeOps +from pyspark.pandas.data_type_ops.base import ( + DataTypeOps, + _as_bool_type, + _as_categorical_type, + _as_other_type, + _as_string_type, +) +from pyspark.pandas.typedef import Dtype, pandas_on_spark_type if TYPE_CHECKING: from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) @@ -69,3 +79,17 @@ class DateOps(DataTypeOps): return -column_op(F.datediff)(left, F.lit(right)).astype("long") else: raise TypeError("date subtraction can only be applied to date series.") + + def astype( + self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype] + ) -> Union["Index", "Series"]: + dtype, spark_type = pandas_on_spark_type(dtype) + + if isinstance(dtype, CategoricalDtype): + return _as_categorical_type(index_ops, dtype, spark_type) + elif isinstance(spark_type, BooleanType): + return _as_bool_type(index_ops, dtype) + elif isinstance(spark_type, StringType): + return _as_string_type(index_ops, dtype, null_str=str(pd.NaT)) + else: + return _as_other_type(index_ops, dtype, spark_type) diff --git a/python/pyspark/pandas/data_type_ops/datetime_ops.py b/python/pyspark/pandas/data_type_ops/datetime_ops.py index 7676eeec35..4cd8e37dc8 100644 --- a/python/pyspark/pandas/data_type_ops/datetime_ops.py +++ b/python/pyspark/pandas/data_type_ops/datetime_ops.py @@ -19,12 +19,21 @@ import datetime import warnings from typing import TYPE_CHECKING, Union +import pandas as pd +from pandas.api.types import CategoricalDtype + from pyspark.sql import functions as F -from pyspark.sql.types import TimestampType +from pyspark.sql.types import BooleanType, StringType, TimestampType from pyspark.pandas.base import IndexOpsMixin -from pyspark.pandas.data_type_ops.base import DataTypeOps -from pyspark.pandas.typedef import as_spark_type +from pyspark.pandas.data_type_ops.base import ( + DataTypeOps, + _as_bool_type, + _as_categorical_type, + _as_other_type, +) +from pyspark.pandas.internal import InternalField +from pyspark.pandas.typedef import as_spark_type, Dtype, extension_dtypes, pandas_on_spark_type if TYPE_CHECKING: from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) @@ -78,3 +87,29 @@ class DatetimeOps(DataTypeOps): def prepare(self, col): """Prepare column when from_pandas.""" return col + + def astype( + self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype] + ) -> Union["Index", "Series"]: + dtype, spark_type = pandas_on_spark_type(dtype) + + if isinstance(dtype, CategoricalDtype): + return _as_categorical_type(index_ops, dtype, spark_type) + elif isinstance(spark_type, BooleanType): + return _as_bool_type(index_ops, dtype) + elif isinstance(spark_type, StringType): + if isinstance(dtype, extension_dtypes): + # seems like a pandas' bug? + scol = F.when(index_ops.spark.column.isNull(), str(pd.NaT)).otherwise( + index_ops.spark.column.cast(spark_type) + ) + else: + null_str = str(pd.NaT) + casted = index_ops.spark.column.cast(spark_type) + scol = F.when(index_ops.spark.column.isNull(), null_str).otherwise(casted) + return index_ops._with_new_scol( + scol.alias(index_ops._internal.data_spark_column_names[0]), + field=InternalField(dtype=dtype), + ) + else: + return _as_other_type(index_ops, dtype, spark_type) diff --git a/python/pyspark/pandas/data_type_ops/null_ops.py b/python/pyspark/pandas/data_type_ops/null_ops.py index 49a18f7c5d..33d7755cbf 100644 --- a/python/pyspark/pandas/data_type_ops/null_ops.py +++ b/python/pyspark/pandas/data_type_ops/null_ops.py @@ -15,7 +15,23 @@ # limitations under the License. # -from pyspark.pandas.data_type_ops.base import DataTypeOps +from typing import TYPE_CHECKING, Union + +from pandas.api.types import CategoricalDtype + +from pyspark.pandas.data_type_ops.base import ( + DataTypeOps, + _as_bool_type, + _as_categorical_type, + _as_other_type, + _as_string_type, +) +from pyspark.pandas.typedef import Dtype, pandas_on_spark_type +from pyspark.sql.types import BooleanType, StringType + +if TYPE_CHECKING: + from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) + from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943) class NullOps(DataTypeOps): @@ -26,3 +42,17 @@ class NullOps(DataTypeOps): @property def pretty_name(self) -> str: return "nulls" + + def astype( + self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype] + ) -> Union["Index", "Series"]: + dtype, spark_type = pandas_on_spark_type(dtype) + + if isinstance(dtype, CategoricalDtype): + return _as_categorical_type(index_ops, dtype, spark_type) + elif isinstance(spark_type, BooleanType): + return _as_bool_type(index_ops, dtype) + elif isinstance(spark_type, StringType): + return _as_string_type(index_ops, dtype) + else: + return _as_other_type(index_ops, dtype, spark_type) diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py index be58e79aeb..bd3e148267 100644 --- a/python/pyspark/pandas/data_type_ops/num_ops.py +++ b/python/pyspark/pandas/data_type_ops/num_ops.py @@ -19,22 +19,30 @@ import numbers from typing import TYPE_CHECKING, Union import numpy as np - -from pyspark.sql import functions as F -from pyspark.sql.types import ( - StringType, - TimestampType, -) +from pandas.api.types import CategoricalDtype from pyspark.pandas.base import column_op, IndexOpsMixin, numpy_column_op from pyspark.pandas.data_type_ops.base import ( is_valid_operand_for_numeric_arithmetic, DataTypeOps, transform_boolean_operand_to_numeric, + _as_bool_type, + _as_categorical_type, + _as_other_type, + _as_string_type, ) +from pyspark.pandas.internal import InternalField from pyspark.pandas.spark import functions as SF +from pyspark.pandas.typedef import Dtype, extension_dtypes, pandas_on_spark_type +from pyspark.sql import functions as F from pyspark.sql.column import Column - +from pyspark.sql.types import ( + BooleanType, + DoubleType, + FloatType, + StringType, + TimestampType, +) if TYPE_CHECKING: from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) @@ -248,6 +256,20 @@ class IntegralOps(NumericOps): right = transform_boolean_operand_to_numeric(right, left.spark.data_type) return numpy_column_op(rfloordiv)(left, right) + def astype( + self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype] + ) -> Union["Index", "Series"]: + dtype, spark_type = pandas_on_spark_type(dtype) + + if isinstance(dtype, CategoricalDtype): + return _as_categorical_type(index_ops, dtype, spark_type) + elif isinstance(spark_type, BooleanType): + return _as_bool_type(index_ops, dtype) + elif isinstance(spark_type, StringType): + return _as_string_type(index_ops, dtype, null_str=str(np.nan)) + else: + return _as_other_type(index_ops, dtype, spark_type) + class FractionalOps(NumericOps): """ @@ -344,3 +366,32 @@ class FractionalOps(NumericOps): right = transform_boolean_operand_to_numeric(right, left.spark.data_type) return numpy_column_op(rfloordiv)(left, right) + + def astype( + self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype] + ) -> Union["Index", "Series"]: + dtype, spark_type = pandas_on_spark_type(dtype) + + if isinstance(dtype, CategoricalDtype): + return _as_categorical_type(index_ops, dtype, spark_type) + elif isinstance(spark_type, BooleanType): + if isinstance(dtype, extension_dtypes): + scol = index_ops.spark.column.cast(spark_type) + else: + if isinstance(index_ops.spark.data_type, (FloatType, DoubleType)): + scol = F.when( + index_ops.spark.column.isNull() | F.isnan(index_ops.spark.column), + F.lit(True), + ).otherwise(index_ops.spark.column.cast(spark_type)) + else: # DecimalType + scol = F.when(index_ops.spark.column.isNull(), F.lit(False)).otherwise( + index_ops.spark.column.cast(spark_type) + ) + return index_ops._with_new_scol( + scol.alias(index_ops._internal.data_spark_column_names[0]), + field=InternalField(dtype=dtype), + ) + elif isinstance(spark_type, StringType): + return _as_string_type(index_ops, dtype, null_str=str(np.nan)) + else: + return _as_other_type(index_ops, dtype, spark_type) diff --git a/python/pyspark/pandas/data_type_ops/string_ops.py b/python/pyspark/pandas/data_type_ops/string_ops.py index 9695affa63..cf31842df5 100644 --- a/python/pyspark/pandas/data_type_ops/string_ops.py +++ b/python/pyspark/pandas/data_type_ops/string_ops.py @@ -23,8 +23,16 @@ from pyspark.sql import functions as F from pyspark.sql.types import IntegralType, StringType from pyspark.pandas.base import column_op, IndexOpsMixin -from pyspark.pandas.data_type_ops.base import DataTypeOps +from pyspark.pandas.data_type_ops.base import ( + DataTypeOps, + _as_categorical_type, + _as_other_type, + _as_string_type, +) +from pyspark.pandas.internal import InternalField from pyspark.pandas.spark import functions as SF +from pyspark.pandas.typedef import Dtype, extension_dtypes, pandas_on_spark_type +from pyspark.sql.types import BooleanType if TYPE_CHECKING: from pyspark.pandas.indexes import Index # noqa: F401 (SPARK-34943) @@ -102,3 +110,27 @@ class StringOps(DataTypeOps): def rmod(self, left, right): raise TypeError("modulo can not be applied on string series or literals.") + + def astype( + self, index_ops: Union["Index", "Series"], dtype: Union[str, type, Dtype] + ) -> Union["Index", "Series"]: + dtype, spark_type = pandas_on_spark_type(dtype) + + if isinstance(dtype, CategoricalDtype): + return _as_categorical_type(index_ops, dtype, spark_type) + + if isinstance(spark_type, BooleanType): + if isinstance(dtype, extension_dtypes): + scol = index_ops.spark.column.cast(spark_type) + else: + scol = F.when(index_ops.spark.column.isNull(), F.lit(False)).otherwise( + F.length(index_ops.spark.column) > 0 + ) + return index_ops._with_new_scol( + scol.alias(index_ops._internal.data_spark_column_names[0]), + field=InternalField(dtype=dtype), + ) + elif isinstance(spark_type, StringType): + return _as_string_type(index_ops, dtype) + else: + return _as_other_type(index_ops, dtype, spark_type) diff --git a/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py index 12e3eb24de..7ef060bcd0 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py @@ -16,6 +16,7 @@ # import pandas as pd +from pandas.api.types import CategoricalDtype from pyspark import pandas as ps from pyspark.pandas.config import option_context @@ -147,6 +148,14 @@ class BinaryOpsTest(PandasOnSparkTestCase, TestCasesUtils): self.assert_eq(pser, psser.to_pandas()) self.assert_eq(ps.from_pandas(pser), psser) + def test_astype(self): + pser = self.pser + psser = self.psser + self.assert_eq(pd.Series(["1", "2", "3"]), psser.astype(str)) + self.assert_eq(pser.astype("category"), psser.astype("category")) + cat_type = CategoricalDtype(categories=[b"2", b"3", b"1"]) + self.assert_eq(pser.astype(cat_type), psser.astype(cat_type)) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py index 255e742f9a..00f3eafb50 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py @@ -21,6 +21,7 @@ from distutils.version import LooseVersion import pandas as pd import numpy as np +from pandas.api.types import CategoricalDtype from pyspark import pandas as ps from pyspark.pandas.config import option_context @@ -287,6 +288,21 @@ class BooleanOpsTest(PandasOnSparkTestCase, TestCasesUtils): self.assert_eq(True | pser, True | psser) self.assert_eq(False | pser, False | psser) + def test_astype(self): + pser = self.pser + psser = self.psser + self.assert_eq(pser.astype(int), psser.astype(int)) + self.assert_eq(pser.astype(float), psser.astype(float)) + self.assert_eq(pser.astype(np.float32), psser.astype(np.float32)) + self.assert_eq(pser.astype(np.int32), psser.astype(np.int32)) + self.assert_eq(pser.astype(np.int16), psser.astype(np.int16)) + self.assert_eq(pser.astype(np.int8), psser.astype(np.int8)) + self.assert_eq(pser.astype(str), psser.astype(str)) + self.assert_eq(pser.astype(bool), psser.astype(bool)) + self.assert_eq(pser.astype("category"), psser.astype("category")) + cat_type = CategoricalDtype(categories=[False, True]) + self.assert_eq(pser.astype(cat_type), psser.astype(cat_type)) + @unittest.skipIf(not extension_dtypes_available, "pandas extension dtypes are not available") class BooleanExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils): @@ -578,6 +594,14 @@ class BooleanExtensionOpsTest(PandasOnSparkTestCase, TestCasesUtils): self.assert_eq(pser, psser.to_pandas()) self.assert_eq(ps.from_pandas(pser), psser) + def test_astype(self): + pser = self.pser + psser = self.psser + self.assert_eq(["True", "False", "None"], self.psser.astype(str).tolist()) + self.assert_eq(pser.astype("category"), psser.astype("category")) + cat_type = CategoricalDtype(categories=[False, True]) + self.assert_eq(pser.astype(cat_type), psser.astype(cat_type)) + if __name__ == "__main__": from pyspark.pandas.tests.data_type_ops.test_boolean_ops import * # noqa: F401 diff --git a/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py index a5ed1bb3c6..3a16174058 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py @@ -15,7 +15,11 @@ # limitations under the License. # +from distutils.version import LooseVersion + import pandas as pd +import numpy as np +from pandas.api.types import CategoricalDtype from pyspark import pandas as ps from pyspark.pandas.config import option_context @@ -140,6 +144,25 @@ class CategoricalOpsTest(PandasOnSparkTestCase, TestCasesUtils): self.assert_eq(pser, psser.to_pandas()) self.assert_eq(ps.from_pandas(pser), psser) + def test_astype(self): + data = [1, 2, 3] + pser = pd.Series(data, dtype="category") + psser = ps.from_pandas(pser) + self.assert_eq(pser.astype(int), psser.astype(int)) + self.assert_eq(pser.astype(float), psser.astype(float)) + self.assert_eq(pser.astype(np.float32), psser.astype(np.float32)) + self.assert_eq(pser.astype(np.int32), psser.astype(np.int32)) + self.assert_eq(pser.astype(np.int16), psser.astype(np.int16)) + self.assert_eq(pser.astype(np.int8), psser.astype(np.int8)) + self.assert_eq(pser.astype(str), psser.astype(str)) + self.assert_eq(pser.astype(bool), psser.astype(bool)) + self.assert_eq(pser.astype("category"), psser.astype("category")) + cat_type = CategoricalDtype(categories=[3, 1, 2]) + if LooseVersion(pd.__version__) >= LooseVersion("1.2"): + self.assert_eq(pser.astype(cat_type), psser.astype(cat_type)) + else: + self.assert_eq(pd.Series(data).astype(cat_type), psser.astype(cat_type)) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py index b2902cb79f..3be9aa2569 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_complex_ops.py @@ -65,9 +65,13 @@ class ComplexOpsTest(PandasOnSparkTestCase, TestCasesUtils): def pssers(self): return self.numeric_array_pssers + list(self.non_numeric_array_pssers.values()) + @property + def pser(self): + return pd.Series([[1, 2, 3]]) + @property def psser(self): - return ps.Series([[1, 2, 3]]) + return ps.from_pandas(self.pser) def test_add(self): for pser, psser in zip(self.psers, self.pssers): @@ -213,6 +217,9 @@ class ComplexOpsTest(PandasOnSparkTestCase, TestCasesUtils): self.assert_eq(pser, psser.to_pandas()) self.assert_eq(ps.from_pandas(pser), psser) + def test_astype(self): + self.assert_eq(self.pser.astype(str), self.psser.astype(str)) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py index 3d94253e2f..18333f85fd 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_date_ops.py @@ -18,6 +18,7 @@ import datetime import pandas as pd +from pandas.api.types import CategoricalDtype from pyspark.sql.types import DateType @@ -172,6 +173,14 @@ class DateOpsTest(PandasOnSparkTestCase, TestCasesUtils): self.assert_eq(pser, psser.to_pandas()) self.assert_eq(ps.from_pandas(pser), psser) + def test_astype(self): + pser = self.pser + psser = self.psser + self.assert_eq(pser.astype(str), psser.astype(str)) + self.assert_eq(pd.Series([None, None, None]), psser.astype(bool)) + cat_type = CategoricalDtype(categories=["a", "b", "c"]) + self.assert_eq(pser.astype(cat_type), psser.astype(cat_type)) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py index e50e0175b6..436469a442 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_datetime_ops.py @@ -19,6 +19,7 @@ import datetime import numpy as np import pandas as pd +from pandas.api.types import CategoricalDtype from pyspark import pandas as ps from pyspark.pandas.config import option_context @@ -172,6 +173,14 @@ class DatetimeOpsTest(PandasOnSparkTestCase, TestCasesUtils): self.assert_eq(pser, psser.to_pandas()) self.assert_eq(ps.from_pandas(pser), psser) + def test_astype(self): + pser = self.pser + psser = self.psser + self.assert_eq(pser.astype(str), psser.astype(str)) + self.assert_eq(pser.astype("category"), psser.astype("category")) + cat_type = CategoricalDtype(categories=["a", "b", "c"]) + self.assert_eq(pser.astype(cat_type), psser.astype(cat_type)) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py index 0308938464..0e66b10685 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_null_ops.py @@ -16,6 +16,7 @@ # import pandas as pd +from pandas.api.types import CategoricalDtype import pyspark.pandas as ps from pyspark.pandas.config import option_context @@ -122,6 +123,15 @@ class NullOpsTest(PandasOnSparkTestCase, TestCasesUtils): self.assert_eq(pser, psser.to_pandas()) self.assert_eq(ps.from_pandas(pser), psser) + def test_astype(self): + pser = self.pser + psser = self.psser + self.assert_eq(pser.astype(str), psser.astype(str)) + self.assert_eq(pser.astype(bool), psser.astype(bool)) + self.assert_eq(pser.astype("category"), psser.astype("category")) + cat_type = CategoricalDtype(categories=[1, 2, 3]) + self.assert_eq(pser.astype(cat_type), psser.astype(cat_type)) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py index f898bf7b3d..59fa0fc118 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py @@ -20,6 +20,7 @@ from distutils.version import LooseVersion import pandas as pd import numpy as np +from pandas.api.types import CategoricalDtype from pyspark import pandas as ps from pyspark.pandas.config import option_context @@ -293,6 +294,20 @@ class NumOpsTest(PandasOnSparkTestCase, TestCasesUtils): self.assert_eq(pser, psser.to_pandas()) self.assert_eq(ps.from_pandas(pser), psser) + def test_astype(self): + for pser, psser in self.numeric_pser_psser_pairs: + self.assert_eq(pser.astype(int), psser.astype(int)) + self.assert_eq(pser.astype(float), psser.astype(float)) + self.assert_eq(pser.astype(np.float32), psser.astype(np.float32)) + self.assert_eq(pser.astype(np.int32), psser.astype(np.int32)) + self.assert_eq(pser.astype(np.int16), psser.astype(np.int16)) + self.assert_eq(pser.astype(np.int8), psser.astype(np.int8)) + self.assert_eq(pser.astype(str), psser.astype(str)) + self.assert_eq(pser.astype(bool), psser.astype(bool)) + self.assert_eq(pser.astype("category"), psser.astype("category")) + cat_type = CategoricalDtype(categories=[2, 1, 3]) + self.assert_eq(pser.astype(cat_type), psser.astype(cat_type)) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py index 62f9406d1b..12dbc4a7e8 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_string_ops.py @@ -17,6 +17,7 @@ import numpy as np import pandas as pd +from pandas.api.types import CategoricalDtype from pyspark import pandas as ps from pyspark.pandas.config import option_context @@ -154,10 +155,25 @@ class StringOpsTest(PandasOnSparkTestCase, TestCasesUtils): self.assert_eq(pser, psser.to_pandas()) self.assert_eq(ps.from_pandas(pser), psser) + def test_astype(self): + pser = pd.Series(["1", "2", "3"]) + psser = ps.from_pandas(pser) + self.assert_eq(pser.astype(int), psser.astype(int)) + self.assert_eq(pser.astype(float), psser.astype(float)) + self.assert_eq(pser.astype(np.float32), psser.astype(np.float32)) + self.assert_eq(pser.astype(np.int32), psser.astype(np.int32)) + self.assert_eq(pser.astype(np.int16), psser.astype(np.int16)) + self.assert_eq(pser.astype(np.int8), psser.astype(np.int8)) + self.assert_eq(pser.astype(str), psser.astype(str)) + self.assert_eq(pser.astype(bool), psser.astype(bool)) + self.assert_eq(pser.astype("category"), psser.astype("category")) + cat_type = CategoricalDtype(categories=["3", "1", "2"]) + self.assert_eq(pser.astype(cat_type), psser.astype(cat_type)) + if __name__ == "__main__": import unittest - from pyspark.pandas.tests.data_type_ops.test_num_ops import * # noqa: F401 + from pyspark.pandas.tests.data_type_ops.test_string_ops import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py index fdad7a431b..8cdbc97c94 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_udt_ops.py @@ -125,6 +125,9 @@ class UDTOpsTest(PandasOnSparkTestCase, TestCasesUtils): self.assert_eq(pser, psser.to_pandas()) self.assert_eq(ps.from_pandas(pser), psser) + def test_astype(self): + self.assertRaises(TypeError, lambda: self.psser.astype(str)) + if __name__ == "__main__": import unittest