[SPARK-36369][PYTHON] Fix Index.union to follow pandas 1.3
This PR proposes fixing the `Index.union` to follow the behavior of pandas 1.3.
Before:
```python
>>> ps_idx1 = ps.Index([1, 1, 1, 1, 1, 2, 2])
>>> ps_idx2 = ps.Index([1, 1, 2, 2, 2, 2, 2])
>>> ps_idx1.union(ps_idx2)
Int64Index([1, 1, 1, 1, 1, 2, 2], dtype='int64')
```
After:
```python
>>> ps_idx1 = ps.Index([1, 1, 1, 1, 1, 2, 2])
>>> ps_idx2 = ps.Index([1, 1, 2, 2, 2, 2, 2])
>>> ps_idx1.union(ps_idx2)
Int64Index([1, 1, 1, 1, 1, 2, 2, 2, 2, 2], dtype='int64')
```
This bug is fixed in https://github.com/pandas-dev/pandas/issues/36289.
We should follow the behavior of pandas as much as possible.
Yes, the result for some cases have duplicates values will change.
Unit test.
Closes #33634 from itholic/SPARK-36369.
Authored-by: itholic <haejoon.lee@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
(cherry picked from commit a9f371c247
)
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
parent
cb075b5301
commit
f2f09e4cdb
|
@ -2336,9 +2336,7 @@ class Index(IndexOpsMixin):
|
|||
|
||||
sdf_self = self._internal.spark_frame.select(self._internal.index_spark_columns)
|
||||
sdf_other = other_idx._internal.spark_frame.select(other_idx._internal.index_spark_columns)
|
||||
sdf = sdf_self.union(sdf_other.subtract(sdf_self))
|
||||
if isinstance(self, MultiIndex):
|
||||
sdf = sdf.drop_duplicates()
|
||||
sdf = sdf_self.unionAll(sdf_other).exceptAll(sdf_self.intersectAll(sdf_other))
|
||||
if sort:
|
||||
sdf = sdf.sort(*self._internal.index_spark_column_names)
|
||||
|
||||
|
|
|
@ -1487,21 +1487,20 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils):
|
|||
almost=True,
|
||||
)
|
||||
|
||||
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
|
||||
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
|
||||
pass
|
||||
else:
|
||||
self.assert_eq(psidx2.union(psidx1), pidx2.union(pidx1))
|
||||
self.assert_eq(
|
||||
psidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
|
||||
pidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
|
||||
almost=True,
|
||||
)
|
||||
self.assert_eq(
|
||||
psidx2.union(ps.Series([1, 2, 3, 4, 3, 4, 3, 4])),
|
||||
pidx2.union(pd.Series([1, 2, 3, 4, 3, 4, 3, 4])),
|
||||
almost=True,
|
||||
)
|
||||
# Manually create the expected result here since there is a bug in Index.union
|
||||
# dropping duplicated values in pandas < 1.3.
|
||||
expected = pd.Index([1, 2, 3, 3, 3, 4, 4, 4, 5, 6])
|
||||
self.assert_eq(psidx2.union(psidx1), expected)
|
||||
self.assert_eq(
|
||||
psidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
|
||||
expected,
|
||||
almost=True,
|
||||
)
|
||||
self.assert_eq(
|
||||
psidx2.union(ps.Series([1, 2, 3, 4, 3, 4, 3, 4])),
|
||||
expected,
|
||||
almost=True,
|
||||
)
|
||||
|
||||
# MultiIndex
|
||||
pmidx1 = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")])
|
||||
|
@ -1513,80 +1512,85 @@ class IndexesTest(PandasOnSparkTestCase, TestUtils):
|
|||
psmidx3 = ps.from_pandas(pmidx3)
|
||||
psmidx4 = ps.from_pandas(pmidx4)
|
||||
|
||||
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
|
||||
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
|
||||
pass
|
||||
else:
|
||||
self.assert_eq(psmidx1.union(psmidx2), pmidx1.union(pmidx2))
|
||||
self.assert_eq(psmidx2.union(psmidx1), pmidx2.union(pmidx1))
|
||||
self.assert_eq(psmidx3.union(psmidx4), pmidx3.union(pmidx4))
|
||||
self.assert_eq(psmidx4.union(psmidx3), pmidx4.union(pmidx3))
|
||||
self.assert_eq(
|
||||
psmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
|
||||
pmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
|
||||
)
|
||||
self.assert_eq(
|
||||
psmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
|
||||
pmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
|
||||
)
|
||||
self.assert_eq(
|
||||
psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
|
||||
pmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
|
||||
)
|
||||
self.assert_eq(
|
||||
psmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
|
||||
pmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
|
||||
)
|
||||
# Manually create the expected result here since there is a bug in MultiIndex.union
|
||||
# dropping duplicated values in pandas < 1.3.
|
||||
expected = pd.MultiIndex.from_tuples(
|
||||
[("x", "a"), ("x", "a"), ("x", "b"), ("x", "b"), ("x", "c"), ("x", "d")]
|
||||
)
|
||||
self.assert_eq(psmidx1.union(psmidx2), expected)
|
||||
self.assert_eq(psmidx2.union(psmidx1), expected)
|
||||
self.assert_eq(
|
||||
psmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
|
||||
expected,
|
||||
)
|
||||
self.assert_eq(
|
||||
psmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
|
||||
expected,
|
||||
)
|
||||
|
||||
expected = pd.MultiIndex.from_tuples(
|
||||
[(1, 1), (1, 2), (1, 3), (1, 3), (1, 4), (1, 4), (1, 5), (1, 6)]
|
||||
)
|
||||
self.assert_eq(psmidx3.union(psmidx4), expected)
|
||||
self.assert_eq(psmidx4.union(psmidx3), expected)
|
||||
self.assert_eq(
|
||||
psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
|
||||
expected,
|
||||
)
|
||||
self.assert_eq(
|
||||
psmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
|
||||
expected,
|
||||
)
|
||||
|
||||
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
|
||||
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
|
||||
pass
|
||||
# Testing if the result is correct after sort=False.
|
||||
# The `sort` argument is added in pandas 0.24.
|
||||
elif LooseVersion(pd.__version__) >= LooseVersion("0.24"):
|
||||
if LooseVersion(pd.__version__) >= LooseVersion("0.24"):
|
||||
# Manually create the expected result here since there is a bug in MultiIndex.union
|
||||
# dropping duplicated values in pandas < 1.3.
|
||||
expected = pd.MultiIndex.from_tuples(
|
||||
[("x", "a"), ("x", "a"), ("x", "b"), ("x", "b"), ("x", "c"), ("x", "d")]
|
||||
)
|
||||
self.assert_eq(
|
||||
psmidx1.union(psmidx2, sort=False).sort_values(),
|
||||
pmidx1.union(pmidx2, sort=False).sort_values(),
|
||||
expected,
|
||||
)
|
||||
self.assert_eq(
|
||||
psmidx2.union(psmidx1, sort=False).sort_values(),
|
||||
pmidx2.union(pmidx1, sort=False).sort_values(),
|
||||
)
|
||||
self.assert_eq(
|
||||
psmidx3.union(psmidx4, sort=False).sort_values(),
|
||||
pmidx3.union(pmidx4, sort=False).sort_values(),
|
||||
)
|
||||
self.assert_eq(
|
||||
psmidx4.union(psmidx3, sort=False).sort_values(),
|
||||
pmidx4.union(pmidx3, sort=False).sort_values(),
|
||||
expected,
|
||||
)
|
||||
self.assert_eq(
|
||||
psmidx1.union(
|
||||
[("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")], sort=False
|
||||
).sort_values(),
|
||||
pmidx1.union(
|
||||
[("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")], sort=False
|
||||
).sort_values(),
|
||||
expected,
|
||||
)
|
||||
self.assert_eq(
|
||||
psmidx2.union(
|
||||
[("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")], sort=False
|
||||
).sort_values(),
|
||||
pmidx2.union(
|
||||
[("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")], sort=False
|
||||
).sort_values(),
|
||||
expected,
|
||||
)
|
||||
|
||||
expected = pd.MultiIndex.from_tuples(
|
||||
[(1, 1), (1, 2), (1, 3), (1, 3), (1, 4), (1, 4), (1, 5), (1, 6)]
|
||||
)
|
||||
self.assert_eq(
|
||||
psmidx3.union(psmidx4, sort=False).sort_values(),
|
||||
expected,
|
||||
)
|
||||
self.assert_eq(
|
||||
psmidx4.union(psmidx3, sort=False).sort_values(),
|
||||
expected,
|
||||
)
|
||||
self.assert_eq(
|
||||
psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)], sort=False).sort_values(),
|
||||
pmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)], sort=False).sort_values(),
|
||||
expected,
|
||||
)
|
||||
self.assert_eq(
|
||||
psmidx4.union(
|
||||
[(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)], sort=False
|
||||
).sort_values(),
|
||||
pmidx4.union(
|
||||
[(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)], sort=False
|
||||
).sort_values(),
|
||||
expected,
|
||||
)
|
||||
|
||||
self.assertRaises(NotImplementedError, lambda: psidx1.union(psmidx1))
|
||||
|
|
Loading…
Reference in a new issue