[SPARK-27659][PYTHON] Allow PySpark to prefetch during toLocalIterator

### What changes were proposed in this pull request?

This PR allows Python toLocalIterator to prefetch the next partition while the first partition is being collected. The PR also adds a demo micro bench mark in the examples directory, we may wish to keep this or not.

### Why are the changes needed?

In https://issues.apache.org/jira/browse/SPARK-23961 / 5e79ae3b40 we changed PySpark to only pull one partition at a time. This is memory efficient, but if partitions take time to compute this can mean we're spending more time blocking.

### Does this PR introduce any user-facing change?

A new param is added to toLocalIterator

### How was this patch tested?

New unit test inside of `test_rdd.py` checks the time that the elements are evaluated at. Another test that the results remain the same are added to `test_dataframe.py`.

I also ran a micro benchmark in the examples directory `prefetch.py` which shows an improvement of ~40% in this specific use case.

>
> 19/08/16 17:11:36 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
> Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
> Setting default log level to "WARN".
> To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
> Running timers:
>
> [Stage 32:>                                                         (0 + 1) / 1]
> Results:
>
> Prefetch time:
>
> 100.228110831
>
>
> Regular time:
>
> 188.341721614
>
>
>

Closes #25515 from holdenk/SPARK-27659-allow-pyspark-tolocalitr-to-prefetch.

Authored-by: Holden Karau <hkarau@apple.com>
Signed-off-by: Holden Karau <hkarau@apple.com>
This commit is contained in:
Holden Karau 2019-09-20 09:59:31 -07:00
parent 27d0c3f913
commit 42050c3f4f
6 changed files with 61 additions and 10 deletions

View file

@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.concurrent.duration.Duration
import scala.reflect.ClassTag
import org.apache.hadoop.conf.Configuration
@ -179,15 +180,22 @@ private[spark] object PythonRDD extends Logging {
* data collected from this job, the secret for authentication, and a socket auth
* server object that can be used to join the JVM serving thread in Python.
*/
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
def toLocalIteratorAndServe[T](rdd: RDD[T], prefetchPartitions: Boolean = false): Array[Any] = {
val handleFunc = (sock: Socket) => {
val out = new DataOutputStream(sock.getOutputStream)
val in = new DataInputStream(sock.getInputStream)
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Collects a partition on each iteration
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head
var result: Array[Any] = null
rdd.sparkContext.submitJob(
rdd,
(iter: Iterator[Any]) => iter.toArray,
Seq(i), // The partition we are evaluating
(_, res: Array[Any]) => result = res,
result)
}
val prefetchIter = collectPartitionIter.buffered
// Write data until iteration is complete, client stops iteration, or error occurs
var complete = false
@ -196,10 +204,15 @@ private[spark] object PythonRDD extends Logging {
// Read request for data, value of zero will stop iteration or non-zero to continue
if (in.readInt() == 0) {
complete = true
} else if (collectPartitionIter.hasNext) {
} else if (prefetchIter.hasNext) {
// Client requested more data, attempt to collect the next partition
val partitionArray = collectPartitionIter.next()
val partitionFuture = prefetchIter.next()
// Cause the next job to be submitted if prefetchPartitions is enabled.
if (prefetchPartitions) {
prefetchIter.headOption
}
val partitionArray = ThreadUtils.awaitResult(partitionFuture, Duration.Inf)
// Send response there is a partition to read
out.writeInt(1)

View file

@ -2437,17 +2437,23 @@ class RDD(object):
hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF)
return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD)
def toLocalIterator(self):
def toLocalIterator(self, prefetchPartitions=False):
"""
Return an iterator that contains all of the elements in this RDD.
The iterator will consume as much memory as the largest partition in this RDD.
With prefetch it may consume up to the memory of the 2 largest partitions.
:param prefetchPartitions: If Spark should pre-fetch the next partition
before it is needed.
>>> rdd = sc.parallelize(range(10))
>>> [x for x in rdd.toLocalIterator()]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
with SCCallSiteSync(self.context) as css:
sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(
self._jrdd.rdd(),
prefetchPartitions)
return _local_iterator_from_socket(sock_info, self._jrdd_deserializer)
def barrier(self):

View file

@ -520,16 +520,20 @@ class DataFrame(object):
@ignore_unicode_prefix
@since(2.0)
def toLocalIterator(self):
def toLocalIterator(self, prefetchPartitions=False):
"""
Returns an iterator that contains all of the rows in this :class:`DataFrame`.
The iterator will consume as much memory as the largest partition in this DataFrame.
With prefetch it may consume up to the memory of the 2 largest partitions.
:param prefetchPartitions: If Spark should pre-fetch the next partition
before it is needed.
>>> list(df.toLocalIterator())
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
sock_info = self._jdf.toPythonIterator()
sock_info = self._jdf.toPythonIterator(prefetchPartitions)
return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer()))
@ignore_unicode_prefix

View file

@ -690,6 +690,12 @@ class DataFrameTests(ReusedSQLTestCase):
expected = df.collect()
self.assertEqual(expected, list(it))
def test_to_local_iterator_prefetch(self):
df = self.spark.range(8, numPartitions=4)
expected = df.collect()
it = df.toLocalIterator(prefetchPartitions=True)
self.assertEqual(expected, list(it))
def test_to_local_iterator_not_fully_consumed(self):
# SPARK-23961: toLocalIterator throws exception when not fully consumed
# Create a DataFrame large enough so that write to socket will eventually block

View file

@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime, timedelta
import hashlib
import os
import random
import sys
import tempfile
import time
from glob import glob
from py4j.protocol import Py4JJavaError
@ -68,6 +70,26 @@ class RDDTests(ReusedPySparkTestCase):
it2 = rdd2.toLocalIterator()
self.assertEqual([1, 2, 3], sorted(it2))
def test_to_localiterator_prefetch(self):
# Test that we fetch the next partition in parallel
# We do this by returning the current time and:
# reading the first elem, waiting, and reading the second elem
# If not in parallel then these would be at different times
# But since they are being computed in parallel we see the time
# is "close enough" to the same.
rdd = self.sc.parallelize(range(2), 2)
times1 = rdd.map(lambda x: datetime.now())
times2 = rdd.map(lambda x: datetime.now())
times_iter_prefetch = times1.toLocalIterator(prefetchPartitions=True)
times_iter = times2.toLocalIterator(prefetchPartitions=False)
times_prefetch_head = next(times_iter_prefetch)
times_head = next(times_iter)
time.sleep(2)
times_next = next(times_iter)
times_prefetch_next = next(times_iter_prefetch)
self.assertTrue(times_next - times_head >= timedelta(seconds=2))
self.assertTrue(times_prefetch_next - times_prefetch_head < timedelta(seconds=1))
def test_save_as_textfile_with_unicode(self):
# Regression test for SPARK-970
x = u"\u00A1Hola, mundo!"

View file

@ -3356,9 +3356,9 @@ class Dataset[T] private[sql](
}
}
private[sql] def toPythonIterator(): Array[Any] = {
private[sql] def toPythonIterator(prefetchPartitions: Boolean = false): Array[Any] = {
withNewExecutionId {
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd, prefetchPartitions)
}
}