From 4543ac62bce8a71eed8b913228d06647a556e654 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Tue, 21 Sep 2021 19:16:27 -0700 Subject: [PATCH] [SPARK-36771][PYTHON][3.2] Fix `pop` of Categorical Series ### What changes were proposed in this pull request? Fix `pop` of Categorical Series to be consistent with the latest pandas (1.3.2) behavior. This is a backport of https://github.com/apache/spark/pull/34052. ### Why are the changes needed? As https://github.com/databricks/koalas/issues/2198, pandas API on Spark behaves differently from pandas on `pop` of Categorical Series. ### Does this PR introduce _any_ user-facing change? Yes, results of `pop` of Categorical Series change. #### From ```py >>> psser = ps.Series(["a", "b", "c", "a"], dtype="category") >>> psser 0 a 1 b 2 c 3 a dtype: category Categories (3, object): ['a', 'b', 'c'] >>> psser.pop(0) 0 >>> psser 1 b 2 c 3 a dtype: category Categories (3, object): ['a', 'b', 'c'] >>> psser.pop(3) 0 >>> psser 1 b 2 c dtype: category Categories (3, object): ['a', 'b', 'c'] ``` #### To ```py >>> psser = ps.Series(["a", "b", "c", "a"], dtype="category") >>> psser 0 a 1 b 2 c 3 a dtype: category Categories (3, object): ['a', 'b', 'c'] >>> psser.pop(0) 'a' >>> psser 1 b 2 c 3 a dtype: category Categories (3, object): ['a', 'b', 'c'] >>> psser.pop(3) 'a' >>> psser 1 b 2 c dtype: category Categories (3, object): ['a', 'b', 'c'] ``` ### How was this patch tested? Unit tests. Closes #34063 from xinrong-databricks/backport_cat_pop. Authored-by: Xinrong Meng Signed-off-by: Takuya UESHIN --- python/pyspark/pandas/series.py | 8 +++++-- python/pyspark/pandas/tests/test_series.py | 25 ++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index 0eebcc9745..e96b3228f1 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -47,7 +47,7 @@ import numpy as np import pandas as pd from pandas.core.accessor import CachedAccessor from pandas.io.formats.printing import pprint_thing -from pandas.api.types import is_list_like, is_hashable +from pandas.api.types import is_list_like, is_hashable, CategoricalDtype from pandas.api.extensions import ExtensionDtype from pandas.tseries.frequencies import DateOffset from pyspark.sql import functions as F, Column, DataFrame as SparkDataFrame @@ -4060,7 +4060,11 @@ class Series(Frame, IndexOpsMixin, Generic[T]): pdf = sdf.limit(2).toPandas() length = len(pdf) if length == 1: - return pdf[internal.data_spark_column_names[0]].iloc[0] + val = pdf[internal.data_spark_column_names[0]].iloc[0] + if isinstance(self.dtype, CategoricalDtype): + return self.dtype.categories[val] + else: + return val item_string = name_like_string(item) sdf = sdf.withColumn(SPARK_DEFAULT_INDEX_NAME, SF.lit(str(item_string))) diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py index cbfc999515..bde0a34f38 100644 --- a/python/pyspark/pandas/tests/test_series.py +++ b/python/pyspark/pandas/tests/test_series.py @@ -1669,6 +1669,31 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils): with self.assertRaisesRegex(KeyError, msg): psser.pop(("lama", "speed", "x")) + pser = pd.Series(["a", "b", "c", "a"], dtype="category") + psser = ps.from_pandas(pser) + + if LooseVersion(pd.__version__) >= LooseVersion("1.3.0"): + self.assert_eq(psser.pop(0), pser.pop(0)) + self.assert_eq(psser, pser) + + self.assert_eq(psser.pop(3), pser.pop(3)) + self.assert_eq(psser, pser) + else: + # Before pandas 1.3.0, `pop` modifies the dtype of categorical series wrongly. + self.assert_eq(psser.pop(0), "a") + self.assert_eq( + psser, + pd.Series( + pd.Categorical(["b", "c", "a"], categories=["a", "b", "c"]), index=[1, 2, 3] + ), + ) + + self.assert_eq(psser.pop(3), "a") + self.assert_eq( + psser, + pd.Series(pd.Categorical(["b", "c"], categories=["a", "b", "c"]), index=[1, 2]), + ) + def test_replace(self): pser = pd.Series([10, 20, 15, 30, np.nan], name="x") psser = ps.Series(pser)