6d06ff6f7e
## What changes were proposed in this pull request? In PySpark, `df.take(1)` runs a single-stage job which computes only one partition of the DataFrame, while `df.limit(1).collect()` computes all partitions and runs a two-stage job. This difference in performance is confusing. The reason why `limit(1).collect()` is so much slower is that `collect()` internally maps to `df.rdd.<some-pyspark-conversions>.toLocalIterator`, which causes Spark SQL to build a query where a global limit appears in the middle of the plan; this, in turn, ends up being executed inefficiently because limits in the middle of plans are now implemented by repartitioning to a single task rather than by running a `take()` job on the driver (this was done in #7334, a patch which was a prerequisite to allowing partition-local limits to be pushed beneath unions, etc.). In order to fix this performance problem I think that we should generalize the fix from SPARK-10731 / #8876 so that `DataFrame.collect()` also delegates to the Scala implementation and shares the same performance properties. This patch modifies `DataFrame.collect()` to first collect all results to the driver and then pass them to Python, allowing this query to be planned using Spark's `CollectLimit` optimizations. ## How was this patch tested? Added a regression test in `sql/tests.py` which asserts that the expected number of jobs, stages, and tasks are run for both queries. Author: Josh Rosen <joshrosen@databricks.com> Closes #15068 from JoshRosen/pyspark-collect-limit. |
||
---|---|---|
.. | ||
__init__.py | ||
catalog.py | ||
column.py | ||
conf.py | ||
context.py | ||
dataframe.py | ||
functions.py | ||
group.py | ||
readwriter.py | ||
session.py | ||
streaming.py | ||
tests.py | ||
types.py | ||
utils.py | ||
window.py |