[SPARK-36771][PYTHON][3.2] Fix pop
of Categorical Series
### What changes were proposed in this pull request? Fix `pop` of Categorical Series to be consistent with the latest pandas (1.3.2) behavior. This is a backport of https://github.com/apache/spark/pull/34052. ### Why are the changes needed? As https://github.com/databricks/koalas/issues/2198, pandas API on Spark behaves differently from pandas on `pop` of Categorical Series. ### Does this PR introduce _any_ user-facing change? Yes, results of `pop` of Categorical Series change. #### From ```py >>> psser = ps.Series(["a", "b", "c", "a"], dtype="category") >>> psser 0 a 1 b 2 c 3 a dtype: category Categories (3, object): ['a', 'b', 'c'] >>> psser.pop(0) 0 >>> psser 1 b 2 c 3 a dtype: category Categories (3, object): ['a', 'b', 'c'] >>> psser.pop(3) 0 >>> psser 1 b 2 c dtype: category Categories (3, object): ['a', 'b', 'c'] ``` #### To ```py >>> psser = ps.Series(["a", "b", "c", "a"], dtype="category") >>> psser 0 a 1 b 2 c 3 a dtype: category Categories (3, object): ['a', 'b', 'c'] >>> psser.pop(0) 'a' >>> psser 1 b 2 c 3 a dtype: category Categories (3, object): ['a', 'b', 'c'] >>> psser.pop(3) 'a' >>> psser 1 b 2 c dtype: category Categories (3, object): ['a', 'b', 'c'] ``` ### How was this patch tested? Unit tests. Closes #34063 from xinrong-databricks/backport_cat_pop. Authored-by: Xinrong Meng <xinrong.meng@databricks.com> Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
parent
affd7a4d47
commit
4543ac62bc
|
@ -47,7 +47,7 @@ import numpy as np
|
|||
import pandas as pd
|
||||
from pandas.core.accessor import CachedAccessor
|
||||
from pandas.io.formats.printing import pprint_thing
|
||||
from pandas.api.types import is_list_like, is_hashable
|
||||
from pandas.api.types import is_list_like, is_hashable, CategoricalDtype
|
||||
from pandas.api.extensions import ExtensionDtype
|
||||
from pandas.tseries.frequencies import DateOffset
|
||||
from pyspark.sql import functions as F, Column, DataFrame as SparkDataFrame
|
||||
|
@ -4060,7 +4060,11 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
|
|||
pdf = sdf.limit(2).toPandas()
|
||||
length = len(pdf)
|
||||
if length == 1:
|
||||
return pdf[internal.data_spark_column_names[0]].iloc[0]
|
||||
val = pdf[internal.data_spark_column_names[0]].iloc[0]
|
||||
if isinstance(self.dtype, CategoricalDtype):
|
||||
return self.dtype.categories[val]
|
||||
else:
|
||||
return val
|
||||
|
||||
item_string = name_like_string(item)
|
||||
sdf = sdf.withColumn(SPARK_DEFAULT_INDEX_NAME, SF.lit(str(item_string)))
|
||||
|
|
|
@ -1669,6 +1669,31 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
|
|||
with self.assertRaisesRegex(KeyError, msg):
|
||||
psser.pop(("lama", "speed", "x"))
|
||||
|
||||
pser = pd.Series(["a", "b", "c", "a"], dtype="category")
|
||||
psser = ps.from_pandas(pser)
|
||||
|
||||
if LooseVersion(pd.__version__) >= LooseVersion("1.3.0"):
|
||||
self.assert_eq(psser.pop(0), pser.pop(0))
|
||||
self.assert_eq(psser, pser)
|
||||
|
||||
self.assert_eq(psser.pop(3), pser.pop(3))
|
||||
self.assert_eq(psser, pser)
|
||||
else:
|
||||
# Before pandas 1.3.0, `pop` modifies the dtype of categorical series wrongly.
|
||||
self.assert_eq(psser.pop(0), "a")
|
||||
self.assert_eq(
|
||||
psser,
|
||||
pd.Series(
|
||||
pd.Categorical(["b", "c", "a"], categories=["a", "b", "c"]), index=[1, 2, 3]
|
||||
),
|
||||
)
|
||||
|
||||
self.assert_eq(psser.pop(3), "a")
|
||||
self.assert_eq(
|
||||
psser,
|
||||
pd.Series(pd.Categorical(["b", "c"], categories=["a", "b", "c"]), index=[1, 2]),
|
||||
)
|
||||
|
||||
def test_replace(self):
|
||||
pser = pd.Series([10, 20, 15, 30, np.nan], name="x")
|
||||
psser = ps.Series(pser)
|
||||
|
|
Loading…
Reference in a new issue