Add an add() method to pyspark accumulators.
Add a regular method for adding a term to accumulators in pyspark. Currently if you have a non-global accumulator, adding to it is awkward. The += operator can't be used for non-global accumulators captured via closure because it's involves an assignment. The only way to do it is using __iadd__ directly. Adding this method lets you write code like this: def main(): sc = SparkContext() accum = sc.accumulator(0) rdd = sc.parallelize([1,2,3]) def f(x): accum.add(x) rdd.foreach(f) print accum.value where using accum += x instead would have caused UnboundLocalError exceptions in workers. Currently it would have to be written as accum.__iadd__(x).
This commit is contained in:
parent
6511bbe2ad
commit
7eaa56de7f
|
@ -42,6 +42,13 @@
|
|||
>>> a.value
|
||||
13
|
||||
|
||||
>>> b = sc.accumulator(0)
|
||||
>>> def g(x):
|
||||
... b.add(x)
|
||||
>>> rdd.foreach(g)
|
||||
>>> b.value
|
||||
6
|
||||
|
||||
>>> from pyspark.accumulators import AccumulatorParam
|
||||
>>> class VectorAccumulatorParam(AccumulatorParam):
|
||||
... def zero(self, value):
|
||||
|
@ -139,9 +146,13 @@ class Accumulator(object):
|
|||
raise Exception("Accumulator.value cannot be accessed inside tasks")
|
||||
self._value = value
|
||||
|
||||
def add(self, term):
|
||||
"""Adds a term to this accumulator's value"""
|
||||
self._value = self.accum_param.addInPlace(self._value, term)
|
||||
|
||||
def __iadd__(self, term):
|
||||
"""The += operator; adds a term to this accumulator's value"""
|
||||
self._value = self.accum_param.addInPlace(self._value, term)
|
||||
self.add(term)
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
|
|
Loading…
Reference in a new issue