[SPARK-36771][PYTHON][3.2] Fix pop of Categorical Series

### What changes were proposed in this pull request?
Fix `pop` of Categorical Series to be consistent with the latest pandas (1.3.2) behavior.

This is a backport of https://github.com/apache/spark/pull/34052.

### Why are the changes needed?
As https://github.com/databricks/koalas/issues/2198, pandas API on Spark behaves differently from pandas on `pop` of Categorical Series.

### Does this PR introduce _any_ user-facing change?
Yes, results of `pop` of Categorical Series change.

#### From
```py
>>> psser = ps.Series(["a", "b", "c", "a"], dtype="category")
>>> psser
0    a
1    b
2    c
3    a
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> psser.pop(0)
0
>>> psser
1    b
2    c
3    a
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> psser.pop(3)
0
>>> psser
1    b
2    c
dtype: category
Categories (3, object): ['a', 'b', 'c']
```

#### To
```py
>>> psser = ps.Series(["a", "b", "c", "a"], dtype="category")
>>> psser
0    a
1    b
2    c
3    a
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> psser.pop(0)
'a'
>>> psser
1    b
2    c
3    a
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> psser.pop(3)
'a'
>>> psser
1    b
2    c
dtype: category
Categories (3, object): ['a', 'b', 'c']

```

### How was this patch tested?
Unit tests.

Closes #34063 from xinrong-databricks/backport_cat_pop.

Authored-by: Xinrong Meng <xinrong.meng@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
Xinrong Meng 2021-09-21 19:16:27 -07:00 committed by Takuya UESHIN
parent affd7a4d47
commit 4543ac62bc
2 changed files with 31 additions and 2 deletions

View file

@ -47,7 +47,7 @@ import numpy as np
import pandas as pd import pandas as pd
from pandas.core.accessor import CachedAccessor from pandas.core.accessor import CachedAccessor
from pandas.io.formats.printing import pprint_thing from pandas.io.formats.printing import pprint_thing
from pandas.api.types import is_list_like, is_hashable from pandas.api.types import is_list_like, is_hashable, CategoricalDtype
from pandas.api.extensions import ExtensionDtype from pandas.api.extensions import ExtensionDtype
from pandas.tseries.frequencies import DateOffset from pandas.tseries.frequencies import DateOffset
from pyspark.sql import functions as F, Column, DataFrame as SparkDataFrame from pyspark.sql import functions as F, Column, DataFrame as SparkDataFrame
@ -4060,7 +4060,11 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
pdf = sdf.limit(2).toPandas() pdf = sdf.limit(2).toPandas()
length = len(pdf) length = len(pdf)
if length == 1: if length == 1:
return pdf[internal.data_spark_column_names[0]].iloc[0] val = pdf[internal.data_spark_column_names[0]].iloc[0]
if isinstance(self.dtype, CategoricalDtype):
return self.dtype.categories[val]
else:
return val
item_string = name_like_string(item) item_string = name_like_string(item)
sdf = sdf.withColumn(SPARK_DEFAULT_INDEX_NAME, SF.lit(str(item_string))) sdf = sdf.withColumn(SPARK_DEFAULT_INDEX_NAME, SF.lit(str(item_string)))

View file

@ -1669,6 +1669,31 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
with self.assertRaisesRegex(KeyError, msg): with self.assertRaisesRegex(KeyError, msg):
psser.pop(("lama", "speed", "x")) psser.pop(("lama", "speed", "x"))
pser = pd.Series(["a", "b", "c", "a"], dtype="category")
psser = ps.from_pandas(pser)
if LooseVersion(pd.__version__) >= LooseVersion("1.3.0"):
self.assert_eq(psser.pop(0), pser.pop(0))
self.assert_eq(psser, pser)
self.assert_eq(psser.pop(3), pser.pop(3))
self.assert_eq(psser, pser)
else:
# Before pandas 1.3.0, `pop` modifies the dtype of categorical series wrongly.
self.assert_eq(psser.pop(0), "a")
self.assert_eq(
psser,
pd.Series(
pd.Categorical(["b", "c", "a"], categories=["a", "b", "c"]), index=[1, 2, 3]
),
)
self.assert_eq(psser.pop(3), "a")
self.assert_eq(
psser,
pd.Series(pd.Categorical(["b", "c"], categories=["a", "b", "c"]), index=[1, 2]),
)
def test_replace(self): def test_replace(self):
pser = pd.Series([10, 20, 15, 30, np.nan], name="x") pser = pd.Series([10, 20, 15, 30, np.nan], name="x")
psser = ps.Series(pser) psser = ps.Series(pser)