Fix Accumulators in Java, and add a test for them
This commit is contained in:
parent
3f74f729a1
commit
ecf9c08901
|
@ -38,14 +38,28 @@ class Accumulable[R, T] (
|
|||
*/
|
||||
def += (term: T) { value_ = param.addAccumulator(value_, term) }
|
||||
|
||||
/**
|
||||
* Add more data to this accumulator / accumulable
|
||||
* @param term the data to add
|
||||
*/
|
||||
def add(term: T) { value_ = param.addAccumulator(value_, term) }
|
||||
|
||||
/**
|
||||
* Merge two accumulable objects together
|
||||
*
|
||||
* Normally, a user will not want to use this version, but will instead call `+=`.
|
||||
* @param term the other Accumulable that will get merged with this
|
||||
* @param term the other `R` that will get merged with this
|
||||
*/
|
||||
def ++= (term: R) { value_ = param.addInPlace(value_, term)}
|
||||
|
||||
/**
|
||||
* Merge two accumulable objects together
|
||||
*
|
||||
* Normally, a user will not want to use this version, but will instead call `add`.
|
||||
* @param term the other `R` that will get merged with this
|
||||
*/
|
||||
def merge(term: R) { value_ = param.addInPlace(value_, term)}
|
||||
|
||||
/**
|
||||
* Access the accumulator's current value; only allowed on master.
|
||||
*/
|
||||
|
|
|
@ -382,11 +382,12 @@ class SparkContext(
|
|||
new Accumulator(initialValue, param)
|
||||
|
||||
/**
|
||||
* Create an [[spark.Accumulable]] shared variable, with a `+=` method
|
||||
* Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`.
|
||||
* Only the master can access the accumuable's `value`.
|
||||
* @tparam T accumulator type
|
||||
* @tparam R type that can be added to the accumulator
|
||||
*/
|
||||
def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
|
||||
def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) =
|
||||
new Accumulable(initialValue, param)
|
||||
|
||||
/**
|
||||
|
@ -404,7 +405,7 @@ class SparkContext(
|
|||
* Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
|
||||
* reading it in distributed functions. The variable will be sent to each cluster only once.
|
||||
*/
|
||||
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal)
|
||||
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
|
||||
|
||||
/**
|
||||
* Add a file to be downloaded into the working directory of this Spark job on every node.
|
||||
|
|
|
@ -10,7 +10,7 @@ import org.apache.hadoop.mapred.InputFormat
|
|||
import org.apache.hadoop.mapred.JobConf
|
||||
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
|
||||
|
||||
import spark.{Accumulator, AccumulatorParam, RDD, SparkContext}
|
||||
import spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext}
|
||||
import spark.SparkContext.IntAccumulatorParam
|
||||
import spark.SparkContext.DoubleAccumulatorParam
|
||||
import spark.broadcast.Broadcast
|
||||
|
@ -265,25 +265,32 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
|
|||
|
||||
/**
|
||||
* Create an [[spark.Accumulator]] integer variable, which tasks can "add" values
|
||||
* to using the `+=` method. Only the master can access the accumulator's `value`.
|
||||
* to using the `add` method. Only the master can access the accumulator's `value`.
|
||||
*/
|
||||
def intAccumulator(initialValue: Int): Accumulator[Int] =
|
||||
sc.accumulator(initialValue)(IntAccumulatorParam)
|
||||
def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] =
|
||||
sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]]
|
||||
|
||||
/**
|
||||
* Create an [[spark.Accumulator]] double variable, which tasks can "add" values
|
||||
* to using the `+=` method. Only the master can access the accumulator's `value`.
|
||||
* to using the `add` method. Only the master can access the accumulator's `value`.
|
||||
*/
|
||||
def doubleAccumulator(initialValue: Double): Accumulator[Double] =
|
||||
sc.accumulator(initialValue)(DoubleAccumulatorParam)
|
||||
def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] =
|
||||
sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]]
|
||||
|
||||
/**
|
||||
* Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values
|
||||
* to using the `+=` method. Only the master can access the accumulator's `value`.
|
||||
* to using the `add` method. Only the master can access the accumulator's `value`.
|
||||
*/
|
||||
def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] =
|
||||
sc.accumulator(initialValue)(accumulatorParam)
|
||||
|
||||
/**
|
||||
* Create an [[spark.Accumulable]] shared variable of the given type, to which tasks can
|
||||
* "add" values with `add`. Only the master can access the accumuable's `value`.
|
||||
*/
|
||||
def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] =
|
||||
sc.accumulable(initialValue)(param)
|
||||
|
||||
/**
|
||||
* Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
|
||||
* reading it in distributed functions. The variable will be sent to each cluster only once.
|
||||
|
|
|
@ -581,4 +581,48 @@ public class JavaAPISuite implements Serializable {
|
|||
JavaPairRDD<Integer, Double> zipped = rdd.zip(doubles);
|
||||
zipped.count();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void accumulators() {
|
||||
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
|
||||
|
||||
final Accumulator<Integer> intAccum = sc.intAccumulator(10);
|
||||
rdd.foreach(new VoidFunction<Integer>() {
|
||||
public void call(Integer x) {
|
||||
intAccum.add(x);
|
||||
}
|
||||
});
|
||||
Assert.assertEquals((Integer) 25, intAccum.value());
|
||||
|
||||
final Accumulator<Double> doubleAccum = sc.doubleAccumulator(10.0);
|
||||
rdd.foreach(new VoidFunction<Integer>() {
|
||||
public void call(Integer x) {
|
||||
doubleAccum.add((double) x);
|
||||
}
|
||||
});
|
||||
Assert.assertEquals((Double) 25.0, doubleAccum.value());
|
||||
|
||||
// Try a custom accumulator type
|
||||
AccumulatorParam<Float> floatAccumulatorParam = new AccumulatorParam<Float>() {
|
||||
public Float addInPlace(Float r, Float t) {
|
||||
return r + t;
|
||||
}
|
||||
|
||||
public Float addAccumulator(Float r, Float t) {
|
||||
return r + t;
|
||||
}
|
||||
|
||||
public Float zero(Float initialValue) {
|
||||
return 0.0f;
|
||||
}
|
||||
};
|
||||
|
||||
final Accumulator<Float> floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam);
|
||||
rdd.foreach(new VoidFunction<Integer>() {
|
||||
public void call(Integer x) {
|
||||
floatAccum.add((float) x);
|
||||
}
|
||||
});
|
||||
Assert.assertEquals((Float) 25.0f, floatAccum.value());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue