[SPARK-36769][PYTHON] Improve filter of single-indexed DataFrame

### What changes were proposed in this pull request?
Improve `filter` of single-indexed DataFrame by replacing a long Project with Filter or Join.

### Why are the changes needed?
When the given `items` have too many elements, a long Project is introduced.
We may replace that with `Column.isin` or joining depending on the length of `items` for better performance.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Unit tests.

Closes #33998 from xinrong-databricks/impr_filter.

Authored-by: Xinrong Meng <xinrong.meng@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
Xinrong Meng 2021-09-21 10:20:15 -07:00 committed by Takuya UESHIN
parent 688b95b136
commit 33e463ccf9
2 changed files with 27 additions and 8 deletions

View file

@ -9995,13 +9995,25 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
raise ValueError("items should be a list-like object.")
if axis == 0:
if len(index_scols) == 1:
col = None
for item in items:
if col is None:
col = index_scols[0] == SF.lit(item)
else:
col = col | (index_scols[0] == SF.lit(item))
elif len(index_scols) > 1:
if len(items) <= ps.get_option("compute.isin_limit"):
col = index_scols[0].isin([SF.lit(item) for item in items])
return DataFrame(self._internal.with_filter(col))
else:
item_sdf_col = verify_temp_column_name(
self._internal.spark_frame, "__item__"
)
item_sdf = default_session().createDataFrame(
pd.DataFrame({item_sdf_col: items})
)
joined_sdf = self._internal.spark_frame.join(
other=F.broadcast(item_sdf),
on=(index_scols[0] == scol_for(item_sdf, item_sdf_col)),
how="semi",
)
return DataFrame(self._internal.with_new_sdf(joined_sdf))
else:
# for multi-index
col = None
for item in items:
@ -10019,7 +10031,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
col = midx_col
else:
col = col | midx_col
return DataFrame(self._internal.with_filter(col))
return DataFrame(self._internal.with_filter(col))
else:
return self[items]
elif like is not None:

View file

@ -4313,6 +4313,13 @@ class DataFrameTest(PandasOnSparkTestCase, SQLTestUtils):
psdf.filter(items=["ab", "aa"], axis=0).sort_index(),
pdf.filter(items=["ab", "aa"], axis=0).sort_index(),
)
with option_context("compute.isin_limit", 0):
self.assert_eq(
psdf.filter(items=["ab", "aa"], axis=0).sort_index(),
pdf.filter(items=["ab", "aa"], axis=0).sort_index(),
)
self.assert_eq(
psdf.filter(items=["ba", "db"], axis=1).sort_index(),
pdf.filter(items=["ba", "db"], axis=1).sort_index(),