Make Python function/line appear in the UI.

This commit is contained in:
Tor Myklebust 2013-12-28 23:34:16 -05:00
parent d812aeece9
commit fec01664a7

View file

@ -23,6 +23,7 @@ import operator
import os
import sys
import shlex
import traceback
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
@ -39,6 +40,46 @@ from py4j.java_collections import ListConverter, MapConverter
__all__ = ["RDD"]
def _extract_concise_traceback():
tb = traceback.extract_stack()
if len(tb) == 0:
return "I'm lost!"
# HACK: This function is in a file called 'rdd.py' in the top level of
# everything PySpark. Just trim off the directory name and assume
# everything in that tree is PySpark guts.
file, line, module, what = tb[len(tb) - 1]
sparkpath = os.path.dirname(file)
first_spark_frame = len(tb) - 1
for i in range(0, len(tb)):
file, line, fun, what = tb[i]
if file.startswith(sparkpath):
first_spark_frame = i
break
if first_spark_frame == 0:
file, line, fun, what = tb[0]
return "%s at %s:%d" % (fun, file, line)
sfile, sline, sfun, swhat = tb[first_spark_frame]
ufile, uline, ufun, uwhat = tb[first_spark_frame-1]
return "%s at %s:%d" % (sfun, ufile, uline)
_spark_stack_depth = 0
class _JavaStackTrace(object):
def __init__(self, sc):
self._traceback = _extract_concise_traceback()
self._context = sc
def __enter__(self):
global _spark_stack_depth
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(self._traceback)
_spark_stack_depth += 1
def __exit__(self, type, value, tb):
global _spark_stack_depth
_spark_stack_depth -= 1
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(None)
class RDD(object):
"""
@ -401,7 +442,8 @@ class RDD(object):
"""
Return a list that contains all of the elements in this RDD.
"""
bytesInJava = self._jrdd.collect().iterator()
with _JavaStackTrace(self.context) as st:
bytesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(bytesInJava))
def _collect_iterator_through_file(self, iterator):
@ -582,13 +624,14 @@ class RDD(object):
# TODO(shivaram): Similar to the scala implementation, update the take
# method to scan multiple splits based on an estimate of how many elements
# we have per-split.
for partition in range(mapped._jrdd.splits().size()):
partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
partitionsToTake[0] = partition
iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
items.extend(mapped._collect_iterator_through_file(iterator))
if len(items) >= num:
break
with _JavaStackTrace(self.context) as st:
for partition in range(mapped._jrdd.splits().size()):
partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
partitionsToTake[0] = partition
iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
items.extend(mapped._collect_iterator_through_file(iterator))
if len(items) >= num:
break
return items[:num]
def first(self):
@ -765,9 +808,10 @@ class RDD(object):
yield outputSerializer.dumps(items)
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
with _JavaStackTrace(self.context) as st:
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
# This is required so that id(partitionFunc) remains unique, even if