[SPARK-4387][PySpark] Refactoring python profiling code to make it extensible

This PR is based on #3255 , fix conflicts and code style.

Closes #3255.

Author: Yandu Oppacher <yandu.oppacher@jadedpixel.com>
Author: Davies Liu <davies@databricks.com>

Closes #3901 from davies/refactor-python-profile-code and squashes the following commits:

b4a9306 [Davies Liu] fix tests
4b79ce8 [Davies Liu] add docstring for profiler_cls
2700e47 [Davies Liu] use BasicProfiler as default
349e341 [Davies Liu] more refactor
6a5d4df [Davies Liu] refactor and fix tests
31bf6b6 [Davies Liu] fix code style
0864b5d [Yandu Oppacher] Remove unused method
76a6c37 [Yandu Oppacher] Added a profile collector to accumulate the profilers per stage
9eefc36 [Yandu Oppacher] Fix doc
9ace076 [Yandu Oppacher] Refactor of profiler, and moved tests around
8739aff [Yandu Oppacher] Code review fixes
9bda3ec [Yandu Oppacher] Refactor profiler code
This commit is contained in:
Yandu Oppacher 2015-01-28 13:48:06 -08:00 committed by Josh Rosen
parent a731314c31
commit 3bead67d59
9 changed files with 235 additions and 71 deletions

View file

@ -311,6 +311,9 @@ Apart from these, the following properties are also available, and may be useful
or it will be displayed before the driver exiting. It also can be dumped into disk by
`sc.dump_profiles(path)`. If some of the profile results had been displayed maually,
they will not be displayed automatically before driver exiting.
By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by
passing a profiler class in as a parameter to the `SparkContext` constructor.
</td>
</tr>
<tr>

View file

@ -45,6 +45,7 @@ from pyspark.storagelevel import StorageLevel
from pyspark.accumulators import Accumulator, AccumulatorParam
from pyspark.broadcast import Broadcast
from pyspark.serializers import MarshalSerializer, PickleSerializer
from pyspark.profiler import Profiler, BasicProfiler
# for back compatibility
from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row
@ -52,4 +53,5 @@ from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row
__all__ = [
"SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
"Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
"Profiler", "BasicProfiler",
]

View file

@ -215,21 +215,6 @@ FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
class PStatsParam(AccumulatorParam):
"""PStatsParam is used to merge pstats.Stats"""
@staticmethod
def zero(value):
return None
@staticmethod
def addInPlace(value1, value2):
if value1 is None:
return value2
value1.add(value2)
return value1
class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
"""

View file

@ -20,7 +20,6 @@ import shutil
import sys
from threading import Lock
from tempfile import NamedTemporaryFile
import atexit
from pyspark import accumulators
from pyspark.accumulators import Accumulator
@ -33,6 +32,7 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deseria
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call
from pyspark.profiler import ProfilerCollector, BasicProfiler
from py4j.java_collections import ListConverter
@ -66,7 +66,7 @@ class SparkContext(object):
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
gateway=None, jsc=None):
gateway=None, jsc=None, profiler_cls=BasicProfiler):
"""
Create a new SparkContext. At least the master and app name should be set,
either through the named parameters here or through C{conf}.
@ -88,6 +88,9 @@ class SparkContext(object):
:param conf: A L{SparkConf} object setting Spark properties.
:param gateway: Use an existing gateway and JVM, otherwise a new JVM
will be instantiated.
:param jsc: The JavaSparkContext instance (optional).
:param profiler_cls: A class of custom Profiler used to do profiling
(default is pyspark.profiler.BasicProfiler).
>>> from pyspark.context import SparkContext
@ -102,14 +105,14 @@ class SparkContext(object):
SparkContext._ensure_initialized(self, gateway=gateway)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf, jsc)
conf, jsc, profiler_cls)
except:
# If an error occurs, clean up in order to allow future SparkContext creation:
self.stop()
raise
def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf, jsc):
conf, jsc, profiler_cls):
self.environment = environment or {}
self._conf = conf or SparkConf(_jvm=self._jvm)
self._batchSize = batchSize # -1 represents an unlimited batch size
@ -192,7 +195,11 @@ class SparkContext(object):
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
# profiling stats collected for each PythonRDD
self._profile_stats = []
if self._conf.get("spark.python.profile", "false") == "true":
dump_path = self._conf.get("spark.python.profile.dump", None)
self.profiler_collector = ProfilerCollector(profiler_cls, dump_path)
else:
self.profiler_collector = None
def _initialize_context(self, jconf):
"""
@ -826,39 +833,14 @@ class SparkContext(object):
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(mappedRDD._collect_iterator_through_file(it))
def _add_profile(self, id, profileAcc):
if not self._profile_stats:
dump_path = self._conf.get("spark.python.profile.dump")
if dump_path:
atexit.register(self.dump_profiles, dump_path)
else:
atexit.register(self.show_profiles)
self._profile_stats.append([id, profileAcc, False])
def show_profiles(self):
""" Print the profile stats to stdout """
for i, (id, acc, showed) in enumerate(self._profile_stats):
stats = acc.value
if not showed and stats:
print "=" * 60
print "Profile of RDD<id=%d>" % id
print "=" * 60
stats.sort_stats("time", "cumulative").print_stats()
# mark it as showed
self._profile_stats[i][2] = True
self.profiler_collector.show_profiles()
def dump_profiles(self, path):
""" Dump the profile stats into directory `path`
"""
if not os.path.exists(path):
os.makedirs(path)
for id, acc, _ in self._profile_stats:
stats = acc.value
if stats:
p = os.path.join(path, "rdd_%d.pstats" % id)
stats.dump_stats(p)
self._profile_stats = []
self.profiler_collector.dump_profiles(path)
def _test():

172
python/pyspark/profiler.py Normal file
View file

@ -0,0 +1,172 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import cProfile
import pstats
import os
import atexit
from pyspark.accumulators import AccumulatorParam
class ProfilerCollector(object):
"""
This class keeps track of different profilers on a per
stage basis. Also this is used to create new profilers for
the different stages.
"""
def __init__(self, profiler_cls, dump_path=None):
self.profiler_cls = profiler_cls
self.profile_dump_path = dump_path
self.profilers = []
def new_profiler(self, ctx):
""" Create a new profiler using class `profiler_cls` """
return self.profiler_cls(ctx)
def add_profiler(self, id, profiler):
""" Add a profiler for RDD `id` """
if not self.profilers:
if self.profile_dump_path:
atexit.register(self.dump_profiles, self.profile_dump_path)
else:
atexit.register(self.show_profiles)
self.profilers.append([id, profiler, False])
def dump_profiles(self, path):
""" Dump the profile stats into directory `path` """
for id, profiler, _ in self.profilers:
profiler.dump(id, path)
self.profilers = []
def show_profiles(self):
""" Print the profile stats to stdout """
for i, (id, profiler, showed) in enumerate(self.profilers):
if not showed and profiler:
profiler.show(id)
# mark it as showed
self.profilers[i][2] = True
class Profiler(object):
"""
.. note:: DeveloperApi
PySpark supports custom profilers, this is to allow for different profilers to
be used as well as outputting to different formats than what is provided in the
BasicProfiler.
A custom profiler has to define or inherit the following methods:
profile - will produce a system profile of some sort.
stats - return the collected stats.
dump - dumps the profiles to a path
add - adds a profile to the existing accumulated profile
The profiler class is chosen when creating a SparkContext
>>> from pyspark import SparkConf, SparkContext
>>> from pyspark import BasicProfiler
>>> class MyCustomProfiler(BasicProfiler):
... def show(self, id):
... print "My custom profiles for RDD:%s" % id
...
>>> conf = SparkConf().set("spark.python.profile", "true")
>>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler)
>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
>>> sc.show_profiles()
My custom profiles for RDD:1
My custom profiles for RDD:2
>>> sc.stop()
"""
def __init__(self, ctx):
pass
def profile(self, func):
""" Do profiling on the function `func`"""
raise NotImplemented
def stats(self):
""" Return the collected profiling stats (pstats.Stats)"""
raise NotImplemented
def show(self, id):
""" Print the profile stats to stdout, id is the RDD id """
stats = self.stats()
if stats:
print "=" * 60
print "Profile of RDD<id=%d>" % id
print "=" * 60
stats.sort_stats("time", "cumulative").print_stats()
def dump(self, id, path):
""" Dump the profile into path, id is the RDD id """
if not os.path.exists(path):
os.makedirs(path)
stats = self.stats()
if stats:
p = os.path.join(path, "rdd_%d.pstats" % id)
stats.dump_stats(p)
class PStatsParam(AccumulatorParam):
"""PStatsParam is used to merge pstats.Stats"""
@staticmethod
def zero(value):
return None
@staticmethod
def addInPlace(value1, value2):
if value1 is None:
return value2
value1.add(value2)
return value1
class BasicProfiler(Profiler):
"""
BasicProfiler is the default profiler, which is implemented based on
cProfile and Accumulator
"""
def __init__(self, ctx):
Profiler.__init__(self, ctx)
# Creates a new accumulator for combining the profiles of different
# partitions of a stage
self._accumulator = ctx.accumulator(None, PStatsParam)
def profile(self, func):
""" Runs and profiles the method to_profile passed in. A profile object is returned. """
pr = cProfile.Profile()
pr.runcall(func)
st = pstats.Stats(pr)
st.stream = None # make it picklable
st.strip_dirs()
# Adds a new profile to the existing accumulated value
self._accumulator.add(st)
def stats(self):
return self._accumulator.value
if __name__ == "__main__":
import doctest
doctest.testmod()

View file

@ -31,7 +31,6 @@ import bisect
import random
from math import sqrt, log, isinf, isnan
from pyspark.accumulators import PStatsParam
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
PickleSerializer, pack_long, AutoBatchedSerializer
@ -2132,9 +2131,13 @@ class PipelinedRDD(RDD):
return self._jrdd_val
if self._bypass_serializer:
self._jrdd_deserializer = NoOpSerializer()
enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
command = (self.func, profileStats, self._prev_jrdd_deserializer,
if self.ctx.profiler_collector:
profiler = self.ctx.profiler_collector.new_profiler(self.ctx)
else:
profiler = None
command = (self.func, profiler, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
@ -2157,9 +2160,9 @@ class PipelinedRDD(RDD):
broadcast_vars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()
if enable_profile:
if profiler:
self._id = self._jrdd_val.id()
self.ctx._add_profile(self._id, profileStats)
self.ctx.profiler_collector.add_profiler(self._id, profiler)
return self._jrdd_val
def id(self):

View file

@ -53,6 +53,7 @@ from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, External
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
UserDefinedType, DoubleType
from pyspark import shuffle
from pyspark.profiler import BasicProfiler
_have_scipy = False
_have_numpy = False
@ -743,16 +744,12 @@ class ProfilerTests(PySparkTestCase):
self.sc = SparkContext('local[4]', class_name, conf=conf)
def test_profiler(self):
self.do_computation()
def heavy_foo(x):
for i in range(1 << 20):
x = 1
rdd = self.sc.parallelize(range(100))
rdd.foreach(heavy_foo)
profiles = self.sc._profile_stats
self.assertEqual(1, len(profiles))
id, acc, _ = profiles[0]
stats = acc.value
profilers = self.sc.profiler_collector.profilers
self.assertEqual(1, len(profilers))
id, profiler, _ = profilers[0]
stats = profiler.stats()
self.assertTrue(stats is not None)
width, stat_list = stats.get_print_list([])
func_names = [func_name for fname, n, func_name in stat_list]
@ -763,6 +760,31 @@ class ProfilerTests(PySparkTestCase):
self.sc.dump_profiles(d)
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
def test_custom_profiler(self):
class TestCustomProfiler(BasicProfiler):
def show(self, id):
self.result = "Custom formatting"
self.sc.profiler_collector.profiler_cls = TestCustomProfiler
self.do_computation()
profilers = self.sc.profiler_collector.profilers
self.assertEqual(1, len(profilers))
_, profiler, _ = profilers[0]
self.assertTrue(isinstance(profiler, TestCustomProfiler))
self.sc.show_profiles()
self.assertEqual("Custom formatting", profiler.result)
def do_computation(self):
def heavy_foo(x):
for i in range(1 << 20):
x = 1
rdd = self.sc.parallelize(range(100))
rdd.foreach(heavy_foo)
class ExamplePointUDT(UserDefinedType):
"""

View file

@ -23,8 +23,6 @@ import sys
import time
import socket
import traceback
import cProfile
import pstats
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
@ -90,19 +88,15 @@ def main(infile, outfile):
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
(func, stats, deserializer, serializer) = command
(func, profiler, deserializer, serializer) = command
init_time = time.time()
def process():
iterator = deserializer.load_stream(infile)
serializer.dump_stream(func(split_index, iterator), outfile)
if stats:
p = cProfile.Profile()
p.runcall(process)
st = pstats.Stats(p)
st.stream = None # make it picklable
stats.add(st.strip_dirs())
if profiler:
profiler.profile(process)
else:
process()
except Exception:

View file

@ -57,6 +57,7 @@ function run_core_tests() {
PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
run_test "pyspark/serializers.py"
run_test "pyspark/profiler.py"
run_test "pyspark/shuffle.py"
run_test "pyspark/tests.py"
}