make accumulator.localValue public, add tests
This commit is contained in:
parent
680df96c43
commit
206a3833ce
|
@ -35,7 +35,16 @@ class Accumulable[T,R] (
|
||||||
else throw new UnsupportedOperationException("Can't use read value in task")
|
else throw new UnsupportedOperationException("Can't use read value in task")
|
||||||
}
|
}
|
||||||
|
|
||||||
private[spark] def localValue = value_
|
/**
|
||||||
|
* get the current value of this accumulator from within a task.
|
||||||
|
*
|
||||||
|
* This is NOT the global value of the accumulator. To get the global value after a
|
||||||
|
* completed operation on the dataset, call `value`.
|
||||||
|
*
|
||||||
|
* The typical use of this method is to directly mutate the local value, eg., to add
|
||||||
|
* an element to a Set.
|
||||||
|
*/
|
||||||
|
def localValue = value_
|
||||||
|
|
||||||
def value_= (t: T) {
|
def value_= (t: T) {
|
||||||
if (!deserialized) value_ = t
|
if (!deserialized) value_ = t
|
||||||
|
|
|
@ -79,4 +79,18 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test ("localValue readable in tasks") {
|
||||||
|
import SetAccum._
|
||||||
|
val maxI = 1000
|
||||||
|
for (nThreads <- List(1, 10)) { //test single & multi-threaded
|
||||||
|
val sc = new SparkContext("local[" + nThreads + "]", "test")
|
||||||
|
val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
|
||||||
|
val d = sc.parallelize(1 to maxI)
|
||||||
|
d.foreach {
|
||||||
|
x => acc.localValue += x
|
||||||
|
}
|
||||||
|
acc.value should be ( (1 to maxI).toSet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in a new issue