From f2f09e4cdba815a3eeffdf0cdbf94dc9aa3c5634 Mon Sep 17 00:00:00 2001 From: itholic Date: Mon, 9 Aug 2021 11:10:01 +0900 Subject: [PATCH] [SPARK-36369][PYTHON] Fix Index.union to follow pandas 1.3 This PR proposes fixing the `Index.union` to follow the behavior of pandas 1.3. Before: ```python >>> ps_idx1 = ps.Index([1, 1, 1, 1, 1, 2, 2]) >>> ps_idx2 = ps.Index([1, 1, 2, 2, 2, 2, 2]) >>> ps_idx1.union(ps_idx2) Int64Index([1, 1, 1, 1, 1, 2, 2], dtype='int64') ``` After: ```python >>> ps_idx1 = ps.Index([1, 1, 1, 1, 1, 2, 2]) >>> ps_idx2 = ps.Index([1, 1, 2, 2, 2, 2, 2]) >>> ps_idx1.union(ps_idx2) Int64Index([1, 1, 1, 1, 1, 2, 2, 2, 2, 2], dtype='int64') ``` This bug is fixed in https://github.com/pandas-dev/pandas/issues/36289. We should follow the behavior of pandas as much as possible. Yes, the result for some cases have duplicates values will change. Unit test. Closes #33634 from itholic/SPARK-36369. Authored-by: itholic Signed-off-by: Hyukjin Kwon (cherry picked from commit a9f371c2470ce28251012dea7428ff9be80bf3e5) Signed-off-by: Hyukjin Kwon --- python/pyspark/pandas/indexes/base.py | 4 +- .../pyspark/pandas/tests/indexes/test_base.py | 130 +++++++++--------- 2 files changed, 68 insertions(+), 66 deletions(-) diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py index a43a5d1628..9d0d75aacd 100644 --- a/python/pyspark/pandas/indexes/base.py +++ b/python/pyspark/pandas/indexes/base.py @@ -2336,9 +2336,7 @@ class Index(IndexOpsMixin): sdf_self = self._internal.spark_frame.select(self._internal.index_spark_columns) sdf_other = other_idx._internal.spark_frame.select(other_idx._internal.index_spark_columns) - sdf = sdf_self.union(sdf_other.subtract(sdf_self)) - if isinstance(self, MultiIndex): - sdf = sdf.drop_duplicates() + sdf = sdf_self.unionAll(sdf_other).exceptAll(sdf_self.intersectAll(sdf_other)) if sort: sdf = sdf.sort(*self._internal.index_spark_column_names) diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py index 39e22bd116..605e3f8fc0 100644 --- a/python/pyspark/pandas/tests/indexes/test_base.py +++ b/python/pyspark/pandas/tests/indexes/test_base.py @@ -1487,21 +1487,20 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils): almost=True, ) - if LooseVersion(pd.__version__) >= LooseVersion("1.3"): - # TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3 - pass - else: - self.assert_eq(psidx2.union(psidx1), pidx2.union(pidx1)) - self.assert_eq( - psidx2.union([1, 2, 3, 4, 3, 4, 3, 4]), - pidx2.union([1, 2, 3, 4, 3, 4, 3, 4]), - almost=True, - ) - self.assert_eq( - psidx2.union(ps.Series([1, 2, 3, 4, 3, 4, 3, 4])), - pidx2.union(pd.Series([1, 2, 3, 4, 3, 4, 3, 4])), - almost=True, - ) + # Manually create the expected result here since there is a bug in Index.union + # dropping duplicated values in pandas < 1.3. + expected = pd.Index([1, 2, 3, 3, 3, 4, 4, 4, 5, 6]) + self.assert_eq(psidx2.union(psidx1), expected) + self.assert_eq( + psidx2.union([1, 2, 3, 4, 3, 4, 3, 4]), + expected, + almost=True, + ) + self.assert_eq( + psidx2.union(ps.Series([1, 2, 3, 4, 3, 4, 3, 4])), + expected, + almost=True, + ) # MultiIndex pmidx1 = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]) @@ -1513,80 +1512,85 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils): psmidx3 = ps.from_pandas(pmidx3) psmidx4 = ps.from_pandas(pmidx4) - if LooseVersion(pd.__version__) >= LooseVersion("1.3"): - # TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3 - pass - else: - self.assert_eq(psmidx1.union(psmidx2), pmidx1.union(pmidx2)) - self.assert_eq(psmidx2.union(psmidx1), pmidx2.union(pmidx1)) - self.assert_eq(psmidx3.union(psmidx4), pmidx3.union(pmidx4)) - self.assert_eq(psmidx4.union(psmidx3), pmidx4.union(pmidx3)) - self.assert_eq( - psmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]), - pmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]), - ) - self.assert_eq( - psmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]), - pmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]), - ) - self.assert_eq( - psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]), - pmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]), - ) - self.assert_eq( - psmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]), - pmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]), - ) + # Manually create the expected result here since there is a bug in MultiIndex.union + # dropping duplicated values in pandas < 1.3. + expected = pd.MultiIndex.from_tuples( + [("x", "a"), ("x", "a"), ("x", "b"), ("x", "b"), ("x", "c"), ("x", "d")] + ) + self.assert_eq(psmidx1.union(psmidx2), expected) + self.assert_eq(psmidx2.union(psmidx1), expected) + self.assert_eq( + psmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]), + expected, + ) + self.assert_eq( + psmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]), + expected, + ) + + expected = pd.MultiIndex.from_tuples( + [(1, 1), (1, 2), (1, 3), (1, 3), (1, 4), (1, 4), (1, 5), (1, 6)] + ) + self.assert_eq(psmidx3.union(psmidx4), expected) + self.assert_eq(psmidx4.union(psmidx3), expected) + self.assert_eq( + psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]), + expected, + ) + self.assert_eq( + psmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]), + expected, + ) - if LooseVersion(pd.__version__) >= LooseVersion("1.3"): - # TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3 - pass # Testing if the result is correct after sort=False. # The `sort` argument is added in pandas 0.24. - elif LooseVersion(pd.__version__) >= LooseVersion("0.24"): + if LooseVersion(pd.__version__) >= LooseVersion("0.24"): + # Manually create the expected result here since there is a bug in MultiIndex.union + # dropping duplicated values in pandas < 1.3. + expected = pd.MultiIndex.from_tuples( + [("x", "a"), ("x", "a"), ("x", "b"), ("x", "b"), ("x", "c"), ("x", "d")] + ) self.assert_eq( psmidx1.union(psmidx2, sort=False).sort_values(), - pmidx1.union(pmidx2, sort=False).sort_values(), + expected, ) self.assert_eq( psmidx2.union(psmidx1, sort=False).sort_values(), - pmidx2.union(pmidx1, sort=False).sort_values(), - ) - self.assert_eq( - psmidx3.union(psmidx4, sort=False).sort_values(), - pmidx3.union(pmidx4, sort=False).sort_values(), - ) - self.assert_eq( - psmidx4.union(psmidx3, sort=False).sort_values(), - pmidx4.union(pmidx3, sort=False).sort_values(), + expected, ) self.assert_eq( psmidx1.union( [("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")], sort=False ).sort_values(), - pmidx1.union( - [("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")], sort=False - ).sort_values(), + expected, ) self.assert_eq( psmidx2.union( [("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")], sort=False ).sort_values(), - pmidx2.union( - [("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")], sort=False - ).sort_values(), + expected, + ) + + expected = pd.MultiIndex.from_tuples( + [(1, 1), (1, 2), (1, 3), (1, 3), (1, 4), (1, 4), (1, 5), (1, 6)] + ) + self.assert_eq( + psmidx3.union(psmidx4, sort=False).sort_values(), + expected, + ) + self.assert_eq( + psmidx4.union(psmidx3, sort=False).sort_values(), + expected, ) self.assert_eq( psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)], sort=False).sort_values(), - pmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)], sort=False).sort_values(), + expected, ) self.assert_eq( psmidx4.union( [(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)], sort=False ).sort_values(), - pmidx4.union( - [(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)], sort=False - ).sort_values(), + expected, ) self.assertRaises(NotImplementedError, lambda: psidx1.union(psmidx1))