[SPARK-36125][PYTHON] Implement non-equality comparison operators between two Categoricals
### What changes were proposed in this pull request?
Implement non-equality comparison operators between two Categoricals.
Non-goal: supporting Scalar input will be a follow-up task.
### Why are the changes needed?
pandas supports non-equality comparisons between two Categoricals. We should follow that.
### Does this PR introduce _any_ user-facing change?
Yes. No `NotImplementedError` for `<`, `<=`, `>`, `>=` operators between two Categoricals. An example is shown as below:
From:
```py
>>> import pyspark.pandas as ps
>>> from pandas.api.types import CategoricalDtype
>>> psser = ps.Series([1, 2, 3]).astype(CategoricalDtype([3, 2, 1], ordered=True))
>>> other_psser = ps.Series([2, 1, 3]).astype(CategoricalDtype([3, 2, 1], ordered=True))
>>> with ps.option_context("compute.ops_on_diff_frames", True):
... psser <= other_psser
...
Traceback (most recent call last):
...
NotImplementedError: <= can not be applied to categoricals.
```
To:
```py
>>> import pyspark.pandas as ps
>>> from pandas.api.types import CategoricalDtype
>>> psser = ps.Series([1, 2, 3]).astype(CategoricalDtype([3, 2, 1], ordered=True))
>>> other_psser = ps.Series([2, 1, 3]).astype(CategoricalDtype([3, 2, 1], ordered=True))
>>> with ps.option_context("compute.ops_on_diff_frames", True):
... psser <= other_psser
...
0 False
1 True
2 True
dtype: bool
```
### How was this patch tested?
Unit tests.
Closes #33331 from xinrong-databricks/categorical_compare.
Authored-by: Xinrong Meng <xinrong.meng@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
(cherry picked from commit 0cb120f390
)
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
parent
0da71548a5
commit
ca8a3f2e23
|
@ -22,10 +22,12 @@ import pandas as pd
|
|||
from pandas.api.types import CategoricalDtype
|
||||
|
||||
from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex
|
||||
from pyspark.pandas.base import column_op, IndexOpsMixin
|
||||
from pyspark.pandas.data_type_ops.base import DataTypeOps
|
||||
from pyspark.pandas.spark import functions as SF
|
||||
from pyspark.pandas.typedef import pandas_on_spark_type
|
||||
from pyspark.sql import functions as F
|
||||
from pyspark.sql.column import Column
|
||||
|
||||
|
||||
class CategoricalOps(DataTypeOps):
|
||||
|
@ -64,15 +66,28 @@ class CategoricalOps(DataTypeOps):
|
|||
scol = map_scol.getItem(index_ops.spark.column)
|
||||
return index_ops._with_new_scol(scol).astype(dtype)
|
||||
|
||||
# TODO(SPARK-35997): Implement comparison operators below
|
||||
def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
|
||||
raise NotImplementedError("< can not be applied to %s." % self.pretty_name)
|
||||
_non_equality_comparison_input_check(left, right)
|
||||
return column_op(Column.__lt__)(left, right)
|
||||
|
||||
def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
|
||||
raise NotImplementedError("<= can not be applied to %s." % self.pretty_name)
|
||||
|
||||
def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
|
||||
raise NotImplementedError("> can not be applied to %s." % self.pretty_name)
|
||||
_non_equality_comparison_input_check(left, right)
|
||||
return column_op(Column.__le__)(left, right)
|
||||
|
||||
def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
|
||||
raise NotImplementedError(">= can not be applied to %s." % self.pretty_name)
|
||||
_non_equality_comparison_input_check(left, right)
|
||||
return column_op(Column.__gt__)(left, right)
|
||||
|
||||
def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
|
||||
_non_equality_comparison_input_check(left, right)
|
||||
return column_op(Column.__ge__)(left, right)
|
||||
|
||||
|
||||
def _non_equality_comparison_input_check(left: IndexOpsLike, right: Any) -> None:
|
||||
if not left.dtype.ordered:
|
||||
raise TypeError("Unordered Categoricals can only compare equality or not.")
|
||||
if isinstance(right, IndexOpsMixin) and isinstance(right.dtype, CategoricalDtype):
|
||||
if hash(left.dtype) != hash(right.dtype):
|
||||
raise TypeError("Categoricals can only be compared if 'categories' are the same.")
|
||||
else:
|
||||
raise TypeError("Cannot compare a Categorical with the given type.")
|
||||
|
|
|
@ -44,6 +44,26 @@ class CategoricalOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
|||
def other_psser(self):
|
||||
return ps.from_pandas(self.other_pser)
|
||||
|
||||
@property
|
||||
def ordered_pser(self):
|
||||
return pd.Series([1, 2, 3]).astype(CategoricalDtype([3, 2, 1], ordered=True))
|
||||
|
||||
@property
|
||||
def ordered_psser(self):
|
||||
return ps.from_pandas(self.ordered_pser)
|
||||
|
||||
@property
|
||||
def other_ordered_pser(self):
|
||||
return pd.Series([2, 1, 3]).astype(CategoricalDtype([3, 2, 1], ordered=True))
|
||||
|
||||
@property
|
||||
def other_ordered_psser(self):
|
||||
return ps.from_pandas(self.other_ordered_pser)
|
||||
|
||||
@property
|
||||
def unordered_psser(self):
|
||||
return ps.Series([1, 2, 3]).astype(CategoricalDtype([3, 2, 1]))
|
||||
|
||||
def test_add(self):
|
||||
self.assertRaises(TypeError, lambda: self.psser + "x")
|
||||
self.assertRaises(TypeError, lambda: self.psser + 1)
|
||||
|
@ -198,16 +218,137 @@ class CategoricalOpsTest(PandasOnSparkTestCase, TestCasesUtils):
|
|||
self.assert_eq(self.pser != self.pser, (self.psser != self.psser).sort_index())
|
||||
|
||||
def test_lt(self):
|
||||
self.assertRaises(NotImplementedError, lambda: self.psser < self.other_psser)
|
||||
ordered_pser = self.ordered_pser
|
||||
ordered_psser = self.ordered_psser
|
||||
self.assert_eq(ordered_pser < ordered_pser, ordered_psser < ordered_psser)
|
||||
with option_context("compute.ops_on_diff_frames", True):
|
||||
self.assert_eq(
|
||||
ordered_pser < self.other_ordered_pser, ordered_psser < self.other_ordered_psser
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Unordered Categoricals can only compare equality or not",
|
||||
lambda: self.unordered_psser < ordered_psser,
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Categoricals can only be compared if 'categories' are the same",
|
||||
lambda: ordered_psser < self.unordered_psser,
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Cannot compare a Categorical with the given type",
|
||||
lambda: ordered_psser < ps.Series([1, 2, 3]),
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Cannot compare a Categorical with the given type",
|
||||
lambda: ordered_psser < [1, 2, 3],
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError, "Cannot compare a Categorical with the given type", lambda: ordered_psser < 1
|
||||
)
|
||||
|
||||
def test_le(self):
|
||||
self.assertRaises(NotImplementedError, lambda: self.psser <= self.other_psser)
|
||||
ordered_pser = self.ordered_pser
|
||||
ordered_psser = self.ordered_psser
|
||||
self.assert_eq(ordered_pser <= ordered_pser, ordered_psser <= ordered_psser)
|
||||
|
||||
with option_context("compute.ops_on_diff_frames", True):
|
||||
self.assert_eq(
|
||||
ordered_pser <= self.other_ordered_pser, ordered_psser <= self.other_ordered_psser
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Unordered Categoricals can only compare equality or not",
|
||||
lambda: self.unordered_psser <= ordered_psser,
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Categoricals can only be compared if 'categories' are the same",
|
||||
lambda: ordered_psser <= self.unordered_psser,
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Cannot compare a Categorical with the given type",
|
||||
lambda: ordered_psser <= ps.Series([1, 2, 3]),
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Cannot compare a Categorical with the given type",
|
||||
lambda: ordered_psser <= [1, 2, 3],
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Cannot compare a Categorical with the given type",
|
||||
lambda: ordered_psser <= 1,
|
||||
)
|
||||
|
||||
def test_gt(self):
|
||||
self.assertRaises(NotImplementedError, lambda: self.psser > self.other_psser)
|
||||
ordered_pser = self.ordered_pser
|
||||
ordered_psser = self.ordered_psser
|
||||
self.assert_eq(ordered_pser > ordered_pser, ordered_psser > ordered_psser)
|
||||
with option_context("compute.ops_on_diff_frames", True):
|
||||
self.assert_eq(
|
||||
ordered_pser > self.other_ordered_pser, ordered_psser > self.other_ordered_psser
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Unordered Categoricals can only compare equality or not",
|
||||
lambda: self.unordered_psser > ordered_psser,
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Categoricals can only be compared if 'categories' are the same",
|
||||
lambda: ordered_psser > self.unordered_psser,
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Cannot compare a Categorical with the given type",
|
||||
lambda: ordered_psser > ps.Series([1, 2, 3]),
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Cannot compare a Categorical with the given type",
|
||||
lambda: ordered_psser > [1, 2, 3],
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError, "Cannot compare a Categorical with the given type", lambda: ordered_psser > 1
|
||||
)
|
||||
|
||||
def test_ge(self):
|
||||
self.assertRaises(NotImplementedError, lambda: self.psser >= self.other_psser)
|
||||
ordered_pser = self.ordered_pser
|
||||
ordered_psser = self.ordered_psser
|
||||
self.assert_eq(ordered_pser >= ordered_pser, ordered_psser >= ordered_psser)
|
||||
with option_context("compute.ops_on_diff_frames", True):
|
||||
self.assert_eq(
|
||||
ordered_pser >= self.other_ordered_pser, ordered_psser >= self.other_ordered_psser
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Unordered Categoricals can only compare equality or not",
|
||||
lambda: self.unordered_psser >= ordered_psser,
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Categoricals can only be compared if 'categories' are the same",
|
||||
lambda: ordered_psser >= self.unordered_psser,
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Cannot compare a Categorical with the given type",
|
||||
lambda: ordered_psser >= ps.Series([1, 2, 3]),
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Cannot compare a Categorical with the given type",
|
||||
lambda: ordered_psser >= [1, 2, 3],
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"Cannot compare a Categorical with the given type",
|
||||
lambda: ordered_psser >= 1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in a new issue