[SPARK-36771][PYTHON] Fix pop of Categorical Series

### What changes were proposed in this pull request?
Fix `pop` of Categorical Series to be consistent with the latest pandas (1.3.2) behavior.

### Why are the changes needed?
As https://github.com/databricks/koalas/issues/2198, pandas API on Spark behaves differently from pandas on `pop` of Categorical Series.

### Does this PR introduce _any_ user-facing change?
Yes, results of `pop` of Categorical Series change.

#### From
```py
>>> psser = ps.Series(["a", "b", "c", "a"], dtype="category")
>>> psser
0    a
1    b
2    c
3    a
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> psser.pop(0)
0
>>> psser
1    b
2    c
3    a
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> psser.pop(3)
0
>>> psser
1    b
2    c
dtype: category
Categories (3, object): ['a', 'b', 'c']
```

#### To
```py
>>> psser = ps.Series(["a", "b", "c", "a"], dtype="category")
>>> psser
0    a
1    b
2    c
3    a
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> psser.pop(0)
'a'
>>> psser
1    b
2    c
3    a
dtype: category
Categories (3, object): ['a', 'b', 'c']
>>> psser.pop(3)
'a'
>>> psser
1    b
2    c
dtype: category
Categories (3, object): ['a', 'b', 'c']

```

### How was this patch tested?
Unit tests.

Closes #34052 from xinrong-databricks/cat_pop.

Authored-by: Xinrong Meng <xinrong.meng@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
Xinrong Meng 2021-09-21 14:11:21 -07:00 committed by Takuya UESHIN
parent b7d99e3eea
commit 079a9c5292
2 changed files with 31 additions and 2 deletions

View file

@ -47,7 +47,7 @@ import numpy as np
import pandas as pd
from pandas.core.accessor import CachedAccessor
from pandas.io.formats.printing import pprint_thing
from pandas.api.types import is_list_like, is_hashable
from pandas.api.types import is_list_like, is_hashable, CategoricalDtype
from pandas.tseries.frequencies import DateOffset
from pyspark.sql import functions as F, Column, DataFrame as SparkDataFrame
from pyspark.sql.types import (
@ -4098,7 +4098,11 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
pdf = sdf.limit(2).toPandas()
length = len(pdf)
if length == 1:
return pdf[internal.data_spark_column_names[0]].iloc[0]
val = pdf[internal.data_spark_column_names[0]].iloc[0]
if isinstance(self.dtype, CategoricalDtype):
return self.dtype.categories[val]
else:
return val
item_string = name_like_string(item)
sdf = sdf.withColumn(SPARK_DEFAULT_INDEX_NAME, SF.lit(str(item_string)))

View file

@ -1669,6 +1669,31 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
with self.assertRaisesRegex(KeyError, msg):
psser.pop(("lama", "speed", "x"))
pser = pd.Series(["a", "b", "c", "a"], dtype="category")
psser = ps.from_pandas(pser)
if LooseVersion(pd.__version__) >= LooseVersion("1.3.0"):
self.assert_eq(psser.pop(0), pser.pop(0))
self.assert_eq(psser, pser)
self.assert_eq(psser.pop(3), pser.pop(3))
self.assert_eq(psser, pser)
else:
# Before pandas 1.3.0, `pop` modifies the dtype of categorical series wrongly.
self.assert_eq(psser.pop(0), "a")
self.assert_eq(
psser,
pd.Series(
pd.Categorical(["b", "c", "a"], categories=["a", "b", "c"]), index=[1, 2, 3]
),
)
self.assert_eq(psser.pop(3), "a")
self.assert_eq(
psser,
pd.Series(pd.Categorical(["b", "c"], categories=["a", "b", "c"]), index=[1, 2]),
)
def test_replace(self):
pser = pd.Series([10, 20, 15, 30, np.nan], name="x")
psser = ps.Series(pser)