diff --git a/python/pyspark/pandas/indexing.py b/python/pyspark/pandas/indexing.py index 3e0797595b..e7a07e763c 100644 --- a/python/pyspark/pandas/indexing.py +++ b/python/pyspark/pandas/indexing.py @@ -33,6 +33,7 @@ import numpy as np from pyspark import pandas as ps # noqa: F401 from pyspark.pandas._typing import Label, Name, Scalar from pyspark.pandas.internal import ( + DEFAULT_SERIES_NAME, InternalField, InternalFrame, NATURAL_ORDER_COLUMN_NAME, @@ -435,11 +436,12 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta): if self._is_series: if isinstance(key, Series) and not same_anchor(key, self._psdf_or_psser): - psdf = self._psdf_or_psser.to_frame() + name = self._psdf_or_psser.name or DEFAULT_SERIES_NAME + psdf = self._psdf_or_psser.to_frame(name) temp_col = verify_temp_column_name(psdf, "__temp_col__") psdf[temp_col] = key - return type(self)(psdf[self._psdf_or_psser.name])[psdf[temp_col]] + return type(self)(psdf[name].rename(self._psdf_or_psser.name))[psdf[temp_col]] cond, limit, remaining_index = self._select_rows(key) if cond is None and limit is None: diff --git a/python/pyspark/pandas/tests/test_indexing.py b/python/pyspark/pandas/tests/test_indexing.py index b74cf90d07..2b00b3f952 100644 --- a/python/pyspark/pandas/tests/test_indexing.py +++ b/python/pyspark/pandas/tests/test_indexing.py @@ -417,6 +417,15 @@ class IndexingTest(PandasOnSparkTestCase): self.assertRaises(KeyError, lambda: psdf.loc[0:30]) self.assertRaises(KeyError, lambda: psdf.loc[10:100]) + def test_loc_getitem_boolean_series(self): + pdf = pd.DataFrame( + {"A": [0, 1, 2, 3, 4], "B": [100, 200, 300, 400, 500]}, index=[20, 10, 30, 0, 50] + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(pdf.A.loc[pdf.B > 200], psdf.A.loc[psdf.B > 200]) + self.assert_eq(pdf.B.loc[pdf.B > 200], psdf.B.loc[psdf.B > 200]) + self.assert_eq(pdf.loc[pdf.B > 200], psdf.loc[psdf.B > 200]) + def test_loc_non_informative_index(self): pdf = pd.DataFrame({"x": [1, 2, 3, 4]}, index=[10, 20, 30, 40]) psdf = ps.from_pandas(pdf) diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py index 6a4855c752..1cc0ff51b8 100644 --- a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py +++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py @@ -503,6 +503,12 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase, SQLTestUtils): (pdf1.A + 1).loc[pdf2.A > -3].sort_index(), (psdf1.A + 1).loc[psdf2.A > -3].sort_index() ) + pser = pd.Series([0, 1, 2, 3, 4], index=[20, 10, 30, 0, 50]) + psser = ps.from_pandas(pser) + self.assert_eq(pser.loc[pdf2.A > -3].sort_index(), psser.loc[psdf2.A > -3].sort_index()) + pser.name = psser.name = "B" + self.assert_eq(pser.loc[pdf2.A > -3].sort_index(), psser.loc[psdf2.A > -3].sort_index()) + def test_bitwise(self): pser1 = pd.Series([True, False, True, False, np.nan, np.nan, True, False, np.nan]) pser2 = pd.Series([True, False, False, True, True, False, np.nan, np.nan, np.nan])