[SPARK-36369][PYTHON] Fix Index.union to follow pandas 1.3

This PR proposes fixing the `Index.union` to follow the behavior of pandas 1.3.

Before:
```python
>>> ps_idx1 = ps.Index([1, 1, 1, 1, 1, 2, 2])
>>> ps_idx2 = ps.Index([1, 1, 2, 2, 2, 2, 2])
>>> ps_idx1.union(ps_idx2)
Int64Index([1, 1, 1, 1, 1, 2, 2], dtype='int64')
```

After:
```python
>>> ps_idx1 = ps.Index([1, 1, 1, 1, 1, 2, 2])
>>> ps_idx2 = ps.Index([1, 1, 2, 2, 2, 2, 2])
>>> ps_idx1.union(ps_idx2)
Int64Index([1, 1, 1, 1, 1, 2, 2, 2, 2, 2], dtype='int64')
```

This bug is fixed in https://github.com/pandas-dev/pandas/issues/36289.

We should follow the behavior of pandas as much as possible.

Yes, the result for some cases have duplicates values will change.

Unit test.

Closes #33634 from itholic/SPARK-36369.

Authored-by: itholic <haejoon.lee@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
(cherry picked from commit a9f371c247)
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
itholic 2021-08-09 11:10:01 +09:00 committed by Hyukjin Kwon
parent cb075b5301
commit f2f09e4cdb
2 changed files with 68 additions and 66 deletions

View file

@ -2336,9 +2336,7 @@ class Index(IndexOpsMixin):
sdf_self = self._internal.spark_frame.select(self._internal.index_spark_columns)
sdf_other = other_idx._internal.spark_frame.select(other_idx._internal.index_spark_columns)
sdf = sdf_self.union(sdf_other.subtract(sdf_self))
if isinstance(self, MultiIndex):
sdf = sdf.drop_duplicates()
sdf = sdf_self.unionAll(sdf_other).exceptAll(sdf_self.intersectAll(sdf_other))
if sort:
sdf = sdf.sort(*self._internal.index_spark_column_names)

View file

@ -1487,21 +1487,20 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils):
almost=True,
)
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
pass
else:
self.assert_eq(psidx2.union(psidx1), pidx2.union(pidx1))
self.assert_eq(
psidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
pidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
almost=True,
)
self.assert_eq(
psidx2.union(ps.Series([1, 2, 3, 4, 3, 4, 3, 4])),
pidx2.union(pd.Series([1, 2, 3, 4, 3, 4, 3, 4])),
almost=True,
)
# Manually create the expected result here since there is a bug in Index.union
# dropping duplicated values in pandas < 1.3.
expected = pd.Index([1, 2, 3, 3, 3, 4, 4, 4, 5, 6])
self.assert_eq(psidx2.union(psidx1), expected)
self.assert_eq(
psidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
expected,
almost=True,
)
self.assert_eq(
psidx2.union(ps.Series([1, 2, 3, 4, 3, 4, 3, 4])),
expected,
almost=True,
)
# MultiIndex
pmidx1 = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")])
@ -1513,80 +1512,85 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils):
psmidx3 = ps.from_pandas(pmidx3)
psmidx4 = ps.from_pandas(pmidx4)
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
pass
else:
self.assert_eq(psmidx1.union(psmidx2), pmidx1.union(pmidx2))
self.assert_eq(psmidx2.union(psmidx1), pmidx2.union(pmidx1))
self.assert_eq(psmidx3.union(psmidx4), pmidx3.union(pmidx4))
self.assert_eq(psmidx4.union(psmidx3), pmidx4.union(pmidx3))
self.assert_eq(
psmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
pmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
)
self.assert_eq(
psmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
pmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
)
self.assert_eq(
psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
pmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
)
self.assert_eq(
psmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
pmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
)
# Manually create the expected result here since there is a bug in MultiIndex.union
# dropping duplicated values in pandas < 1.3.
expected = pd.MultiIndex.from_tuples(
[("x", "a"), ("x", "a"), ("x", "b"), ("x", "b"), ("x", "c"), ("x", "d")]
)
self.assert_eq(psmidx1.union(psmidx2), expected)
self.assert_eq(psmidx2.union(psmidx1), expected)
self.assert_eq(
psmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
expected,
)
self.assert_eq(
psmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
expected,
)
expected = pd.MultiIndex.from_tuples(
[(1, 1), (1, 2), (1, 3), (1, 3), (1, 4), (1, 4), (1, 5), (1, 6)]
)
self.assert_eq(psmidx3.union(psmidx4), expected)
self.assert_eq(psmidx4.union(psmidx3), expected)
self.assert_eq(
psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
expected,
)
self.assert_eq(
psmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
expected,
)
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
pass
# Testing if the result is correct after sort=False.
# The `sort` argument is added in pandas 0.24.
elif LooseVersion(pd.__version__) >= LooseVersion("0.24"):
if LooseVersion(pd.__version__) >= LooseVersion("0.24"):
# Manually create the expected result here since there is a bug in MultiIndex.union
# dropping duplicated values in pandas < 1.3.
expected = pd.MultiIndex.from_tuples(
[("x", "a"), ("x", "a"), ("x", "b"), ("x", "b"), ("x", "c"), ("x", "d")]
)
self.assert_eq(
psmidx1.union(psmidx2, sort=False).sort_values(),
pmidx1.union(pmidx2, sort=False).sort_values(),
expected,
)
self.assert_eq(
psmidx2.union(psmidx1, sort=False).sort_values(),
pmidx2.union(pmidx1, sort=False).sort_values(),
)
self.assert_eq(
psmidx3.union(psmidx4, sort=False).sort_values(),
pmidx3.union(pmidx4, sort=False).sort_values(),
)
self.assert_eq(
psmidx4.union(psmidx3, sort=False).sort_values(),
pmidx4.union(pmidx3, sort=False).sort_values(),
expected,
)
self.assert_eq(
psmidx1.union(
[("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")], sort=False
).sort_values(),
pmidx1.union(
[("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")], sort=False
).sort_values(),
expected,
)
self.assert_eq(
psmidx2.union(
[("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")], sort=False
).sort_values(),
pmidx2.union(
[("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")], sort=False
).sort_values(),
expected,
)
expected = pd.MultiIndex.from_tuples(
[(1, 1), (1, 2), (1, 3), (1, 3), (1, 4), (1, 4), (1, 5), (1, 6)]
)
self.assert_eq(
psmidx3.union(psmidx4, sort=False).sort_values(),
expected,
)
self.assert_eq(
psmidx4.union(psmidx3, sort=False).sort_values(),
expected,
)
self.assert_eq(
psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)], sort=False).sort_values(),
pmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)], sort=False).sort_values(),
expected,
)
self.assert_eq(
psmidx4.union(
[(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)], sort=False
).sort_values(),
pmidx4.union(
[(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)], sort=False
).sort_values(),
expected,
)
self.assertRaises(NotImplementedError, lambda: psidx1.union(psmidx1))