[SPARK-36722][PYTHON] Fix Series.update with another in same frame

### What changes were proposed in this pull request?
Fix Series.update with another in same frame

also add test for update series in diff frame

### Why are the changes needed?
Fix Series.update with another in same frame

Pandas behavior:
``` python
>>> pdf = pd.DataFrame(
...     {"a": [None, 2, 3, 4, 5, 6, 7, 8, None], "b": [None, 5, None, 3, 2, 1, None, 0, 0]},
... )
>>> pdf
     a    b
0  NaN  NaN
1  2.0  5.0
2  3.0  NaN
3  4.0  3.0
4  5.0  2.0
5  6.0  1.0
6  7.0  NaN
7  8.0  0.0
8  NaN  0.0
>>> pdf.a.update(pdf.b)
>>> pdf
     a    b
0  NaN  NaN
1  5.0  5.0
2  3.0  NaN
3  3.0  3.0
4  2.0  2.0
5  1.0  1.0
6  7.0  NaN
7  0.0  0.0
8  0.0  0.0
```

### Does this PR introduce _any_ user-facing change?
Before
```python
>>> psdf = ps.DataFrame(
...     {"a": [None, 2, 3, 4, 5, 6, 7, 8, None], "b": [None, 5, None, 3, 2, 1, None, 0, 0]},
... )

>>> psdf.a.update(psdf.b)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/dgd/spark/python/pyspark/pandas/series.py", line 4551, in update
    combined = combine_frames(self._psdf, other._psdf, how="leftouter")
  File "/Users/dgd/spark/python/pyspark/pandas/utils.py", line 141, in combine_frames
    assert not same_anchor(
AssertionError: We don't need to combine. `this` and `that` are same.
>>>
```

After
```python
>>> psdf = ps.DataFrame(
...     {"a": [None, 2, 3, 4, 5, 6, 7, 8, None], "b": [None, 5, None, 3, 2, 1, None, 0, 0]},
... )

>>> psdf.a.update(psdf.b)
>>> psdf
     a    b
0  NaN  NaN
1  5.0  5.0
2  3.0  NaN
3  3.0  3.0
4  2.0  2.0
5  1.0  1.0
6  7.0  NaN
7  0.0  0.0
8  0.0  0.0
>>>
```

### How was this patch tested?
unit tests

Closes #33968 from dgd-contributor/SPARK-36722_fix_update_same_anchor.

Authored-by: dgd-contributor <dgd_contributor@viettel.com.vn>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
dgd-contributor 2021-09-15 11:08:01 -07:00 committed by Takuya UESHIN
parent b665782f0d
commit c15072cc73
3 changed files with 64 additions and 12 deletions

View file

@ -4536,22 +4536,33 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
if not isinstance(other, Series):
raise TypeError("'other' must be a Series")
combined = combine_frames(self._psdf, other._psdf, how="leftouter")
if same_anchor(self, other):
scol = (
F.when(other.spark.column.isNotNull(), other.spark.column)
.otherwise(self.spark.column)
.alias(self._psdf._internal.spark_column_name_for(self._column_label))
)
internal = self._psdf._internal.with_new_spark_column(
self._column_label, scol # TODO: dtype?
)
self._psdf._update_internal_frame(internal)
else:
combined = combine_frames(self._psdf, other._psdf, how="leftouter")
this_scol = combined["this"]._internal.spark_column_for(self._column_label)
that_scol = combined["that"]._internal.spark_column_for(other._column_label)
this_scol = combined["this"]._internal.spark_column_for(self._column_label)
that_scol = combined["that"]._internal.spark_column_for(other._column_label)
scol = (
F.when(that_scol.isNotNull(), that_scol)
.otherwise(this_scol)
.alias(self._psdf._internal.spark_column_name_for(self._column_label))
)
scol = (
F.when(that_scol.isNotNull(), that_scol)
.otherwise(this_scol)
.alias(self._psdf._internal.spark_column_name_for(self._column_label))
)
internal = combined["this"]._internal.with_new_spark_column(
self._column_label, scol # TODO: dtype?
)
internal = combined["this"]._internal.with_new_spark_column(
self._column_label, scol # TODO: dtype?
)
self._psdf._update_internal_frame(internal.resolved_copy, requires_same_anchor=False)
self._psdf._update_internal_frame(internal.resolved_copy, requires_same_anchor=False)
def where(self, cond: "Series", other: Any = np.nan) -> "Series":
"""

View file

@ -1358,6 +1358,15 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils):
self.assert_eq(psser.sort_index(), pser.sort_index())
self.assert_eq(psdf.sort_index(), pdf.sort_index())
pser1 = pd.Series([None, 2, 3, 4, 5, 6, 7, 8, None])
pser2 = pd.Series([None, 5, None, 3, 2, 1, None, 0, 0])
psser1 = ps.from_pandas(pser1)
psser2 = ps.from_pandas(pser2)
pser1.update(pser2)
psser1.update(psser2)
self.assert_eq(psser1, pser1)
def test_where(self):
pdf1 = pd.DataFrame({"A": [0, 1, 2, 3, 4], "B": [100, 200, 300, 400, 500]})
pdf2 = pd.DataFrame({"A": [0, -1, -2, -3, -4], "B": [-100, -200, -300, -400, -500]})

View file

@ -1711,6 +1711,38 @@ class SeriesTest(PandasOnSparkTestCase, SQLTestUtils):
with self.assertRaisesRegex(TypeError, msg):
psser.update(10)
def _get_data():
pdf = pd.DataFrame(
{
"a": [None, 2, 3, 4, 5, 6, 7, 8, None],
"b": [None, 5, None, 3, 2, 1, None, 0, 0],
"c": [1, 5, 1, 3, 2, 1, 1, 0, 0],
},
)
psdf = ps.from_pandas(pdf)
return pdf, psdf
pdf, psdf = _get_data()
psdf.a.update(psdf.a)
pdf.a.update(pdf.a)
self.assert_eq(psdf, pdf)
pdf, psdf = _get_data()
psdf.a.update(psdf.b)
pdf.a.update(pdf.b)
self.assert_eq(psdf, pdf)
pdf, psdf = _get_data()
pser = pdf.a
psser = psdf.a
pser.update(pdf.b)
psser.update(psdf.b)
self.assert_eq(psser, pser)
self.assert_eq(psdf, pdf)
def test_where(self):
pser1 = pd.Series([0, 1, 2, 3, 4])
psser1 = ps.from_pandas(pser1)