[SPARK-4387][PySpark] Refactoring python profiling code to make it extensible
This PR is based on #3255 , fix conflicts and code style. Closes #3255. Author: Yandu Oppacher <yandu.oppacher@jadedpixel.com> Author: Davies Liu <davies@databricks.com> Closes #3901 from davies/refactor-python-profile-code and squashes the following commits: b4a9306 [Davies Liu] fix tests 4b79ce8 [Davies Liu] add docstring for profiler_cls 2700e47 [Davies Liu] use BasicProfiler as default 349e341 [Davies Liu] more refactor 6a5d4df [Davies Liu] refactor and fix tests 31bf6b6 [Davies Liu] fix code style 0864b5d [Yandu Oppacher] Remove unused method 76a6c37 [Yandu Oppacher] Added a profile collector to accumulate the profilers per stage 9eefc36 [Yandu Oppacher] Fix doc 9ace076 [Yandu Oppacher] Refactor of profiler, and moved tests around 8739aff [Yandu Oppacher] Code review fixes 9bda3ec [Yandu Oppacher] Refactor profiler code
This commit is contained in:
parent
a731314c31
commit
3bead67d59
|
@ -311,6 +311,9 @@ Apart from these, the following properties are also available, and may be useful
|
|||
or it will be displayed before the driver exiting. It also can be dumped into disk by
|
||||
`sc.dump_profiles(path)`. If some of the profile results had been displayed maually,
|
||||
they will not be displayed automatically before driver exiting.
|
||||
|
||||
By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by
|
||||
passing a profiler class in as a parameter to the `SparkContext` constructor.
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
|
|
|
@ -45,6 +45,7 @@ from pyspark.storagelevel import StorageLevel
|
|||
from pyspark.accumulators import Accumulator, AccumulatorParam
|
||||
from pyspark.broadcast import Broadcast
|
||||
from pyspark.serializers import MarshalSerializer, PickleSerializer
|
||||
from pyspark.profiler import Profiler, BasicProfiler
|
||||
|
||||
# for back compatibility
|
||||
from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row
|
||||
|
@ -52,4 +53,5 @@ from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row
|
|||
__all__ = [
|
||||
"SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
|
||||
"Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
|
||||
"Profiler", "BasicProfiler",
|
||||
]
|
||||
|
|
|
@ -215,21 +215,6 @@ FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
|
|||
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
|
||||
|
||||
|
||||
class PStatsParam(AccumulatorParam):
|
||||
"""PStatsParam is used to merge pstats.Stats"""
|
||||
|
||||
@staticmethod
|
||||
def zero(value):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def addInPlace(value1, value2):
|
||||
if value1 is None:
|
||||
return value2
|
||||
value1.add(value2)
|
||||
return value1
|
||||
|
||||
|
||||
class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
|
||||
|
||||
"""
|
||||
|
|
|
@ -20,7 +20,6 @@ import shutil
|
|||
import sys
|
||||
from threading import Lock
|
||||
from tempfile import NamedTemporaryFile
|
||||
import atexit
|
||||
|
||||
from pyspark import accumulators
|
||||
from pyspark.accumulators import Accumulator
|
||||
|
@ -33,6 +32,7 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deseria
|
|||
from pyspark.storagelevel import StorageLevel
|
||||
from pyspark.rdd import RDD
|
||||
from pyspark.traceback_utils import CallSite, first_spark_call
|
||||
from pyspark.profiler import ProfilerCollector, BasicProfiler
|
||||
|
||||
from py4j.java_collections import ListConverter
|
||||
|
||||
|
@ -66,7 +66,7 @@ class SparkContext(object):
|
|||
|
||||
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
|
||||
environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
|
||||
gateway=None, jsc=None):
|
||||
gateway=None, jsc=None, profiler_cls=BasicProfiler):
|
||||
"""
|
||||
Create a new SparkContext. At least the master and app name should be set,
|
||||
either through the named parameters here or through C{conf}.
|
||||
|
@ -88,6 +88,9 @@ class SparkContext(object):
|
|||
:param conf: A L{SparkConf} object setting Spark properties.
|
||||
:param gateway: Use an existing gateway and JVM, otherwise a new JVM
|
||||
will be instantiated.
|
||||
:param jsc: The JavaSparkContext instance (optional).
|
||||
:param profiler_cls: A class of custom Profiler used to do profiling
|
||||
(default is pyspark.profiler.BasicProfiler).
|
||||
|
||||
|
||||
>>> from pyspark.context import SparkContext
|
||||
|
@ -102,14 +105,14 @@ class SparkContext(object):
|
|||
SparkContext._ensure_initialized(self, gateway=gateway)
|
||||
try:
|
||||
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
|
||||
conf, jsc)
|
||||
conf, jsc, profiler_cls)
|
||||
except:
|
||||
# If an error occurs, clean up in order to allow future SparkContext creation:
|
||||
self.stop()
|
||||
raise
|
||||
|
||||
def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
|
||||
conf, jsc):
|
||||
conf, jsc, profiler_cls):
|
||||
self.environment = environment or {}
|
||||
self._conf = conf or SparkConf(_jvm=self._jvm)
|
||||
self._batchSize = batchSize # -1 represents an unlimited batch size
|
||||
|
@ -192,7 +195,11 @@ class SparkContext(object):
|
|||
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
|
||||
|
||||
# profiling stats collected for each PythonRDD
|
||||
self._profile_stats = []
|
||||
if self._conf.get("spark.python.profile", "false") == "true":
|
||||
dump_path = self._conf.get("spark.python.profile.dump", None)
|
||||
self.profiler_collector = ProfilerCollector(profiler_cls, dump_path)
|
||||
else:
|
||||
self.profiler_collector = None
|
||||
|
||||
def _initialize_context(self, jconf):
|
||||
"""
|
||||
|
@ -826,39 +833,14 @@ class SparkContext(object):
|
|||
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
|
||||
return list(mappedRDD._collect_iterator_through_file(it))
|
||||
|
||||
def _add_profile(self, id, profileAcc):
|
||||
if not self._profile_stats:
|
||||
dump_path = self._conf.get("spark.python.profile.dump")
|
||||
if dump_path:
|
||||
atexit.register(self.dump_profiles, dump_path)
|
||||
else:
|
||||
atexit.register(self.show_profiles)
|
||||
|
||||
self._profile_stats.append([id, profileAcc, False])
|
||||
|
||||
def show_profiles(self):
|
||||
""" Print the profile stats to stdout """
|
||||
for i, (id, acc, showed) in enumerate(self._profile_stats):
|
||||
stats = acc.value
|
||||
if not showed and stats:
|
||||
print "=" * 60
|
||||
print "Profile of RDD<id=%d>" % id
|
||||
print "=" * 60
|
||||
stats.sort_stats("time", "cumulative").print_stats()
|
||||
# mark it as showed
|
||||
self._profile_stats[i][2] = True
|
||||
self.profiler_collector.show_profiles()
|
||||
|
||||
def dump_profiles(self, path):
|
||||
""" Dump the profile stats into directory `path`
|
||||
"""
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
for id, acc, _ in self._profile_stats:
|
||||
stats = acc.value
|
||||
if stats:
|
||||
p = os.path.join(path, "rdd_%d.pstats" % id)
|
||||
stats.dump_stats(p)
|
||||
self._profile_stats = []
|
||||
self.profiler_collector.dump_profiles(path)
|
||||
|
||||
|
||||
def _test():
|
||||
|
|
172
python/pyspark/profiler.py
Normal file
172
python/pyspark/profiler.py
Normal file
|
@ -0,0 +1,172 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
# contributor license agreements. See the NOTICE file distributed with
|
||||
# this work for additional information regarding copyright ownership.
|
||||
# The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
# (the "License"); you may not use this file except in compliance with
|
||||
# the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import cProfile
|
||||
import pstats
|
||||
import os
|
||||
import atexit
|
||||
|
||||
from pyspark.accumulators import AccumulatorParam
|
||||
|
||||
|
||||
class ProfilerCollector(object):
|
||||
"""
|
||||
This class keeps track of different profilers on a per
|
||||
stage basis. Also this is used to create new profilers for
|
||||
the different stages.
|
||||
"""
|
||||
|
||||
def __init__(self, profiler_cls, dump_path=None):
|
||||
self.profiler_cls = profiler_cls
|
||||
self.profile_dump_path = dump_path
|
||||
self.profilers = []
|
||||
|
||||
def new_profiler(self, ctx):
|
||||
""" Create a new profiler using class `profiler_cls` """
|
||||
return self.profiler_cls(ctx)
|
||||
|
||||
def add_profiler(self, id, profiler):
|
||||
""" Add a profiler for RDD `id` """
|
||||
if not self.profilers:
|
||||
if self.profile_dump_path:
|
||||
atexit.register(self.dump_profiles, self.profile_dump_path)
|
||||
else:
|
||||
atexit.register(self.show_profiles)
|
||||
|
||||
self.profilers.append([id, profiler, False])
|
||||
|
||||
def dump_profiles(self, path):
|
||||
""" Dump the profile stats into directory `path` """
|
||||
for id, profiler, _ in self.profilers:
|
||||
profiler.dump(id, path)
|
||||
self.profilers = []
|
||||
|
||||
def show_profiles(self):
|
||||
""" Print the profile stats to stdout """
|
||||
for i, (id, profiler, showed) in enumerate(self.profilers):
|
||||
if not showed and profiler:
|
||||
profiler.show(id)
|
||||
# mark it as showed
|
||||
self.profilers[i][2] = True
|
||||
|
||||
|
||||
class Profiler(object):
|
||||
"""
|
||||
.. note:: DeveloperApi
|
||||
|
||||
PySpark supports custom profilers, this is to allow for different profilers to
|
||||
be used as well as outputting to different formats than what is provided in the
|
||||
BasicProfiler.
|
||||
|
||||
A custom profiler has to define or inherit the following methods:
|
||||
profile - will produce a system profile of some sort.
|
||||
stats - return the collected stats.
|
||||
dump - dumps the profiles to a path
|
||||
add - adds a profile to the existing accumulated profile
|
||||
|
||||
The profiler class is chosen when creating a SparkContext
|
||||
|
||||
>>> from pyspark import SparkConf, SparkContext
|
||||
>>> from pyspark import BasicProfiler
|
||||
>>> class MyCustomProfiler(BasicProfiler):
|
||||
... def show(self, id):
|
||||
... print "My custom profiles for RDD:%s" % id
|
||||
...
|
||||
>>> conf = SparkConf().set("spark.python.profile", "true")
|
||||
>>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler)
|
||||
>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
|
||||
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
|
||||
>>> sc.show_profiles()
|
||||
My custom profiles for RDD:1
|
||||
My custom profiles for RDD:2
|
||||
>>> sc.stop()
|
||||
"""
|
||||
|
||||
def __init__(self, ctx):
|
||||
pass
|
||||
|
||||
def profile(self, func):
|
||||
""" Do profiling on the function `func`"""
|
||||
raise NotImplemented
|
||||
|
||||
def stats(self):
|
||||
""" Return the collected profiling stats (pstats.Stats)"""
|
||||
raise NotImplemented
|
||||
|
||||
def show(self, id):
|
||||
""" Print the profile stats to stdout, id is the RDD id """
|
||||
stats = self.stats()
|
||||
if stats:
|
||||
print "=" * 60
|
||||
print "Profile of RDD<id=%d>" % id
|
||||
print "=" * 60
|
||||
stats.sort_stats("time", "cumulative").print_stats()
|
||||
|
||||
def dump(self, id, path):
|
||||
""" Dump the profile into path, id is the RDD id """
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
stats = self.stats()
|
||||
if stats:
|
||||
p = os.path.join(path, "rdd_%d.pstats" % id)
|
||||
stats.dump_stats(p)
|
||||
|
||||
|
||||
class PStatsParam(AccumulatorParam):
|
||||
"""PStatsParam is used to merge pstats.Stats"""
|
||||
|
||||
@staticmethod
|
||||
def zero(value):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def addInPlace(value1, value2):
|
||||
if value1 is None:
|
||||
return value2
|
||||
value1.add(value2)
|
||||
return value1
|
||||
|
||||
|
||||
class BasicProfiler(Profiler):
|
||||
"""
|
||||
BasicProfiler is the default profiler, which is implemented based on
|
||||
cProfile and Accumulator
|
||||
"""
|
||||
def __init__(self, ctx):
|
||||
Profiler.__init__(self, ctx)
|
||||
# Creates a new accumulator for combining the profiles of different
|
||||
# partitions of a stage
|
||||
self._accumulator = ctx.accumulator(None, PStatsParam)
|
||||
|
||||
def profile(self, func):
|
||||
""" Runs and profiles the method to_profile passed in. A profile object is returned. """
|
||||
pr = cProfile.Profile()
|
||||
pr.runcall(func)
|
||||
st = pstats.Stats(pr)
|
||||
st.stream = None # make it picklable
|
||||
st.strip_dirs()
|
||||
|
||||
# Adds a new profile to the existing accumulated value
|
||||
self._accumulator.add(st)
|
||||
|
||||
def stats(self):
|
||||
return self._accumulator.value
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
|
@ -31,7 +31,6 @@ import bisect
|
|||
import random
|
||||
from math import sqrt, log, isinf, isnan
|
||||
|
||||
from pyspark.accumulators import PStatsParam
|
||||
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
|
||||
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
|
||||
PickleSerializer, pack_long, AutoBatchedSerializer
|
||||
|
@ -2132,9 +2131,13 @@ class PipelinedRDD(RDD):
|
|||
return self._jrdd_val
|
||||
if self._bypass_serializer:
|
||||
self._jrdd_deserializer = NoOpSerializer()
|
||||
enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
|
||||
profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
|
||||
command = (self.func, profileStats, self._prev_jrdd_deserializer,
|
||||
|
||||
if self.ctx.profiler_collector:
|
||||
profiler = self.ctx.profiler_collector.new_profiler(self.ctx)
|
||||
else:
|
||||
profiler = None
|
||||
|
||||
command = (self.func, profiler, self._prev_jrdd_deserializer,
|
||||
self._jrdd_deserializer)
|
||||
# the serialized command will be compressed by broadcast
|
||||
ser = CloudPickleSerializer()
|
||||
|
@ -2157,9 +2160,9 @@ class PipelinedRDD(RDD):
|
|||
broadcast_vars, self.ctx._javaAccumulator)
|
||||
self._jrdd_val = python_rdd.asJavaRDD()
|
||||
|
||||
if enable_profile:
|
||||
if profiler:
|
||||
self._id = self._jrdd_val.id()
|
||||
self.ctx._add_profile(self._id, profileStats)
|
||||
self.ctx.profiler_collector.add_profiler(self._id, profiler)
|
||||
return self._jrdd_val
|
||||
|
||||
def id(self):
|
||||
|
|
|
@ -53,6 +53,7 @@ from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, External
|
|||
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
|
||||
UserDefinedType, DoubleType
|
||||
from pyspark import shuffle
|
||||
from pyspark.profiler import BasicProfiler
|
||||
|
||||
_have_scipy = False
|
||||
_have_numpy = False
|
||||
|
@ -743,16 +744,12 @@ class ProfilerTests(PySparkTestCase):
|
|||
self.sc = SparkContext('local[4]', class_name, conf=conf)
|
||||
|
||||
def test_profiler(self):
|
||||
self.do_computation()
|
||||
|
||||
def heavy_foo(x):
|
||||
for i in range(1 << 20):
|
||||
x = 1
|
||||
rdd = self.sc.parallelize(range(100))
|
||||
rdd.foreach(heavy_foo)
|
||||
profiles = self.sc._profile_stats
|
||||
self.assertEqual(1, len(profiles))
|
||||
id, acc, _ = profiles[0]
|
||||
stats = acc.value
|
||||
profilers = self.sc.profiler_collector.profilers
|
||||
self.assertEqual(1, len(profilers))
|
||||
id, profiler, _ = profilers[0]
|
||||
stats = profiler.stats()
|
||||
self.assertTrue(stats is not None)
|
||||
width, stat_list = stats.get_print_list([])
|
||||
func_names = [func_name for fname, n, func_name in stat_list]
|
||||
|
@ -763,6 +760,31 @@ class ProfilerTests(PySparkTestCase):
|
|||
self.sc.dump_profiles(d)
|
||||
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
|
||||
|
||||
def test_custom_profiler(self):
|
||||
class TestCustomProfiler(BasicProfiler):
|
||||
def show(self, id):
|
||||
self.result = "Custom formatting"
|
||||
|
||||
self.sc.profiler_collector.profiler_cls = TestCustomProfiler
|
||||
|
||||
self.do_computation()
|
||||
|
||||
profilers = self.sc.profiler_collector.profilers
|
||||
self.assertEqual(1, len(profilers))
|
||||
_, profiler, _ = profilers[0]
|
||||
self.assertTrue(isinstance(profiler, TestCustomProfiler))
|
||||
|
||||
self.sc.show_profiles()
|
||||
self.assertEqual("Custom formatting", profiler.result)
|
||||
|
||||
def do_computation(self):
|
||||
def heavy_foo(x):
|
||||
for i in range(1 << 20):
|
||||
x = 1
|
||||
|
||||
rdd = self.sc.parallelize(range(100))
|
||||
rdd.foreach(heavy_foo)
|
||||
|
||||
|
||||
class ExamplePointUDT(UserDefinedType):
|
||||
"""
|
||||
|
|
|
@ -23,8 +23,6 @@ import sys
|
|||
import time
|
||||
import socket
|
||||
import traceback
|
||||
import cProfile
|
||||
import pstats
|
||||
|
||||
from pyspark.accumulators import _accumulatorRegistry
|
||||
from pyspark.broadcast import Broadcast, _broadcastRegistry
|
||||
|
@ -90,19 +88,15 @@ def main(infile, outfile):
|
|||
command = pickleSer._read_with_length(infile)
|
||||
if isinstance(command, Broadcast):
|
||||
command = pickleSer.loads(command.value)
|
||||
(func, stats, deserializer, serializer) = command
|
||||
(func, profiler, deserializer, serializer) = command
|
||||
init_time = time.time()
|
||||
|
||||
def process():
|
||||
iterator = deserializer.load_stream(infile)
|
||||
serializer.dump_stream(func(split_index, iterator), outfile)
|
||||
|
||||
if stats:
|
||||
p = cProfile.Profile()
|
||||
p.runcall(process)
|
||||
st = pstats.Stats(p)
|
||||
st.stream = None # make it picklable
|
||||
stats.add(st.strip_dirs())
|
||||
if profiler:
|
||||
profiler.profile(process)
|
||||
else:
|
||||
process()
|
||||
except Exception:
|
||||
|
|
|
@ -57,6 +57,7 @@ function run_core_tests() {
|
|||
PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
|
||||
PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
|
||||
run_test "pyspark/serializers.py"
|
||||
run_test "pyspark/profiler.py"
|
||||
run_test "pyspark/shuffle.py"
|
||||
run_test "pyspark/tests.py"
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue