From ca8a3f2e23ace915696593540b27abff2b24e631 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Wed, 14 Jul 2021 14:01:10 -0700 Subject: [PATCH] [SPARK-36125][PYTHON] Implement non-equality comparison operators between two Categoricals ### What changes were proposed in this pull request? Implement non-equality comparison operators between two Categoricals. Non-goal: supporting Scalar input will be a follow-up task. ### Why are the changes needed? pandas supports non-equality comparisons between two Categoricals. We should follow that. ### Does this PR introduce _any_ user-facing change? Yes. No `NotImplementedError` for `<`, `<=`, `>`, `>=` operators between two Categoricals. An example is shown as below: From: ```py >>> import pyspark.pandas as ps >>> from pandas.api.types import CategoricalDtype >>> psser = ps.Series([1, 2, 3]).astype(CategoricalDtype([3, 2, 1], ordered=True)) >>> other_psser = ps.Series([2, 1, 3]).astype(CategoricalDtype([3, 2, 1], ordered=True)) >>> with ps.option_context("compute.ops_on_diff_frames", True): ... psser <= other_psser ... Traceback (most recent call last): ... NotImplementedError: <= can not be applied to categoricals. ``` To: ```py >>> import pyspark.pandas as ps >>> from pandas.api.types import CategoricalDtype >>> psser = ps.Series([1, 2, 3]).astype(CategoricalDtype([3, 2, 1], ordered=True)) >>> other_psser = ps.Series([2, 1, 3]).astype(CategoricalDtype([3, 2, 1], ordered=True)) >>> with ps.option_context("compute.ops_on_diff_frames", True): ... psser <= other_psser ... 0 False 1 True 2 True dtype: bool ``` ### How was this patch tested? Unit tests. Closes #33331 from xinrong-databricks/categorical_compare. Authored-by: Xinrong Meng Signed-off-by: Takuya UESHIN (cherry picked from commit 0cb120f390cac96c09ae99c8fbaec2ac06cd2848) Signed-off-by: Takuya UESHIN --- .../pandas/data_type_ops/categorical_ops.py | 29 +++- .../data_type_ops/test_categorical_ops.py | 149 +++++++++++++++++- 2 files changed, 167 insertions(+), 11 deletions(-) diff --git a/python/pyspark/pandas/data_type_ops/categorical_ops.py b/python/pyspark/pandas/data_type_ops/categorical_ops.py index 9238e6bdbb..fb5666d850 100644 --- a/python/pyspark/pandas/data_type_ops/categorical_ops.py +++ b/python/pyspark/pandas/data_type_ops/categorical_ops.py @@ -22,10 +22,12 @@ import pandas as pd from pandas.api.types import CategoricalDtype from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex +from pyspark.pandas.base import column_op, IndexOpsMixin from pyspark.pandas.data_type_ops.base import DataTypeOps from pyspark.pandas.spark import functions as SF from pyspark.pandas.typedef import pandas_on_spark_type from pyspark.sql import functions as F +from pyspark.sql.column import Column class CategoricalOps(DataTypeOps): @@ -64,15 +66,28 @@ class CategoricalOps(DataTypeOps): scol = map_scol.getItem(index_ops.spark.column) return index_ops._with_new_scol(scol).astype(dtype) - # TODO(SPARK-35997): Implement comparison operators below def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: - raise NotImplementedError("< can not be applied to %s." % self.pretty_name) + _non_equality_comparison_input_check(left, right) + return column_op(Column.__lt__)(left, right) def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: - raise NotImplementedError("<= can not be applied to %s." % self.pretty_name) - - def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: - raise NotImplementedError("> can not be applied to %s." % self.pretty_name) + _non_equality_comparison_input_check(left, right) + return column_op(Column.__le__)(left, right) def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: - raise NotImplementedError(">= can not be applied to %s." % self.pretty_name) + _non_equality_comparison_input_check(left, right) + return column_op(Column.__gt__)(left, right) + + def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex: + _non_equality_comparison_input_check(left, right) + return column_op(Column.__ge__)(left, right) + + +def _non_equality_comparison_input_check(left: IndexOpsLike, right: Any) -> None: + if not left.dtype.ordered: + raise TypeError("Unordered Categoricals can only compare equality or not.") + if isinstance(right, IndexOpsMixin) and isinstance(right.dtype, CategoricalDtype): + if hash(left.dtype) != hash(right.dtype): + raise TypeError("Categoricals can only be compared if 'categories' are the same.") + else: + raise TypeError("Cannot compare a Categorical with the given type.") diff --git a/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py index c0fb2408a9..840722c43a 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_categorical_ops.py @@ -44,6 +44,26 @@ class CategoricalOpsTest(PandasOnSparkTestCase, TestCasesUtils): def other_psser(self): return ps.from_pandas(self.other_pser) + @property + def ordered_pser(self): + return pd.Series([1, 2, 3]).astype(CategoricalDtype([3, 2, 1], ordered=True)) + + @property + def ordered_psser(self): + return ps.from_pandas(self.ordered_pser) + + @property + def other_ordered_pser(self): + return pd.Series([2, 1, 3]).astype(CategoricalDtype([3, 2, 1], ordered=True)) + + @property + def other_ordered_psser(self): + return ps.from_pandas(self.other_ordered_pser) + + @property + def unordered_psser(self): + return ps.Series([1, 2, 3]).astype(CategoricalDtype([3, 2, 1])) + def test_add(self): self.assertRaises(TypeError, lambda: self.psser + "x") self.assertRaises(TypeError, lambda: self.psser + 1) @@ -198,16 +218,137 @@ class CategoricalOpsTest(PandasOnSparkTestCase, TestCasesUtils): self.assert_eq(self.pser != self.pser, (self.psser != self.psser).sort_index()) def test_lt(self): - self.assertRaises(NotImplementedError, lambda: self.psser < self.other_psser) + ordered_pser = self.ordered_pser + ordered_psser = self.ordered_psser + self.assert_eq(ordered_pser < ordered_pser, ordered_psser < ordered_psser) + with option_context("compute.ops_on_diff_frames", True): + self.assert_eq( + ordered_pser < self.other_ordered_pser, ordered_psser < self.other_ordered_psser + ) + self.assertRaisesRegex( + TypeError, + "Unordered Categoricals can only compare equality or not", + lambda: self.unordered_psser < ordered_psser, + ) + self.assertRaisesRegex( + TypeError, + "Categoricals can only be compared if 'categories' are the same", + lambda: ordered_psser < self.unordered_psser, + ) + self.assertRaisesRegex( + TypeError, + "Cannot compare a Categorical with the given type", + lambda: ordered_psser < ps.Series([1, 2, 3]), + ) + self.assertRaisesRegex( + TypeError, + "Cannot compare a Categorical with the given type", + lambda: ordered_psser < [1, 2, 3], + ) + self.assertRaisesRegex( + TypeError, "Cannot compare a Categorical with the given type", lambda: ordered_psser < 1 + ) def test_le(self): - self.assertRaises(NotImplementedError, lambda: self.psser <= self.other_psser) + ordered_pser = self.ordered_pser + ordered_psser = self.ordered_psser + self.assert_eq(ordered_pser <= ordered_pser, ordered_psser <= ordered_psser) + + with option_context("compute.ops_on_diff_frames", True): + self.assert_eq( + ordered_pser <= self.other_ordered_pser, ordered_psser <= self.other_ordered_psser + ) + self.assertRaisesRegex( + TypeError, + "Unordered Categoricals can only compare equality or not", + lambda: self.unordered_psser <= ordered_psser, + ) + self.assertRaisesRegex( + TypeError, + "Categoricals can only be compared if 'categories' are the same", + lambda: ordered_psser <= self.unordered_psser, + ) + self.assertRaisesRegex( + TypeError, + "Cannot compare a Categorical with the given type", + lambda: ordered_psser <= ps.Series([1, 2, 3]), + ) + self.assertRaisesRegex( + TypeError, + "Cannot compare a Categorical with the given type", + lambda: ordered_psser <= [1, 2, 3], + ) + self.assertRaisesRegex( + TypeError, + "Cannot compare a Categorical with the given type", + lambda: ordered_psser <= 1, + ) def test_gt(self): - self.assertRaises(NotImplementedError, lambda: self.psser > self.other_psser) + ordered_pser = self.ordered_pser + ordered_psser = self.ordered_psser + self.assert_eq(ordered_pser > ordered_pser, ordered_psser > ordered_psser) + with option_context("compute.ops_on_diff_frames", True): + self.assert_eq( + ordered_pser > self.other_ordered_pser, ordered_psser > self.other_ordered_psser + ) + self.assertRaisesRegex( + TypeError, + "Unordered Categoricals can only compare equality or not", + lambda: self.unordered_psser > ordered_psser, + ) + self.assertRaisesRegex( + TypeError, + "Categoricals can only be compared if 'categories' are the same", + lambda: ordered_psser > self.unordered_psser, + ) + self.assertRaisesRegex( + TypeError, + "Cannot compare a Categorical with the given type", + lambda: ordered_psser > ps.Series([1, 2, 3]), + ) + self.assertRaisesRegex( + TypeError, + "Cannot compare a Categorical with the given type", + lambda: ordered_psser > [1, 2, 3], + ) + self.assertRaisesRegex( + TypeError, "Cannot compare a Categorical with the given type", lambda: ordered_psser > 1 + ) def test_ge(self): - self.assertRaises(NotImplementedError, lambda: self.psser >= self.other_psser) + ordered_pser = self.ordered_pser + ordered_psser = self.ordered_psser + self.assert_eq(ordered_pser >= ordered_pser, ordered_psser >= ordered_psser) + with option_context("compute.ops_on_diff_frames", True): + self.assert_eq( + ordered_pser >= self.other_ordered_pser, ordered_psser >= self.other_ordered_psser + ) + self.assertRaisesRegex( + TypeError, + "Unordered Categoricals can only compare equality or not", + lambda: self.unordered_psser >= ordered_psser, + ) + self.assertRaisesRegex( + TypeError, + "Categoricals can only be compared if 'categories' are the same", + lambda: ordered_psser >= self.unordered_psser, + ) + self.assertRaisesRegex( + TypeError, + "Cannot compare a Categorical with the given type", + lambda: ordered_psser >= ps.Series([1, 2, 3]), + ) + self.assertRaisesRegex( + TypeError, + "Cannot compare a Categorical with the given type", + lambda: ordered_psser >= [1, 2, 3], + ) + self.assertRaisesRegex( + TypeError, + "Cannot compare a Categorical with the given type", + lambda: ordered_psser >= 1, + ) if __name__ == "__main__":