[SPARK-3478] [PySpark] Profile the Python tasks

This patch add profiling support for PySpark, it will show the profiling results
before the driver exits, here is one example:

```
============================================================
Profile of RDD<id=3>
============================================================
         5146507 function calls (5146487 primitive calls) in 71.094 seconds

   Ordered by: internal time, cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
  5144576   68.331    0.000   68.331    0.000 statcounter.py:44(merge)
       20    2.735    0.137   71.071    3.554 statcounter.py:33(__init__)
       20    0.017    0.001    0.017    0.001 {cPickle.dumps}
     1024    0.003    0.000    0.003    0.000 t.py:16(<lambda>)
       20    0.001    0.000    0.001    0.000 {reduce}
       21    0.001    0.000    0.001    0.000 {cPickle.loads}
       20    0.001    0.000    0.001    0.000 copy_reg.py:95(_slotnames)
       41    0.001    0.000    0.001    0.000 serializers.py:461(read_int)
       40    0.001    0.000    0.002    0.000 serializers.py:179(_batched)
       62    0.000    0.000    0.000    0.000 {method 'read' of 'file' objects}
       20    0.000    0.000   71.072    3.554 rdd.py:863(<lambda>)
       20    0.000    0.000    0.001    0.000 serializers.py:198(load_stream)
    40/20    0.000    0.000   71.072    3.554 rdd.py:2093(pipeline_func)
       41    0.000    0.000    0.002    0.000 serializers.py:130(load_stream)
       40    0.000    0.000   71.072    1.777 rdd.py:304(func)
       20    0.000    0.000   71.094    3.555 worker.py:82(process)
```

Also, use can show profile result manually by `sc.show_profiles()` or dump it into disk
by `sc.dump_profiles(path)`, such as

```python
>>> sc._conf.set("spark.python.profile", "true")
>>> rdd = sc.parallelize(range(100)).map(str)
>>> rdd.count()
100
>>> sc.show_profiles()
============================================================
Profile of RDD<id=1>
============================================================
         284 function calls (276 primitive calls) in 0.001 seconds

   Ordered by: internal time, cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        4    0.000    0.000    0.000    0.000 serializers.py:198(load_stream)
        4    0.000    0.000    0.000    0.000 {reduce}
     12/4    0.000    0.000    0.001    0.000 rdd.py:2092(pipeline_func)
        4    0.000    0.000    0.000    0.000 {cPickle.loads}
        4    0.000    0.000    0.000    0.000 {cPickle.dumps}
      104    0.000    0.000    0.000    0.000 rdd.py:852(<genexpr>)
        8    0.000    0.000    0.000    0.000 serializers.py:461(read_int)
       12    0.000    0.000    0.000    0.000 rdd.py:303(func)
```
The profiling is disabled by default, can be enabled by "spark.python.profile=true".

Also, users can dump the results into disks automatically for future analysis, by "spark.python.profile.dump=path_to_dump"

This is bugfix of #2351 cc JoshRosen

Author: Davies Liu <davies.liu@gmail.com>

Closes #2556 from davies/profiler and squashes the following commits:

e68df5a [Davies Liu] Merge branch 'master' of github.com:apache/spark into profiler
858e74c [Davies Liu] compatitable with python 2.6
7ef2aa0 [Davies Liu] bugfix, add tests for show_profiles and dump_profiles()
2b0daf2 [Davies Liu] fix docs
7a56c24 [Davies Liu] bugfix
cba9463 [Davies Liu] move show_profiles and dump_profiles to SparkContext
fb9565b [Davies Liu] Merge branch 'master' of github.com:apache/spark into profiler
116d52a [Davies Liu] Merge branch 'master' of github.com:apache/spark into profiler
09d02c3 [Davies Liu] Merge branch 'master' into profiler
c23865c [Davies Liu] Merge branch 'master' into profiler
15d6f18 [Davies Liu] add docs for two configs
dadee1a [Davies Liu] add docs string and clear profiles after show or dump
4f8309d [Davies Liu] address comment, add tests
0a5b6eb [Davies Liu] fix Python UDF
4b20494 [Davies Liu] add profile for python
This commit is contained in:
Davies Liu 2014-09-30 18:24:57 -07:00 committed by Josh Rosen
parent d75496b189
commit c5414b6818
7 changed files with 127 additions and 7 deletions

View file

@ -206,6 +206,25 @@ Apart from these, the following properties are also available, and may be useful
used during aggregation goes above this amount, it will spill the data into disks.
</td>
</tr>
<tr>
<td><code>spark.python.profile</code></td>
<td>false</td>
<td>
Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`,
or it will be displayed before the driver exiting. It also can be dumped into disk by
`sc.dump_profiles(path)`. If some of the profile results had been displayed maually,
they will not be displayed automatically before driver exiting.
</td>
</tr>
<tr>
<td><code>spark.python.profile.dump</code></td>
<td>(none)</td>
<td>
The directory which is used to dump the profile result before driver exiting.
The results will be dumped as separated file for each RDD. They can be loaded
by ptats.Stats(). If this is specified, the profile result will not be displayed
automatically.
</tr>
<tr>
<td><code>spark.python.worker.reuse</code></td>
<td>true</td>

View file

@ -215,6 +215,21 @@ FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
class PStatsParam(AccumulatorParam):
"""PStatsParam is used to merge pstats.Stats"""
@staticmethod
def zero(value):
return None
@staticmethod
def addInPlace(value1, value2):
if value1 is None:
return value2
value1.add(value2)
return value1
class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
"""

View file

@ -20,6 +20,7 @@ import shutil
import sys
from threading import Lock
from tempfile import NamedTemporaryFile
import atexit
from pyspark import accumulators
from pyspark.accumulators import Accumulator
@ -30,7 +31,6 @@ from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer, CompressedSerializer
from pyspark.storagelevel import StorageLevel
from pyspark import rdd
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call
@ -192,6 +192,9 @@ class SparkContext(object):
self._temp_dir = \
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
# profiling stats collected for each PythonRDD
self._profile_stats = []
def _initialize_context(self, jconf):
"""
Initialize SparkContext in function to allow subclass specific initialization
@ -792,6 +795,40 @@ class SparkContext(object):
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(mappedRDD._collect_iterator_through_file(it))
def _add_profile(self, id, profileAcc):
if not self._profile_stats:
dump_path = self._conf.get("spark.python.profile.dump")
if dump_path:
atexit.register(self.dump_profiles, dump_path)
else:
atexit.register(self.show_profiles)
self._profile_stats.append([id, profileAcc, False])
def show_profiles(self):
""" Print the profile stats to stdout """
for i, (id, acc, showed) in enumerate(self._profile_stats):
stats = acc.value
if not showed and stats:
print "=" * 60
print "Profile of RDD<id=%d>" % id
print "=" * 60
stats.sort_stats("time", "cumulative").print_stats()
# mark it as showed
self._profile_stats[i][2] = True
def dump_profiles(self, path):
""" Dump the profile stats into directory `path`
"""
if not os.path.exists(path):
os.makedirs(path)
for id, acc, _ in self._profile_stats:
stats = acc.value
if stats:
p = os.path.join(path, "rdd_%d.pstats" % id)
stats.dump_stats(p)
self._profile_stats = []
def _test():
import atexit

View file

@ -15,7 +15,6 @@
# limitations under the License.
#
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
from itertools import chain, ifilter, imap
@ -32,6 +31,7 @@ import bisect
from random import Random
from math import sqrt, log, isinf, isnan
from pyspark.accumulators import PStatsParam
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
PickleSerializer, pack_long, AutoBatchedSerializer
@ -2080,7 +2080,9 @@ class PipelinedRDD(RDD):
return self._jrdd_val
if self._bypass_serializer:
self._jrdd_deserializer = NoOpSerializer()
command = (self.func, self._prev_jrdd_deserializer,
enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
command = (self.func, profileStats, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
@ -2102,6 +2104,10 @@ class PipelinedRDD(RDD):
self.ctx.pythonExec,
broadcast_vars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()
if enable_profile:
self._id = self._jrdd_val.id()
self.ctx._add_profile(self._id, profileStats)
return self._jrdd_val
def id(self):

View file

@ -960,7 +960,7 @@ class SQLContext(object):
[Row(c0=4)]
"""
func = lambda _, it: imap(lambda x: f(*x), it)
command = (func,
command = (func, None,
BatchedSerializer(PickleSerializer(), 1024),
BatchedSerializer(PickleSerializer(), 1024))
ser = CloudPickleSerializer()

View file

@ -632,6 +632,36 @@ class TestRDDFunctions(PySparkTestCase):
self.assertEquals(result.count(), 3)
class TestProfiler(PySparkTestCase):
def setUp(self):
self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__
conf = SparkConf().set("spark.python.profile", "true")
self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf)
def test_profiler(self):
def heavy_foo(x):
for i in range(1 << 20):
x = 1
rdd = self.sc.parallelize(range(100))
rdd.foreach(heavy_foo)
profiles = self.sc._profile_stats
self.assertEqual(1, len(profiles))
id, acc, _ = profiles[0]
stats = acc.value
self.assertTrue(stats is not None)
width, stat_list = stats.get_print_list([])
func_names = [func_name for fname, n, func_name in stat_list]
self.assertTrue("heavy_foo" in func_names)
self.sc.show_profiles()
d = tempfile.gettempdir()
self.sc.dump_profiles(d)
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
class TestSQL(PySparkTestCase):
def setUp(self):

View file

@ -23,6 +23,8 @@ import sys
import time
import socket
import traceback
import cProfile
import pstats
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
@ -90,10 +92,21 @@ def main(infile, outfile):
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
(func, deserializer, serializer) = command
(func, stats, deserializer, serializer) = command
init_time = time.time()
def process():
iterator = deserializer.load_stream(infile)
serializer.dump_stream(func(split_index, iterator), outfile)
if stats:
p = cProfile.Profile()
p.runcall(process)
st = pstats.Stats(p)
st.stream = None # make it picklable
stats.add(st.strip_dirs())
else:
process()
except Exception:
try:
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)