[SPARK-36368][PYTHON] Fix CategoricalOps.astype to follow pandas 1.3
### What changes were proposed in this pull request? This PR proposes to fix the behavior of `astype` for `CategoricalDtype` to follow pandas 1.3. **Before:** ```python >>> pcat 0 a 1 b 2 c dtype: category Categories (3, object): ['a', 'b', 'c'] >>> pcat.astype(CategoricalDtype(["b", "c", "a"])) 0 a 1 b 2 c dtype: category Categories (3, object): ['b', 'c', 'a'] ``` **After:** ```python >>> pcat 0 a 1 b 2 c dtype: category Categories (3, object): ['a', 'b', 'c'] >>> pcat.astype(CategoricalDtype(["b", "c", "a"])) 0 a 1 b 2 c dtype: category Categories (3, object): ['a', 'b', 'c'] # CategoricalDtype is not updated if dtype is the same ``` `CategoricalDtype` is treated as a same `dtype` if the unique values are the same. ```python >>> pcat1 = pser.astype(CategoricalDtype(["b", "c", "a"])) >>> pcat2 = pser.astype(CategoricalDtype(["a", "b", "c"])) >>> pcat1.dtype == pcat2.dtype True ``` ### Why are the changes needed? We should follow the latest pandas as much as possible. ### Does this PR introduce _any_ user-facing change? Yes, the behavior is changed as example in the PR description. ### How was this patch tested? Unittest Closes #33757 from itholic/SPARK-36368. Authored-by: itholic <haejoon.lee@databricks.com> Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
parent
c91ae544fd
commit
f2e593bcf1
|
@ -22,6 +22,7 @@ from pandas.api.types import CategoricalDtype, is_dict_like, is_list_like
|
|||
|
||||
from pyspark.pandas.internal import InternalField
|
||||
from pyspark.pandas.spark import functions as SF
|
||||
from pyspark.pandas.data_type_ops.categorical_ops import _to_cat
|
||||
from pyspark.sql import functions as F
|
||||
from pyspark.sql.types import StructField
|
||||
|
||||
|
@ -735,7 +736,7 @@ class CategoricalAccessor(object):
|
|||
return self._data.copy()
|
||||
else:
|
||||
dtype = CategoricalDtype(categories=new_categories, ordered=ordered)
|
||||
psser = self._data.astype(dtype)
|
||||
psser = _to_cat(self._data).astype(dtype)
|
||||
|
||||
if inplace:
|
||||
internal = self._data._psdf._internal.with_new_spark_column(
|
||||
|
|
|
@ -57,7 +57,9 @@ class CategoricalOps(DataTypeOps):
|
|||
def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
|
||||
dtype, _ = pandas_on_spark_type(dtype)
|
||||
|
||||
if isinstance(dtype, CategoricalDtype) and cast(CategoricalDtype, dtype).categories is None:
|
||||
if isinstance(dtype, CategoricalDtype) and (
|
||||
(dtype.categories is None) or (index_ops.dtype == dtype)
|
||||
):
|
||||
return index_ops.copy()
|
||||
|
||||
return _to_cat(index_ops).astype(dtype)
|
||||
|
|
|
@ -192,13 +192,11 @@ class CategoricalOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
|||
self.assert_eq(pser.astype("category"), psser.astype("category"))
|
||||
|
||||
cat_type = CategoricalDtype(categories=[3, 1, 2])
|
||||
# CategoricalDtype is not updated if the dtype is same from pandas 1.3.
|
||||
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
|
||||
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
|
||||
pass
|
||||
elif LooseVersion(pd.__version__) >= LooseVersion("1.2"):
|
||||
self.assert_eq(pser.astype(cat_type), psser.astype(cat_type))
|
||||
else:
|
||||
self.assert_eq(pd.Series(data).astype(cat_type), psser.astype(cat_type))
|
||||
self.assert_eq(psser.astype(cat_type), pser)
|
||||
|
||||
def test_neg(self):
|
||||
self.assertRaises(TypeError, lambda: -self.psser)
|
||||
|
|
|
@ -172,25 +172,23 @@ class CategoricalIndexTest(PandasOnSparkTestCase, TestUtils):
|
|||
)
|
||||
|
||||
pcidx = pidx.astype(CategoricalDtype(["c", "a", "b"]))
|
||||
kcidx = psidx.astype(CategoricalDtype(["c", "a", "b"]))
|
||||
pscidx = psidx.astype(CategoricalDtype(["c", "a", "b"]))
|
||||
|
||||
self.assert_eq(kcidx.astype("category"), pcidx.astype("category"))
|
||||
self.assert_eq(pscidx.astype("category"), pcidx.astype("category"))
|
||||
|
||||
# CategoricalDtype is not updated if the dtype is same from pandas 1.3.
|
||||
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
|
||||
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
|
||||
pass
|
||||
elif LooseVersion(pd.__version__) >= LooseVersion("1.2"):
|
||||
self.assert_eq(
|
||||
kcidx.astype(CategoricalDtype(["b", "c", "a"])),
|
||||
pscidx.astype(CategoricalDtype(["b", "c", "a"])),
|
||||
pcidx.astype(CategoricalDtype(["b", "c", "a"])),
|
||||
)
|
||||
else:
|
||||
self.assert_eq(
|
||||
kcidx.astype(CategoricalDtype(["b", "c", "a"])),
|
||||
pidx.astype(CategoricalDtype(["b", "c", "a"])),
|
||||
pscidx.astype(CategoricalDtype(["b", "c", "a"])),
|
||||
pcidx,
|
||||
)
|
||||
|
||||
self.assert_eq(kcidx.astype(str), pcidx.astype(str))
|
||||
self.assert_eq(pscidx.astype(str), pcidx.astype(str))
|
||||
|
||||
def test_factorize(self):
|
||||
pidx = pd.CategoricalIndex([1, 2, 3, None])
|
||||
|
|
|
@ -239,25 +239,23 @@ class CategoricalTest(PandasOnSparkTestCase, TestUtils):
|
|||
)
|
||||
|
||||
pcser = pser.astype(CategoricalDtype(["c", "a", "b"]))
|
||||
kcser = psser.astype(CategoricalDtype(["c", "a", "b"]))
|
||||
pscser = psser.astype(CategoricalDtype(["c", "a", "b"]))
|
||||
|
||||
self.assert_eq(kcser.astype("category"), pcser.astype("category"))
|
||||
self.assert_eq(pscser.astype("category"), pcser.astype("category"))
|
||||
|
||||
# CategoricalDtype is not updated if the dtype is same from pandas 1.3.
|
||||
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
|
||||
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
|
||||
pass
|
||||
elif LooseVersion(pd.__version__) >= LooseVersion("1.2"):
|
||||
self.assert_eq(
|
||||
kcser.astype(CategoricalDtype(["b", "c", "a"])),
|
||||
pscser.astype(CategoricalDtype(["b", "c", "a"])),
|
||||
pcser.astype(CategoricalDtype(["b", "c", "a"])),
|
||||
)
|
||||
else:
|
||||
self.assert_eq(
|
||||
kcser.astype(CategoricalDtype(["b", "c", "a"])),
|
||||
pser.astype(CategoricalDtype(["b", "c", "a"])),
|
||||
pscser.astype(CategoricalDtype(["b", "c", "a"])),
|
||||
pcser,
|
||||
)
|
||||
|
||||
self.assert_eq(kcser.astype(str), pcser.astype(str))
|
||||
self.assert_eq(pscser.astype(str), pcser.astype(str))
|
||||
|
||||
def test_factorize(self):
|
||||
pser = pd.Series(["a", "b", "c", None], dtype=CategoricalDtype(["c", "a", "d", "b"]))
|
||||
|
|
Loading…
Reference in a new issue