diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index bf77417852..9b273ff62f 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -35,7 +35,16 @@ class Accumulable[T,R] ( else throw new UnsupportedOperationException("Can't use read value in task") } - private[spark] def localValue = value_ + /** + * get the current value of this accumulator from within a task. + * + * This is NOT the global value of the accumulator. To get the global value after a + * completed operation on the dataset, call `value`. + * + * The typical use of this method is to directly mutate the local value, eg., to add + * an element to a Set. + */ + def localValue = value_ def value_= (t: T) { if (!deserialized) value_ = t diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala index a59b77fc85..e3bf0a2d6c 100644 --- a/core/src/test/scala/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/spark/AccumulatorSuite.scala @@ -79,4 +79,18 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers { } } + test ("localValue readable in tasks") { + import SetAccum._ + val maxI = 1000 + for (nThreads <- List(1, 10)) { //test single & multi-threaded + val sc = new SparkContext("local[" + nThreads + "]", "test") + val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) + val d = sc.parallelize(1 to maxI) + d.foreach { + x => acc.localValue += x + } + acc.value should be ( (1 to maxI).toSet) + } + } + } \ No newline at end of file