diff --git a/python/docs/source/reference/pyspark.pandas/indexing.rst b/python/docs/source/reference/pyspark.pandas/indexing.rst index b0b4cdd221..60c4c4b339 100644 --- a/python/docs/source/reference/pyspark.pandas/indexing.rst +++ b/python/docs/source/reference/pyspark.pandas/indexing.rst @@ -176,6 +176,7 @@ Categorical components CategoricalIndex.categories CategoricalIndex.ordered CategoricalIndex.add_categories + CategoricalIndex.remove_categories CategoricalIndex.as_ordered CategoricalIndex.as_unordered diff --git a/python/docs/source/reference/pyspark.pandas/series.rst b/python/docs/source/reference/pyspark.pandas/series.rst index 6243a22454..fec89de65d 100644 --- a/python/docs/source/reference/pyspark.pandas/series.rst +++ b/python/docs/source/reference/pyspark.pandas/series.rst @@ -402,6 +402,7 @@ the ``Series.cat`` accessor. Series.cat.ordered Series.cat.codes Series.cat.add_categories + Series.cat.remove_categories Series.cat.as_ordered Series.cat.as_unordered diff --git a/python/pyspark/pandas/categorical.py b/python/pyspark/pandas/categorical.py index a83c3c741c..529041bc75 100644 --- a/python/pyspark/pandas/categorical.py +++ b/python/pyspark/pandas/categorical.py @@ -348,8 +348,96 @@ class CategoricalAccessor(object): """ return self._set_ordered(ordered=False, inplace=inplace) - def remove_categories(self, removals: pd.Index, inplace: bool = False) -> "ps.Series": - raise NotImplementedError() + def remove_categories( + self, removals: Union[pd.Index, Any, List], inplace: bool = False + ) -> Optional["ps.Series"]: + """ + Remove the specified categories. + + `removals` must be included in the old categories. Values which were in + the removed categories will be set to NaN + + Parameters + ---------- + removals : category or list of categories + The categories which should be removed. + inplace : bool, default False + Whether or not to remove the categories inplace or return a copy of + this categorical with removed categories. + + Returns + ------- + Series or None + Categorical with removed categories or None if ``inplace=True``. + + Raises + ------ + ValueError + If the removals are not contained in the categories + + Examples + -------- + >>> s = ps.Series(list("abbccc"), dtype="category") + >>> s # doctest: +SKIP + 0 a + 1 b + 2 b + 3 c + 4 c + 5 c + dtype: category + Categories (3, object): ['a', 'b', 'c'] + + >>> s.cat.remove_categories('b') # doctest: +SKIP + 0 a + 1 NaN + 2 NaN + 3 c + 4 c + 5 c + dtype: category + Categories (2, object): ['a', 'c'] + """ + if is_list_like(removals): + categories = [cat for cat in removals if cat is not None] # type: List + elif removals is None: + categories = [] + else: + categories = [removals] + + if any(cat not in self.categories for cat in categories): + raise ValueError( + "removals must all be in old categories: {{{cats}}}".format( + cats=", ".join( + set(str(cat) for cat in categories if cat not in self.categories) + ) + ) + ) + + if len(categories) == 0: + if inplace: + return None + else: + psser = self._data + return psser._with_new_scol( + psser.spark.column, field=psser._internal.data_fields[0] + ) + else: + dtype = CategoricalDtype( + [cat for cat in self.categories if cat not in categories], ordered=self.ordered + ) + psser = self._data.astype(dtype) + + if inplace: + internal = self._data._psdf._internal.with_new_spark_column( + self._data._column_label, + psser.spark.column, + field=psser._internal.data_fields[0], + ) + self._data._psdf._update_internal_frame(internal) + return None + else: + return psser def remove_unused_categories(self) -> "ps.Series": raise NotImplementedError() diff --git a/python/pyspark/pandas/indexes/category.py b/python/pyspark/pandas/indexes/category.py index 308043e353..28b50271bb 100644 --- a/python/pyspark/pandas/indexes/category.py +++ b/python/pyspark/pandas/indexes/category.py @@ -312,6 +312,49 @@ class CategoricalIndex(Index): return CategoricalIndex(self.to_series().cat.as_unordered()).rename(self.name) + def remove_categories( + self, removals: Union[pd.Index, Any, List], inplace: bool = False + ) -> Optional["CategoricalIndex"]: + """ + Remove the specified categories. + + `removals` must be included in the old categories. Values which were in + the removed categories will be set to NaN + + Parameters + ---------- + removals : category or list of categories + The categories which should be removed. + inplace : bool, default False + Whether or not to remove the categories inplace or return a copy of + this categorical with removed categories. + + Returns + ------- + CategoricalIndex or None + Categorical with removed categories or None if ``inplace=True``. + + Raises + ------ + ValueError + If the removals are not contained in the categories + + Examples + -------- + >>> idx = ps.CategoricalIndex(list("abbccc")) + >>> idx # doctest: +NORMALIZE_WHITESPACE + CategoricalIndex(['a', 'b', 'b', 'c', 'c', 'c'], + categories=['a', 'b', 'c'], ordered=False, dtype='category') + + >>> idx.remove_categories('b') # doctest: +NORMALIZE_WHITESPACE + CategoricalIndex(['a', nan, nan, 'c', 'c', 'c'], + categories=['a', 'c'], ordered=False, dtype='category') + """ + if inplace: + raise ValueError("cannot use inplace with CategoricalIndex") + + return CategoricalIndex(self.to_series().cat.remove_categories(removals)).rename(self.name) + def __getattr__(self, item: str) -> Any: if hasattr(MissingPandasLikeCategoricalIndex, item): property_or_func = getattr(MissingPandasLikeCategoricalIndex, item) diff --git a/python/pyspark/pandas/missing/indexes.py b/python/pyspark/pandas/missing/indexes.py index 2a5a4c9087..e550801ce3 100644 --- a/python/pyspark/pandas/missing/indexes.py +++ b/python/pyspark/pandas/missing/indexes.py @@ -125,7 +125,6 @@ class MissingPandasLikeCategoricalIndex(MissingPandasLikeIndex): # Functions rename_categories = _unsupported_function("rename_categories", cls="CategoricalIndex") reorder_categories = _unsupported_function("reorder_categories", cls="CategoricalIndex") - remove_categories = _unsupported_function("remove_categories", cls="CategoricalIndex") remove_unused_categories = _unsupported_function( "remove_unused_categories", cls="CategoricalIndex" ) diff --git a/python/pyspark/pandas/tests/indexes/test_category.py b/python/pyspark/pandas/tests/indexes/test_category.py index 44e270348e..ebda1be7c6 100644 --- a/python/pyspark/pandas/tests/indexes/test_category.py +++ b/python/pyspark/pandas/tests/indexes/test_category.py @@ -106,6 +106,22 @@ class CategoricalIndexTest(PandasOnSparkTestCase, TestUtils): self.assertRaises(ValueError, lambda: psidx.add_categories(3)) self.assertRaises(ValueError, lambda: psidx.add_categories([4, 4])) + def test_remove_categories(self): + pidx = pd.CategoricalIndex([1, 2, 3], categories=[3, 2, 1]) + psidx = ps.from_pandas(pidx) + + self.assert_eq(pidx.remove_categories(2), psidx.remove_categories(2)) + self.assert_eq(pidx.remove_categories([1, 3]), psidx.remove_categories([1, 3])) + self.assert_eq(pidx.remove_categories([]), psidx.remove_categories([])) + self.assert_eq(pidx.remove_categories([2, 2]), psidx.remove_categories([2, 2])) + self.assert_eq(pidx.remove_categories([1, 2, 3]), psidx.remove_categories([1, 2, 3])) + self.assert_eq(pidx.remove_categories(None), psidx.remove_categories(None)) + self.assert_eq(pidx.remove_categories([None]), psidx.remove_categories([None])) + + self.assertRaises(ValueError, lambda: pidx.remove_categories(4, inplace=True)) + self.assertRaises(ValueError, lambda: psidx.remove_categories(4)) + self.assertRaises(ValueError, lambda: psidx.remove_categories([4, None])) + def test_as_ordered_unordered(self): pidx = pd.CategoricalIndex(["x", "y", "z"], categories=["z", "y", "x"]) psidx = ps.from_pandas(pidx) diff --git a/python/pyspark/pandas/tests/test_categorical.py b/python/pyspark/pandas/tests/test_categorical.py index 1af03d69df..cf36563e7e 100644 --- a/python/pyspark/pandas/tests/test_categorical.py +++ b/python/pyspark/pandas/tests/test_categorical.py @@ -97,6 +97,30 @@ class CategoricalTest(PandasOnSparkTestCase, TestUtils): self.assertRaises(ValueError, lambda: psser.cat.add_categories(4)) self.assertRaises(ValueError, lambda: psser.cat.add_categories([5, 5])) + def test_remove_categories(self): + pdf, psdf = self.df_pair + + pser = pdf.a + psser = psdf.a + + self.assert_eq(pser.cat.remove_categories(2), psser.cat.remove_categories(2)) + self.assert_eq(pser.cat.remove_categories([1, 3]), psser.cat.remove_categories([1, 3])) + self.assert_eq(pser.cat.remove_categories([]), psser.cat.remove_categories([])) + self.assert_eq(pser.cat.remove_categories([2, 2]), psser.cat.remove_categories([2, 2])) + self.assert_eq( + pser.cat.remove_categories([1, 2, 3]), psser.cat.remove_categories([1, 2, 3]) + ) + self.assert_eq(pser.cat.remove_categories(None), psser.cat.remove_categories(None)) + self.assert_eq(pser.cat.remove_categories([None]), psser.cat.remove_categories([None])) + + pser.cat.remove_categories(2, inplace=True) + psser.cat.remove_categories(2, inplace=True) + self.assert_eq(pser, psser) + self.assert_eq(pdf, psdf) + + self.assertRaises(ValueError, lambda: psser.cat.remove_categories(4)) + self.assertRaises(ValueError, lambda: psser.cat.remove_categories([4, None])) + def test_as_ordered_unordered(self): pdf, psdf = self.df_pair