Add an add() method to pyspark accumulators.

Add a regular method for adding a term to accumulators in
pyspark. Currently if you have a non-global accumulator, adding to it
is awkward. The += operator can't be used for non-global accumulators
captured via closure because it's involves an assignment. The only way
to do it is using __iadd__ directly.

Adding this method lets you write code like this:

def main():
    sc = SparkContext()
    accum = sc.accumulator(0)

    rdd = sc.parallelize([1,2,3])
    def f(x):
        accum.add(x)
    rdd.foreach(f)
    print accum.value

where using accum += x instead would have caused UnboundLocalError
exceptions in workers. Currently it would have to be written as
accum.__iadd__(x).
This commit is contained in:
Ewen Cheslack-Postava 2013-10-19 19:55:39 -07:00
parent 6511bbe2ad
commit 7eaa56de7f

View file

@ -42,6 +42,13 @@
>>> a.value
13
>>> b = sc.accumulator(0)
>>> def g(x):
... b.add(x)
>>> rdd.foreach(g)
>>> b.value
6
>>> from pyspark.accumulators import AccumulatorParam
>>> class VectorAccumulatorParam(AccumulatorParam):
... def zero(self, value):
@ -139,9 +146,13 @@ class Accumulator(object):
raise Exception("Accumulator.value cannot be accessed inside tasks")
self._value = value
def add(self, term):
"""Adds a term to this accumulator's value"""
self._value = self.accum_param.addInPlace(self._value, term)
def __iadd__(self, term):
"""The += operator; adds a term to this accumulator's value"""
self._value = self.accum_param.addInPlace(self._value, term)
self.add(term)
return self
def __str__(self):