[SPARK-27659][PYTHON] Allow PySpark to prefetch during toLocalIterator
### What changes were proposed in this pull request?
This PR allows Python toLocalIterator to prefetch the next partition while the first partition is being collected. The PR also adds a demo micro bench mark in the examples directory, we may wish to keep this or not.
### Why are the changes needed?
In https://issues.apache.org/jira/browse/SPARK-23961 / 5e79ae3b40
we changed PySpark to only pull one partition at a time. This is memory efficient, but if partitions take time to compute this can mean we're spending more time blocking.
### Does this PR introduce any user-facing change?
A new param is added to toLocalIterator
### How was this patch tested?
New unit test inside of `test_rdd.py` checks the time that the elements are evaluated at. Another test that the results remain the same are added to `test_dataframe.py`.
I also ran a micro benchmark in the examples directory `prefetch.py` which shows an improvement of ~40% in this specific use case.
>
> 19/08/16 17:11:36 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
> Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
> Setting default log level to "WARN".
> To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
> Running timers:
>
> [Stage 32:> (0 + 1) / 1]
> Results:
>
> Prefetch time:
>
> 100.228110831
>
>
> Regular time:
>
> 188.341721614
>
>
>
Closes #25515 from holdenk/SPARK-27659-allow-pyspark-tolocalitr-to-prefetch.
Authored-by: Holden Karau <hkarau@apple.com>
Signed-off-by: Holden Karau <hkarau@apple.com>
This commit is contained in:
parent
27d0c3f913
commit
42050c3f4f
|
@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
|
|||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
import scala.concurrent.duration.Duration
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
|
@ -179,15 +180,22 @@ private[spark] object PythonRDD extends Logging {
|
|||
* data collected from this job, the secret for authentication, and a socket auth
|
||||
* server object that can be used to join the JVM serving thread in Python.
|
||||
*/
|
||||
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
|
||||
def toLocalIteratorAndServe[T](rdd: RDD[T], prefetchPartitions: Boolean = false): Array[Any] = {
|
||||
val handleFunc = (sock: Socket) => {
|
||||
val out = new DataOutputStream(sock.getOutputStream)
|
||||
val in = new DataInputStream(sock.getInputStream)
|
||||
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
|
||||
// Collects a partition on each iteration
|
||||
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
|
||||
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head
|
||||
var result: Array[Any] = null
|
||||
rdd.sparkContext.submitJob(
|
||||
rdd,
|
||||
(iter: Iterator[Any]) => iter.toArray,
|
||||
Seq(i), // The partition we are evaluating
|
||||
(_, res: Array[Any]) => result = res,
|
||||
result)
|
||||
}
|
||||
val prefetchIter = collectPartitionIter.buffered
|
||||
|
||||
// Write data until iteration is complete, client stops iteration, or error occurs
|
||||
var complete = false
|
||||
|
@ -196,10 +204,15 @@ private[spark] object PythonRDD extends Logging {
|
|||
// Read request for data, value of zero will stop iteration or non-zero to continue
|
||||
if (in.readInt() == 0) {
|
||||
complete = true
|
||||
} else if (collectPartitionIter.hasNext) {
|
||||
} else if (prefetchIter.hasNext) {
|
||||
|
||||
// Client requested more data, attempt to collect the next partition
|
||||
val partitionArray = collectPartitionIter.next()
|
||||
val partitionFuture = prefetchIter.next()
|
||||
// Cause the next job to be submitted if prefetchPartitions is enabled.
|
||||
if (prefetchPartitions) {
|
||||
prefetchIter.headOption
|
||||
}
|
||||
val partitionArray = ThreadUtils.awaitResult(partitionFuture, Duration.Inf)
|
||||
|
||||
// Send response there is a partition to read
|
||||
out.writeInt(1)
|
||||
|
|
|
@ -2437,17 +2437,23 @@ class RDD(object):
|
|||
hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF)
|
||||
return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD)
|
||||
|
||||
def toLocalIterator(self):
|
||||
def toLocalIterator(self, prefetchPartitions=False):
|
||||
"""
|
||||
Return an iterator that contains all of the elements in this RDD.
|
||||
The iterator will consume as much memory as the largest partition in this RDD.
|
||||
With prefetch it may consume up to the memory of the 2 largest partitions.
|
||||
|
||||
:param prefetchPartitions: If Spark should pre-fetch the next partition
|
||||
before it is needed.
|
||||
|
||||
>>> rdd = sc.parallelize(range(10))
|
||||
>>> [x for x in rdd.toLocalIterator()]
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
"""
|
||||
with SCCallSiteSync(self.context) as css:
|
||||
sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
|
||||
sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(
|
||||
self._jrdd.rdd(),
|
||||
prefetchPartitions)
|
||||
return _local_iterator_from_socket(sock_info, self._jrdd_deserializer)
|
||||
|
||||
def barrier(self):
|
||||
|
|
|
@ -520,16 +520,20 @@ class DataFrame(object):
|
|||
|
||||
@ignore_unicode_prefix
|
||||
@since(2.0)
|
||||
def toLocalIterator(self):
|
||||
def toLocalIterator(self, prefetchPartitions=False):
|
||||
"""
|
||||
Returns an iterator that contains all of the rows in this :class:`DataFrame`.
|
||||
The iterator will consume as much memory as the largest partition in this DataFrame.
|
||||
With prefetch it may consume up to the memory of the 2 largest partitions.
|
||||
|
||||
:param prefetchPartitions: If Spark should pre-fetch the next partition
|
||||
before it is needed.
|
||||
|
||||
>>> list(df.toLocalIterator())
|
||||
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
|
||||
"""
|
||||
with SCCallSiteSync(self._sc) as css:
|
||||
sock_info = self._jdf.toPythonIterator()
|
||||
sock_info = self._jdf.toPythonIterator(prefetchPartitions)
|
||||
return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer()))
|
||||
|
||||
@ignore_unicode_prefix
|
||||
|
|
|
@ -690,6 +690,12 @@ class DataFrameTests(ReusedSQLTestCase):
|
|||
expected = df.collect()
|
||||
self.assertEqual(expected, list(it))
|
||||
|
||||
def test_to_local_iterator_prefetch(self):
|
||||
df = self.spark.range(8, numPartitions=4)
|
||||
expected = df.collect()
|
||||
it = df.toLocalIterator(prefetchPartitions=True)
|
||||
self.assertEqual(expected, list(it))
|
||||
|
||||
def test_to_local_iterator_not_fully_consumed(self):
|
||||
# SPARK-23961: toLocalIterator throws exception when not fully consumed
|
||||
# Create a DataFrame large enough so that write to socket will eventually block
|
||||
|
|
|
@ -14,11 +14,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime, timedelta
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from glob import glob
|
||||
|
||||
from py4j.protocol import Py4JJavaError
|
||||
|
@ -68,6 +70,26 @@ class RDDTests(ReusedPySparkTestCase):
|
|||
it2 = rdd2.toLocalIterator()
|
||||
self.assertEqual([1, 2, 3], sorted(it2))
|
||||
|
||||
def test_to_localiterator_prefetch(self):
|
||||
# Test that we fetch the next partition in parallel
|
||||
# We do this by returning the current time and:
|
||||
# reading the first elem, waiting, and reading the second elem
|
||||
# If not in parallel then these would be at different times
|
||||
# But since they are being computed in parallel we see the time
|
||||
# is "close enough" to the same.
|
||||
rdd = self.sc.parallelize(range(2), 2)
|
||||
times1 = rdd.map(lambda x: datetime.now())
|
||||
times2 = rdd.map(lambda x: datetime.now())
|
||||
times_iter_prefetch = times1.toLocalIterator(prefetchPartitions=True)
|
||||
times_iter = times2.toLocalIterator(prefetchPartitions=False)
|
||||
times_prefetch_head = next(times_iter_prefetch)
|
||||
times_head = next(times_iter)
|
||||
time.sleep(2)
|
||||
times_next = next(times_iter)
|
||||
times_prefetch_next = next(times_iter_prefetch)
|
||||
self.assertTrue(times_next - times_head >= timedelta(seconds=2))
|
||||
self.assertTrue(times_prefetch_next - times_prefetch_head < timedelta(seconds=1))
|
||||
|
||||
def test_save_as_textfile_with_unicode(self):
|
||||
# Regression test for SPARK-970
|
||||
x = u"\u00A1Hola, mundo!"
|
||||
|
|
|
@ -3356,9 +3356,9 @@ class Dataset[T] private[sql](
|
|||
}
|
||||
}
|
||||
|
||||
private[sql] def toPythonIterator(): Array[Any] = {
|
||||
private[sql] def toPythonIterator(prefetchPartitions: Boolean = false): Array[Any] = {
|
||||
withNewExecutionId {
|
||||
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
|
||||
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd, prefetchPartitions)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue