From 33e463ccf99d09ad8a743d32104f590e204da93d Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Tue, 21 Sep 2021 10:20:15 -0700 Subject: [PATCH] [SPARK-36769][PYTHON] Improve `filter` of single-indexed DataFrame ### What changes were proposed in this pull request? Improve `filter` of single-indexed DataFrame by replacing a long Project with Filter or Join. ### Why are the changes needed? When the given `items` have too many elements, a long Project is introduced. We may replace that with `Column.isin` or joining depending on the length of `items` for better performance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. Closes #33998 from xinrong-databricks/impr_filter. Authored-by: Xinrong Meng Signed-off-by: Takuya UESHIN --- python/pyspark/pandas/frame.py | 28 +++++++++++++------ python/pyspark/pandas/tests/test_dataframe.py | 7 +++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 09efef2750..cba1db1a65 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -9995,13 +9995,25 @@ defaultdict(, {'col..., 'col...})] raise ValueError("items should be a list-like object.") if axis == 0: if len(index_scols) == 1: - col = None - for item in items: - if col is None: - col = index_scols[0] == SF.lit(item) - else: - col = col | (index_scols[0] == SF.lit(item)) - elif len(index_scols) > 1: + if len(items) <= ps.get_option("compute.isin_limit"): + col = index_scols[0].isin([SF.lit(item) for item in items]) + return DataFrame(self._internal.with_filter(col)) + else: + item_sdf_col = verify_temp_column_name( + self._internal.spark_frame, "__item__" + ) + item_sdf = default_session().createDataFrame( + pd.DataFrame({item_sdf_col: items}) + ) + joined_sdf = self._internal.spark_frame.join( + other=F.broadcast(item_sdf), + on=(index_scols[0] == scol_for(item_sdf, item_sdf_col)), + how="semi", + ) + + return DataFrame(self._internal.with_new_sdf(joined_sdf)) + + else: # for multi-index col = None for item in items: @@ -10019,7 +10031,7 @@ defaultdict(, {'col..., 'col...})] col = midx_col else: col = col | midx_col - return DataFrame(self._internal.with_filter(col)) + return DataFrame(self._internal.with_filter(col)) else: return self[items] elif like is not None: diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 20aecc22a4..3cfbc0334d 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -4313,6 +4313,13 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils): psdf.filter(items=["ab", "aa"], axis=0).sort_index(), pdf.filter(items=["ab", "aa"], axis=0).sort_index(), ) + + with option_context("compute.isin_limit", 0): + self.assert_eq( + psdf.filter(items=["ab", "aa"], axis=0).sort_index(), + pdf.filter(items=["ab", "aa"], axis=0).sort_index(), + ) + self.assert_eq( psdf.filter(items=["ba", "db"], axis=1).sort_index(), pdf.filter(items=["ba", "db"], axis=1).sort_index(),