diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index d72c08d716..da0d2fbcda 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -47,7 +47,7 @@ import numpy as np import pandas as pd from pandas.core.accessor import CachedAccessor from pandas.io.formats.printing import pprint_thing -from pandas.api.types import is_list_like, is_hashable +from pandas.api.types import is_list_like, is_hashable, CategoricalDtype from pandas.tseries.frequencies import DateOffset from pyspark.sql import functions as F, Column, DataFrame as SparkDataFrame from pyspark.sql.types import ( @@ -4098,7 +4098,11 @@ class Series(Frame, IndexOpsMixin, Generic[T]): pdf = sdf.limit(2).toPandas() length = len(pdf) if length == 1: - return pdf[internal.data_spark_column_names[0]].iloc[0] + val = pdf[internal.data_spark_column_names[0]].iloc[0] + if isinstance(self.dtype, CategoricalDtype): + return self.dtype.categories[val] + else: + return val item_string = name_like_string(item) sdf = sdf.withColumn(SPARK_DEFAULT_INDEX_NAME, SF.lit(str(item_string))) diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py index 09e5d304dc..b7bb121725 100644 --- a/python/pyspark/pandas/tests/test_series.py +++ b/python/pyspark/pandas/tests/test_series.py @@ -1669,6 +1669,31 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils): with self.assertRaisesRegex(KeyError, msg): psser.pop(("lama", "speed", "x")) + pser = pd.Series(["a", "b", "c", "a"], dtype="category") + psser = ps.from_pandas(pser) + + if LooseVersion(pd.__version__) >= LooseVersion("1.3.0"): + self.assert_eq(psser.pop(0), pser.pop(0)) + self.assert_eq(psser, pser) + + self.assert_eq(psser.pop(3), pser.pop(3)) + self.assert_eq(psser, pser) + else: + # Before pandas 1.3.0, `pop` modifies the dtype of categorical series wrongly. + self.assert_eq(psser.pop(0), "a") + self.assert_eq( + psser, + pd.Series( + pd.Categorical(["b", "c", "a"], categories=["a", "b", "c"]), index=[1, 2, 3] + ), + ) + + self.assert_eq(psser.pop(3), "a") + self.assert_eq( + psser, + pd.Series(pd.Categorical(["b", "c"], categories=["a", "b", "c"]), index=[1, 2]), + ) + def test_replace(self): pser = pd.Series([10, 20, 15, 30, np.nan], name="x") psser = ps.Series(pser)