diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index 2c29437f64..b4686dcafb 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -49,7 +49,16 @@ class Accumulable[R, T] ( else throw new UnsupportedOperationException("Can't read accumulator value in task") } - private[spark] def localValue = value_ + /** + * get the current value of this accumulator from within a task. + * + * This is NOT the global value of the accumulator. To get the global value after a + * completed operation on the dataset, call `value`. + * + * The typical use of this method is to directly mutate the local value, eg., to add + * an element to a Set. + */ + def localValue = value_ def value_= (r: R) { if (!deserialized) value_ = r diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala index 71df5941e5..ee14820c6a 100644 --- a/core/src/test/scala/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/spark/AccumulatorSuite.scala @@ -112,4 +112,19 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter sc.stop() } } + + test ("localValue readable in tasks") { + import SetAccum._ + val maxI = 1000 + for (nThreads <- List(1, 10)) { //test single & multi-threaded + val sc = new SparkContext("local[" + nThreads + "]", "test") + val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) + val d = sc.parallelize(1 to maxI) + d.foreach { + x => acc.localValue += x + } + acc.value should be ( (1 to maxI).toSet) + } + } + }