[SPARK-2470] PEP8 fixes to PySpark

This pull request aims to resolve all outstanding PEP8 violations in PySpark.

Author: Nicholas Chammas <nicholas.chammas@gmail.com>
Author: nchammas <nicholas.chammas@gmail.com>

Closes #1505 from nchammas/master and squashes the following commits:

98171af [Nicholas Chammas] [SPARK-2470] revert PEP 8 fixes to cloudpickle
cba7768 [Nicholas Chammas] [SPARK-2470] wrap expression list in parentheses
e178dbe [Nicholas Chammas] [SPARK-2470] style - change position of line break
9127d2b [Nicholas Chammas] [SPARK-2470] wrap expression lists in parentheses
22132a4 [Nicholas Chammas] [SPARK-2470] wrap conditionals in parentheses
24639bc [Nicholas Chammas] [SPARK-2470] fix whitespace for doctest
7d557b7 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to tests.py
8f8e4c0 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to storagelevel.py
b3b96cf [Nicholas Chammas] [SPARK-2470] PEP8 fixes to statcounter.py
d644477 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to worker.py
aa3a7b6 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to sql.py
1916859 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to shell.py
95d1d95 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to serializers.py
a0fec2e [Nicholas Chammas] [SPARK-2470] PEP8 fixes to mllib
c85e1e5 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to join.py
d14f2f1 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to __init__.py
81fcb20 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to resultiterable.py
1bde265 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to java_gateway.py
7fc849c [Nicholas Chammas] [SPARK-2470] PEP8 fixes to daemon.py
ca2d28b [Nicholas Chammas] [SPARK-2470] PEP8 fixes to context.py
f4e0039 [Nicholas Chammas] [SPARK-2470] PEP8 fixes to conf.py
a6d5e4b [Nicholas Chammas] [SPARK-2470] PEP8 fixes to cloudpickle.py
f0a7ebf [Nicholas Chammas] [SPARK-2470] PEP8 fixes to rddsampler.py
4dd148f [nchammas] Merge pull request #5 from apache/master
f7e4581 [Nicholas Chammas] unrelated pep8 fix
a36eed0 [Nicholas Chammas] name ec2 instances and security groups consistently
de7292a [nchammas] Merge pull request #4 from apache/master
2e4fe00 [nchammas] Merge pull request #3 from apache/master
89fde08 [nchammas] Merge pull request #2 from apache/master
69f6e22 [Nicholas Chammas] PEP8 fixes
2627247 [Nicholas Chammas] broke up lines before they hit 100 chars
6544b7e [Nicholas Chammas] [SPARK-2065] give launched instances names
69da6cf [nchammas] Merge pull request #1 from apache/master
This commit is contained in:
Nicholas Chammas 2014-07-21 22:30:53 -07:00 committed by Reynold Xin
parent c3462c6568
commit 5d16d5bbfd
18 changed files with 127 additions and 97 deletions

View file

@ -59,4 +59,5 @@ from pyspark.files import SparkFiles
from pyspark.storagelevel import StorageLevel
__all__ = ["SparkConf", "SparkContext", "SQLContext", "RDD", "SchemaRDD", "SparkFiles", "StorageLevel", "Row"]
__all__ = ["SparkConf", "SparkContext", "SQLContext", "RDD", "SchemaRDD",
"SparkFiles", "StorageLevel", "Row"]

View file

@ -50,7 +50,8 @@ spark.executorEnv.VAR3=value3
spark.executorEnv.VAR4=value4
spark.home=/path
>>> sorted(conf.getAll(), key=lambda p: p[0])
[(u'spark.executorEnv.VAR1', u'value1'), (u'spark.executorEnv.VAR3', u'value3'), (u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')]
[(u'spark.executorEnv.VAR1', u'value1'), (u'spark.executorEnv.VAR3', u'value3'), \
(u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')]
"""
@ -118,9 +119,9 @@ class SparkConf(object):
"""Set an environment variable to be passed to executors."""
if (key is not None and pairs is not None) or (key is None and pairs is None):
raise Exception("Either pass one key-value pair or a list of pairs")
elif key != None:
elif key is not None:
self._jconf.setExecutorEnv(key, value)
elif pairs != None:
elif pairs is not None:
for (k, v) in pairs:
self._jconf.setExecutorEnv(k, v)
return self
@ -137,7 +138,7 @@ class SparkConf(object):
def get(self, key, defaultValue=None):
"""Get the configured value for some key, or return a default otherwise."""
if defaultValue == None: # Py4J doesn't call the right get() if we pass None
if defaultValue is None: # Py4J doesn't call the right get() if we pass None
if not self._jconf.contains(key):
return None
return self._jconf.get(key)

View file

@ -29,7 +29,7 @@ from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer
PairDeserializer
from pyspark.storagelevel import StorageLevel
from pyspark import rdd
from pyspark.rdd import RDD
@ -50,12 +50,11 @@ class SparkContext(object):
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None,
gateway=None):
environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None,
gateway=None):
"""
Create a new SparkContext. At least the master and app name should be set,
either through the named parameters here or through C{conf}.
@ -138,8 +137,8 @@ class SparkContext(object):
self._accumulatorServer = accumulators._start_update_server()
(host, port) = self._accumulatorServer.server_address
self._javaAccumulator = self._jsc.accumulator(
self._jvm.java.util.ArrayList(),
self._jvm.PythonAccumulatorParam(host, port))
self._jvm.java.util.ArrayList(),
self._jvm.PythonAccumulatorParam(host, port))
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
@ -165,7 +164,7 @@ class SparkContext(object):
(dirname, filename) = os.path.split(path)
self._python_includes.append(filename)
sys.path.append(path)
if not dirname in sys.path:
if dirname not in sys.path:
sys.path.append(dirname)
# Create a temporary directory inside spark.local.dir:
@ -192,15 +191,19 @@ class SparkContext(object):
SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
if instance:
if SparkContext._active_spark_context and SparkContext._active_spark_context != instance:
if (SparkContext._active_spark_context and
SparkContext._active_spark_context != instance):
currentMaster = SparkContext._active_spark_context.master
currentAppName = SparkContext._active_spark_context.appName
callsite = SparkContext._active_spark_context._callsite
# Raise error if there is already a running Spark context
raise ValueError("Cannot run multiple SparkContexts at once; existing SparkContext(app=%s, master=%s)" \
" created by %s at %s:%s " \
% (currentAppName, currentMaster, callsite.function, callsite.file, callsite.linenum))
raise ValueError(
"Cannot run multiple SparkContexts at once; "
"existing SparkContext(app=%s, master=%s)"
" created by %s at %s:%s "
% (currentAppName, currentMaster,
callsite.function, callsite.file, callsite.linenum))
else:
SparkContext._active_spark_context = instance
@ -290,7 +293,7 @@ class SparkContext(object):
Read a text file from HDFS, a local file system (available on all
nodes), or any Hadoop-supported file system URI, and return it as an
RDD of Strings.
>>> path = os.path.join(tempdir, "sample-text.txt")
>>> with open(path, "w") as testFile:
... testFile.write("Hello world!")
@ -584,11 +587,12 @@ class SparkContext(object):
HTTP, HTTPS or FTP URI.
"""
self.addFile(path)
(dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix
(dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix
if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'):
self._python_includes.append(filename)
sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) # for tests in local mode
# for tests in local mode
sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename))
def setCheckpointDir(self, dirName):
"""
@ -649,9 +653,9 @@ class SparkContext(object):
Cancelled
If interruptOnCancel is set to true for the job group, then job cancellation will result
in Thread.interrupt() being called on the job's executor threads. This is useful to help ensure
that the tasks are actually stopped in a timely manner, but is off by default due to HDFS-1208,
where HDFS may respond to Thread.interrupt() by marking nodes as dead.
in Thread.interrupt() being called on the job's executor threads. This is useful to help
ensure that the tasks are actually stopped in a timely manner, but is off by default due
to HDFS-1208, where HDFS may respond to Thread.interrupt() by marking nodes as dead.
"""
self._jsc.setJobGroup(groupId, description, interruptOnCancel)
@ -688,7 +692,7 @@ class SparkContext(object):
"""
self._jsc.sc().cancelAllJobs()
def runJob(self, rdd, partitionFunc, partitions = None, allowLocal = False):
def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
"""
Executes the given partitionFunc on the specified set of partitions,
returning the result as an array of elements.
@ -703,7 +707,7 @@ class SparkContext(object):
>>> sc.runJob(myRDD, lambda part: [x * x for x in part], [0, 2], True)
[0, 1, 16, 25]
"""
if partitions == None:
if partitions is None:
partitions = range(rdd._jrdd.partitions().size())
javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client)
@ -714,6 +718,7 @@ class SparkContext(object):
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(mappedRDD._collect_iterator_through_file(it))
def _test():
import atexit
import doctest

View file

@ -42,12 +42,12 @@ def should_exit():
def compute_real_exit_code(exit_code):
# SystemExit's code can be integer or string, but os._exit only accepts integers
import numbers
if isinstance(exit_code, numbers.Integral):
return exit_code
else:
return 1
# SystemExit's code can be integer or string, but os._exit only accepts integers
import numbers
if isinstance(exit_code, numbers.Integral):
return exit_code
else:
return 1
def worker(listen_sock):

View file

@ -24,6 +24,7 @@ from subprocess import Popen, PIPE
from threading import Thread
from py4j.java_gateway import java_import, JavaGateway, GatewayClient
def launch_gateway():
SPARK_HOME = os.environ["SPARK_HOME"]

View file

@ -33,10 +33,11 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from pyspark.resultiterable import ResultIterable
def _do_python_join(rdd, other, numPartitions, dispatch):
vs = rdd.map(lambda (k, v): (k, (1, v)))
ws = other.map(lambda (k, v): (k, (2, v)))
return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x : dispatch(x.__iter__()))
return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x: dispatch(x.__iter__()))
def python_join(rdd, other, numPartitions):
@ -85,6 +86,7 @@ def python_cogroup(rdds, numPartitions):
vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)]
union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds)
rdd_len = len(vrdds)
def dispatch(seq):
bufs = [[] for i in range(rdd_len)]
for (n, v) in seq:

View file

@ -164,7 +164,7 @@ def _deserialize_double_vector(ba, offset=0):
nb = len(ba) - offset
if nb < 5:
raise TypeError("_deserialize_double_vector called on a %d-byte array, "
"which is too short" % nb)
"which is too short" % nb)
if ba[offset] == DENSE_VECTOR_MAGIC:
return _deserialize_dense_vector(ba, offset)
elif ba[offset] == SPARSE_VECTOR_MAGIC:
@ -272,6 +272,7 @@ def _serialize_labeled_point(p):
header_float[0] = p.label
return header + serialized_features
def _deserialize_labeled_point(ba, offset=0):
"""Deserialize a LabeledPoint from a mutually understood format."""
from pyspark.mllib.regression import LabeledPoint
@ -283,6 +284,7 @@ def _deserialize_labeled_point(ba, offset=0):
features = _deserialize_double_vector(ba, offset + 9)
return LabeledPoint(label, features)
def _copyto(array, buffer, offset, shape, dtype):
"""
Copy the contents of a vector to a destination bytearray at the

View file

@ -247,6 +247,7 @@ class Vectors(object):
else:
return "[" + ",".join([str(v) for v in vector]) + "]"
def _test():
import doctest
(failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS)

View file

@ -24,7 +24,6 @@ from pyspark.rdd import RDD
from pyspark.serializers import NoOpSerializer
class MLUtils:
"""
Helper methods to load, save and pre-process data used in MLlib.
@ -154,7 +153,6 @@ class MLUtils:
lines = data.map(lambda p: MLUtils._convert_labeled_point_to_libsvm(p))
lines.saveAsTextFile(dir)
@staticmethod
def loadLabeledPoints(sc, path, minPartitions=None):
"""

View file

@ -18,13 +18,16 @@
import sys
import random
class RDDSampler(object):
def __init__(self, withReplacement, fraction, seed=None):
try:
import numpy
self._use_numpy = True
except ImportError:
print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling."
print >> sys.stderr, (
"NumPy does not appear to be installed. "
"Falling back to default random generator for sampling.")
self._use_numpy = False
self._seed = seed if seed is not None else random.randint(0, sys.maxint)
@ -61,7 +64,7 @@ class RDDSampler(object):
def getPoissonSample(self, split, mean):
if not self._rand_initialized or split != self._split:
self.initRandomGenerator(split)
if self._use_numpy:
return self._random.poisson(mean)
else:
@ -80,30 +83,27 @@ class RDDSampler(object):
num_arrivals += 1
return (num_arrivals - 1)
def shuffle(self, vals):
if self._random is None:
self.initRandomGenerator(0) # this should only ever called on the master so
# the split does not matter
if self._use_numpy:
self._random.shuffle(vals)
else:
self._random.shuffle(vals, self._random.random)
def func(self, split, iterator):
if self._withReplacement:
if self._withReplacement:
for obj in iterator:
# For large datasets, the expected number of occurrences of each element in a sample with
# replacement is Poisson(frac). We use that to get a count for each element.
count = self.getPoissonSample(split, mean = self._fraction)
# For large datasets, the expected number of occurrences of each element in
# a sample with replacement is Poisson(frac). We use that to get a count for
# each element.
count = self.getPoissonSample(split, mean=self._fraction)
for _ in range(0, count):
yield obj
else:
for obj in iterator:
if self.getUniformSample(split) <= self._fraction:
yield obj

View file

@ -19,6 +19,7 @@ __all__ = ["ResultIterable"]
import collections
class ResultIterable(collections.Iterable):
"""
A special result iterable. This is used because the standard iterator can not be pickled
@ -27,7 +28,9 @@ class ResultIterable(collections.Iterable):
self.data = data
self.index = 0
self.maxindex = len(data)
def __iter__(self):
return iter(self.data)
def __len__(self):
return len(self.data)

View file

@ -91,7 +91,6 @@ class Serializer(object):
"""
raise NotImplementedError
def _load_stream_without_unbatching(self, stream):
return self.load_stream(stream)
@ -197,8 +196,8 @@ class BatchedSerializer(Serializer):
return self.serializer.load_stream(stream)
def __eq__(self, other):
return isinstance(other, BatchedSerializer) and \
other.serializer == self.serializer
return (isinstance(other, BatchedSerializer) and
other.serializer == self.serializer)
def __str__(self):
return "BatchedSerializer<%s>" % str(self.serializer)
@ -229,8 +228,8 @@ class CartesianDeserializer(FramedSerializer):
yield pair
def __eq__(self, other):
return isinstance(other, CartesianDeserializer) and \
self.key_ser == other.key_ser and self.val_ser == other.val_ser
return (isinstance(other, CartesianDeserializer) and
self.key_ser == other.key_ser and self.val_ser == other.val_ser)
def __str__(self):
return "CartesianDeserializer<%s, %s>" % \
@ -252,18 +251,20 @@ class PairDeserializer(CartesianDeserializer):
yield pair
def __eq__(self, other):
return isinstance(other, PairDeserializer) and \
self.key_ser == other.key_ser and self.val_ser == other.val_ser
return (isinstance(other, PairDeserializer) and
self.key_ser == other.key_ser and self.val_ser == other.val_ser)
def __str__(self):
return "PairDeserializer<%s, %s>" % \
(str(self.key_ser), str(self.val_ser))
return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser))
class NoOpSerializer(FramedSerializer):
def loads(self, obj): return obj
def dumps(self, obj): return obj
def loads(self, obj):
return obj
def dumps(self, obj):
return obj
class PickleSerializer(FramedSerializer):
@ -276,12 +277,16 @@ class PickleSerializer(FramedSerializer):
not be as fast as more specialized serializers.
"""
def dumps(self, obj): return cPickle.dumps(obj, 2)
def dumps(self, obj):
return cPickle.dumps(obj, 2)
loads = cPickle.loads
class CloudPickleSerializer(PickleSerializer):
def dumps(self, obj): return cloudpickle.dumps(obj, 2)
def dumps(self, obj):
return cloudpickle.dumps(obj, 2)
class MarshalSerializer(FramedSerializer):

View file

@ -35,7 +35,8 @@ from pyspark.context import SparkContext
from pyspark.storagelevel import StorageLevel
# this is the equivalent of ADD_JARS
add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("ADD_FILES") is not None else None
add_files = (os.environ.get("ADD_FILES").split(',')
if os.environ.get("ADD_FILES") is not None else None)
if os.environ.get("SPARK_EXECUTOR_URI"):
SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"])

View file

@ -30,7 +30,7 @@ class SQLContext:
tables, execute SQL over tables, cache tables, and read parquet files.
"""
def __init__(self, sparkContext, sqlContext = None):
def __init__(self, sparkContext, sqlContext=None):
"""Create a new SQLContext.
@param sparkContext: The SparkContext to wrap.
@ -137,7 +137,6 @@ class SQLContext:
jschema_rdd = self._ssql_ctx.parquetFile(path)
return SchemaRDD(jschema_rdd, self)
def jsonFile(self, path):
"""Loads a text file storing one JSON object per line,
returning the result as a L{SchemaRDD}.
@ -234,8 +233,8 @@ class HiveContext(SQLContext):
self._scala_HiveContext = self._get_hive_ctx()
return self._scala_HiveContext
except Py4JError as e:
raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \
"sbt/sbt assembly" , e)
raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run "
"sbt/sbt assembly", e)
def _get_hive_ctx(self):
return self._jvm.HiveContext(self._jsc.sc())
@ -377,7 +376,7 @@ class SchemaRDD(RDD):
"""
self._jschema_rdd.registerAsTable(name)
def insertInto(self, tableName, overwrite = False):
def insertInto(self, tableName, overwrite=False):
"""Inserts the contents of this SchemaRDD into the specified table.
Optionally overwriting any existing data.
@ -420,7 +419,7 @@ class SchemaRDD(RDD):
# in Java land in the javaToPython function. May require a custom
# pickle serializer in Pyrolite
return RDD(jrdd, self._sc, BatchedSerializer(
PickleSerializer())).map(lambda d: Row(d))
PickleSerializer())).map(lambda d: Row(d))
# We override the default cache/persist/checkpoint behavior as we want to cache the underlying
# SchemaRDD object in the JVM, not the PythonRDD checkpointed by the super class
@ -483,6 +482,7 @@ class SchemaRDD(RDD):
else:
raise ValueError("Can only subtract another SchemaRDD")
def _test():
import doctest
from array import array
@ -493,20 +493,25 @@ def _test():
sc = SparkContext('local[4]', 'PythonTest', batchSize=2)
globs['sc'] = sc
globs['sqlCtx'] = SQLContext(sc)
globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"},
{"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
jsonStrings = ['{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]}, "field6":[{"field7": "row2"}]}',
'{"field1" : null, "field2": "row3", "field3":{"field4":33, "field5": []}}']
globs['rdd'] = sc.parallelize(
[{"field1": 1, "field2": "row1"},
{"field1": 2, "field2": "row2"},
{"field1": 3, "field2": "row3"}]
)
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]}, "field6":[{"field7": "row2"}]}',
'{"field1" : null, "field2": "row3", "field3":{"field4":33, "field5": []}}'
]
globs['jsonStrings'] = jsonStrings
globs['json'] = sc.parallelize(jsonStrings)
globs['nestedRdd1'] = sc.parallelize([
{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
{"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}])
{"f1": array('i', [1, 2]), "f2": {"row1": 1.0}},
{"f1": array('i', [2, 3]), "f2": {"row2": 2.0}}])
globs['nestedRdd2'] = sc.parallelize([
{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
{"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}])
(failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
{"f1": [[1, 2], [2, 3]], "f2": set([1, 2]), "f3": (1, 2)},
{"f1": [[2, 3], [3, 4]], "f2": set([2, 3]), "f3": (2, 3)}])
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)
@ -514,4 +519,3 @@ def _test():
if __name__ == "__main__":
_test()

View file

@ -20,18 +20,19 @@
import copy
import math
class StatCounter(object):
def __init__(self, values=[]):
self.n = 0L # Running count of our values
self.mu = 0.0 # Running mean of our values
self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2)
self.maxValue = float("-inf")
self.minValue = float("inf")
for v in values:
self.merge(v)
# Add a value into this StatCounter, updating the internal statistics.
def merge(self, value):
delta = value - self.mu
@ -42,7 +43,7 @@ class StatCounter(object):
self.maxValue = value
if self.minValue > value:
self.minValue = value
return self
# Merge another StatCounter into this one, adding up the internal statistics.
@ -50,7 +51,7 @@ class StatCounter(object):
if not isinstance(other, StatCounter):
raise Exception("Can only merge Statcounters!")
if other is self: # reference equality holds
if other is self: # reference equality holds
self.merge(copy.deepcopy(other)) # Avoid overwriting fields in a weird order
else:
if self.n == 0:
@ -59,8 +60,8 @@ class StatCounter(object):
self.n = other.n
self.maxValue = other.maxValue
self.minValue = other.minValue
elif other.n != 0:
elif other.n != 0:
delta = other.mu - self.mu
if other.n * 10 < self.n:
self.mu = self.mu + (delta * other.n) / (self.n + other.n)
@ -68,10 +69,10 @@ class StatCounter(object):
self.mu = other.mu - (delta * self.n) / (self.n + other.n)
else:
self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n)
self.maxValue = max(self.maxValue, other.maxValue)
self.minValue = min(self.minValue, other.minValue)
self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n)
self.n += other.n
return self
@ -94,7 +95,7 @@ class StatCounter(object):
def max(self):
return self.maxValue
# Return the variance of the values.
def variance(self):
if self.n == 0:
@ -124,5 +125,5 @@ class StatCounter(object):
return math.sqrt(self.sampleVariance())
def __repr__(self):
return "(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" % (self.count(), self.mean(), self.stdev(), self.max(), self.min())
return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" %
(self.count(), self.mean(), self.stdev(), self.max(), self.min()))

View file

@ -17,6 +17,7 @@
__all__ = ["StorageLevel"]
class StorageLevel:
"""
Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory,
@ -25,7 +26,7 @@ class StorageLevel:
Also contains static constants for some commonly used storage levels, such as MEMORY_ONLY.
"""
def __init__(self, useDisk, useMemory, useOffHeap, deserialized, replication = 1):
def __init__(self, useDisk, useMemory, useOffHeap, deserialized, replication=1):
self.useDisk = useDisk
self.useMemory = useMemory
self.useOffHeap = useOffHeap
@ -55,4 +56,4 @@ StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, False, True)
StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, False, True, 2)
StorageLevel.MEMORY_AND_DISK_SER = StorageLevel(True, True, False, False)
StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel(True, True, False, False, 2)
StorageLevel.OFF_HEAP = StorageLevel(False, False, True, False, 1)
StorageLevel.OFF_HEAP = StorageLevel(False, False, True, False, 1)

View file

@ -52,12 +52,13 @@ class PySparkTestCase(unittest.TestCase):
def setUp(self):
self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__
self.sc = SparkContext('local[4]', class_name , batchSize=2)
self.sc = SparkContext('local[4]', class_name, batchSize=2)
def tearDown(self):
self.sc.stop()
sys.path = self._old_sys_path
class TestCheckpoint(PySparkTestCase):
def setUp(self):
@ -190,6 +191,7 @@ class TestRDDFunctions(PySparkTestCase):
def testAggregateByKey(self):
data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
def seqOp(x, y):
x.add(y)
return x
@ -197,17 +199,19 @@ class TestRDDFunctions(PySparkTestCase):
def combOp(x, y):
x |= y
return x
sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect())
self.assertEqual(3, len(sets))
self.assertEqual(set([1]), sets[1])
self.assertEqual(set([2]), sets[3])
self.assertEqual(set([1, 3]), sets[5])
class TestIO(PySparkTestCase):
def test_stdout_redirection(self):
import subprocess
def func(x):
subprocess.check_call('ls', shell=True)
self.sc.parallelize([1]).foreach(func)
@ -479,7 +483,7 @@ class TestSparkSubmit(unittest.TestCase):
| return x + 1
""")
proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, script],
stdout=subprocess.PIPE)
stdout=subprocess.PIPE)
out, err = proc.communicate()
self.assertEqual(0, proc.returncode)
self.assertIn("[2, 3, 4]", out)

View file

@ -57,8 +57,8 @@ def main(infile, outfile):
SparkFiles._is_running_on_worker = True
# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
for _ in range(num_python_includes):
filename = utf8_deserializer.loads(infile)
sys.path.append(os.path.join(spark_files_dir, filename))