diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 09efef2750..cba1db1a65 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -9995,13 +9995,25 @@ defaultdict(, {'col..., 'col...})] raise ValueError("items should be a list-like object.") if axis == 0: if len(index_scols) == 1: - col = None - for item in items: - if col is None: - col = index_scols[0] == SF.lit(item) - else: - col = col | (index_scols[0] == SF.lit(item)) - elif len(index_scols) > 1: + if len(items) <= ps.get_option("compute.isin_limit"): + col = index_scols[0].isin([SF.lit(item) for item in items]) + return DataFrame(self._internal.with_filter(col)) + else: + item_sdf_col = verify_temp_column_name( + self._internal.spark_frame, "__item__" + ) + item_sdf = default_session().createDataFrame( + pd.DataFrame({item_sdf_col: items}) + ) + joined_sdf = self._internal.spark_frame.join( + other=F.broadcast(item_sdf), + on=(index_scols[0] == scol_for(item_sdf, item_sdf_col)), + how="semi", + ) + + return DataFrame(self._internal.with_new_sdf(joined_sdf)) + + else: # for multi-index col = None for item in items: @@ -10019,7 +10031,7 @@ defaultdict(, {'col..., 'col...})] col = midx_col else: col = col | midx_col - return DataFrame(self._internal.with_filter(col)) + return DataFrame(self._internal.with_filter(col)) else: return self[items] elif like is not None: diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 20aecc22a4..3cfbc0334d 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -4313,6 +4313,13 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils): psdf.filter(items=["ab", "aa"], axis=0).sort_index(), pdf.filter(items=["ab", "aa"], axis=0).sort_index(), ) + + with option_context("compute.isin_limit", 0): + self.assert_eq( + psdf.filter(items=["ab", "aa"], axis=0).sort_index(), + pdf.filter(items=["ab", "aa"], axis=0).sort_index(), + ) + self.assert_eq( psdf.filter(items=["ba", "db"], axis=1).sort_index(), pdf.filter(items=["ba", "db"], axis=1).sort_index(),