Fix Accumulators in Java, and add a test for them

This commit is contained in:
Matei Zaharia 2013-01-05 20:54:08 -05:00
parent 3f74f729a1
commit ecf9c08901
4 changed files with 79 additions and 13 deletions

View file

@ -38,14 +38,28 @@ class Accumulable[R, T] (
*/
def += (term: T) { value_ = param.addAccumulator(value_, term) }
/**
* Add more data to this accumulator / accumulable
* @param term the data to add
*/
def add(term: T) { value_ = param.addAccumulator(value_, term) }
/**
* Merge two accumulable objects together
*
* Normally, a user will not want to use this version, but will instead call `+=`.
* @param term the other Accumulable that will get merged with this
* @param term the other `R` that will get merged with this
*/
def ++= (term: R) { value_ = param.addInPlace(value_, term)}
/**
* Merge two accumulable objects together
*
* Normally, a user will not want to use this version, but will instead call `add`.
* @param term the other `R` that will get merged with this
*/
def merge(term: R) { value_ = param.addInPlace(value_, term)}
/**
* Access the accumulator's current value; only allowed on master.
*/

View file

@ -382,11 +382,12 @@ class SparkContext(
new Accumulator(initialValue, param)
/**
* Create an [[spark.Accumulable]] shared variable, with a `+=` method
* Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`.
* Only the master can access the accumuable's `value`.
* @tparam T accumulator type
* @tparam R type that can be added to the accumulator
*/
def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) =
new Accumulable(initialValue, param)
/**
@ -404,7 +405,7 @@ class SparkContext(
* Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once.
*/
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal)
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
/**
* Add a file to be downloaded into the working directory of this Spark job on every node.

View file

@ -10,7 +10,7 @@ import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import spark.{Accumulator, AccumulatorParam, RDD, SparkContext}
import spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext}
import spark.SparkContext.IntAccumulatorParam
import spark.SparkContext.DoubleAccumulatorParam
import spark.broadcast.Broadcast
@ -265,25 +265,32 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
/**
* Create an [[spark.Accumulator]] integer variable, which tasks can "add" values
* to using the `+=` method. Only the master can access the accumulator's `value`.
* to using the `add` method. Only the master can access the accumulator's `value`.
*/
def intAccumulator(initialValue: Int): Accumulator[Int] =
sc.accumulator(initialValue)(IntAccumulatorParam)
def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] =
sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]]
/**
* Create an [[spark.Accumulator]] double variable, which tasks can "add" values
* to using the `+=` method. Only the master can access the accumulator's `value`.
* to using the `add` method. Only the master can access the accumulator's `value`.
*/
def doubleAccumulator(initialValue: Double): Accumulator[Double] =
sc.accumulator(initialValue)(DoubleAccumulatorParam)
def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] =
sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]]
/**
* Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values
* to using the `+=` method. Only the master can access the accumulator's `value`.
* to using the `add` method. Only the master can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] =
sc.accumulator(initialValue)(accumulatorParam)
/**
* Create an [[spark.Accumulable]] shared variable of the given type, to which tasks can
* "add" values with `add`. Only the master can access the accumuable's `value`.
*/
def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] =
sc.accumulable(initialValue)(param)
/**
* Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once.

View file

@ -581,4 +581,48 @@ public class JavaAPISuite implements Serializable {
JavaPairRDD<Integer, Double> zipped = rdd.zip(doubles);
zipped.count();
}
@Test
public void accumulators() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
final Accumulator<Integer> intAccum = sc.intAccumulator(10);
rdd.foreach(new VoidFunction<Integer>() {
public void call(Integer x) {
intAccum.add(x);
}
});
Assert.assertEquals((Integer) 25, intAccum.value());
final Accumulator<Double> doubleAccum = sc.doubleAccumulator(10.0);
rdd.foreach(new VoidFunction<Integer>() {
public void call(Integer x) {
doubleAccum.add((double) x);
}
});
Assert.assertEquals((Double) 25.0, doubleAccum.value());
// Try a custom accumulator type
AccumulatorParam<Float> floatAccumulatorParam = new AccumulatorParam<Float>() {
public Float addInPlace(Float r, Float t) {
return r + t;
}
public Float addAccumulator(Float r, Float t) {
return r + t;
}
public Float zero(Float initialValue) {
return 0.0f;
}
};
final Accumulator<Float> floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam);
rdd.foreach(new VoidFunction<Integer>() {
public void call(Integer x) {
floatAccum.add((float) x);
}
});
Assert.assertEquals((Float) 25.0f, floatAccum.value());
}
}