
229 lines
7 KiB
Raw Normal View History

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
2013-01-20 04:57:44 -05:00
>>> from pyspark.context import SparkContext
>>> sc = SparkContext('local', 'test')
>>> a = sc.accumulator(1)
>>> a.value
>>> a.value = 2
>>> a.value
>>> a += 5
>>> a.value
>>> sc.accumulator(1.0).value
>>> sc.accumulator(1j).value
2013-01-20 04:57:44 -05:00
>>> rdd = sc.parallelize([1,2,3])
>>> def f(x):
... global a
... a += x
>>> rdd.foreach(f)
>>> a.value
>>> b = sc.accumulator(0)
>>> def g(x):
... b.add(x)
>>> rdd.foreach(g)
>>> b.value
>>> from pyspark.accumulators import AccumulatorParam
>>> class VectorAccumulatorParam(AccumulatorParam):
2013-01-20 04:57:44 -05:00
... def zero(self, value):
... return [0.0] * len(value)
... def addInPlace(self, val1, val2):
... for i in xrange(len(val1)):
... val1[i] += val2[i]
... return val1
>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam())
>>> va.value
[1.0, 2.0, 3.0]
>>> def g(x):
... global va
... va += [x] * 3
>>> rdd.foreach(g)
>>> va.value
[7.0, 8.0, 9.0]
>>> x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
>>> def h(x):
... global a
... a.value = 7
>>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
>>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
import struct
import SocketServer
import threading
from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import read_int, PickleSerializer
2013-01-20 04:57:44 -05:00
pickleSer = PickleSerializer()
2013-01-20 04:57:44 -05:00
# Holds accumulators registered on the current machine, keyed by ID. This is then used to send
# the local accumulator updates back to the driver program at the end of a task.
_accumulatorRegistry = {}
def _deserialize_accumulator(aid, zero_value, accum_param):
from pyspark.accumulators import _accumulatorRegistry
accum = Accumulator(aid, zero_value, accum_param)
accum._deserialized = True
_accumulatorRegistry[aid] = accum
return accum
class Accumulator(object):
2013-01-20 05:10:25 -05:00
A shared variable that can be accumulated, i.e., has a commutative and associative "add"
operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=}
operator, but only the driver program is allowed to access its value, using C{value}.
Updates from the workers get propagated automatically to the driver program.
While C{SparkContext} supports accumulators for primitive data types like C{int} and
C{float}, users can also define accumulators for custom types by providing a custom
L{AccumulatorParam} object. Refer to the doctest of this module for an example.
2013-01-20 05:10:25 -05:00
2013-01-20 04:57:44 -05:00
def __init__(self, aid, value, accum_param):
"""Create a new Accumulator with a given initial value and AccumulatorParam object"""
from pyspark.accumulators import _accumulatorRegistry
self.aid = aid
self.accum_param = accum_param
self._value = value
self._deserialized = False
_accumulatorRegistry[aid] = self
def __reduce__(self):
"""Custom serialization; saves the zero value from our AccumulatorParam"""
param = self.accum_param
return (_deserialize_accumulator, (self.aid,, param))
def value(self):
"""Get the accumulator's value; only usable in driver program"""
if self._deserialized:
raise Exception("Accumulator.value cannot be accessed inside tasks")
return self._value
def value(self, value):
"""Sets the accumulator's value; only usable in driver program"""
if self._deserialized:
raise Exception("Accumulator.value cannot be accessed inside tasks")
self._value = value
def add(self, term):
"""Adds a term to this accumulator's value"""
self._value = self.accum_param.addInPlace(self._value, term)
2013-01-20 04:57:44 -05:00
def __iadd__(self, term):
"""The += operator; adds a term to this accumulator's value"""
2013-01-20 04:57:44 -05:00
return self
def __str__(self):
return str(self._value)
def __repr__(self):
return "Accumulator<id=%i, value=%s>" % (self.aid, self._value)
2013-01-20 04:57:44 -05:00
class AccumulatorParam(object):
Helper object that defines how to accumulate values of a given type.
def zero(self, value):
Provide a "zero value" for the type, compatible in dimensions with the
provided C{value} (e.g., a zero vector)
raise NotImplementedError
def addInPlace(self, value1, value2):
Add two values of the accumulator's data type, returning a new value;
for efficiency, can also update C{value1} in place and return it.
raise NotImplementedError
class AddingAccumulatorParam(AccumulatorParam):
2013-01-20 04:57:44 -05:00
An AccumulatorParam that uses the + operators to add values. Designed for simple types
such as integers, floats, and lists. Requires the zero value for the underlying type
as a parameter.
def __init__(self, zero_value):
self.zero_value = zero_value
def zero(self, value):
return self.zero_value
def addInPlace(self, value1, value2):
value1 += value2
return value1
# Singleton accumulator params for some standard types
INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0)
FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
2013-01-20 04:57:44 -05:00
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
def handle(self):
from pyspark.accumulators import _accumulatorRegistry
num_updates = read_int(self.rfile)
for _ in range(num_updates):
(aid, update) = pickleSer._read_with_length(self.rfile)
2013-01-20 04:57:44 -05:00
_accumulatorRegistry[aid] += update
# Write a byte in acknowledgement
self.wfile.write(struct.pack("!b", 1))
def _start_update_server():
"""Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler)
thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
return server