From 823878c77f41cab39b4b06f0168f2ebcadd24dba Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 17 Aug 2012 15:52:42 -0700 Subject: [PATCH] add accumulators for mutable collections, with correct typing! --- core/src/main/scala/spark/Accumulators.scala | 15 +++++++++- core/src/main/scala/spark/SparkContext.scala | 6 +++- .../test/scala/spark/AccumulatorSuite.scala | 30 +++++++++++++++++-- 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index bf77417852..cc5bed257b 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -3,6 +3,7 @@ package spark import java.io._ import scala.collection.mutable.Map +import collection.generic.Growable class Accumulable[T,R] ( @transient initialValue: T, @@ -99,6 +100,18 @@ trait AccumulableParam[R,T] extends Serializable { def zero(initialValue: R): R } +class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T] extends AccumulableParam[R,T] { + def addAccumulator(growable: R, elem: T) : R = { + growable += elem + growable + } + def addInPlace(t1: R, t2: R) : R = { + t1 ++= t2 + t1 + } + def zero(initialValue: R) = initialValue +} + // TODO: The multi-thread support in accumulators is kind of lame; check // if there's a more intuitive way of doing it right private object Accumulators { @@ -143,4 +156,4 @@ private object Accumulators { } } } -} +} \ No newline at end of file diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index e220972e8f..4c45c7be68 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -4,7 +4,6 @@ import java.io._ import java.util.concurrent.atomic.AtomicInteger import scala.actors.remote.RemoteActor -import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration @@ -31,6 +30,7 @@ import org.apache.hadoop.mapreduce.{Job => NewHadoopJob} import org.apache.mesos.MesosNativeLibrary import spark.broadcast._ +import collection.generic.Growable class SparkContext( master: String, @@ -253,6 +253,10 @@ class SparkContext( def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) = new Accumulable(initialValue, param) + def accumlableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = { + val param = new GrowableAccumulableParam[R,T] + new Accumulable(initialValue, param) + } // Keep around a weak hash map of values to Cached versions? def broadcast[T](value: T) = Broadcast.getBroadcastFactory.newBroadcast[T] (value, isLocal) diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala index a59b77fc85..68230c4b92 100644 --- a/core/src/test/scala/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/spark/AccumulatorSuite.scala @@ -3,9 +3,6 @@ package spark import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers import collection.mutable -import java.util.Random -import scala.math.exp -import scala.math.signum import spark.SparkContext._ class AccumulatorSuite extends FunSuite with ShouldMatchers { @@ -79,4 +76,31 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers { } } + test ("collection accumulators") { + val maxI = 1000 + for (nThreads <- List(1, 10)) { + //test single & multi-threaded + val sc = new SparkContext("local[" + nThreads + "]", "test") + val setAcc = sc.accumlableCollection(mutable.HashSet[Int]()) + val bufferAcc = sc.accumlableCollection(mutable.ArrayBuffer[Int]()) + val mapAcc = sc.accumlableCollection(mutable.HashMap[Int,String]()) + val d = sc.parallelize( (1 to maxI) ++ (1 to maxI)) + d.foreach { + x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)} + } + + //NOTE that this is typed correctly -- no casts necessary + setAcc.value.size should be (maxI) + bufferAcc.value.size should be (2 * maxI) + mapAcc.value.size should be (maxI) + for (i <- 1 to maxI) { + setAcc.value should contain(i) + bufferAcc.value should contain(i) + mapAcc.value should contain (i -> i.toString) + } + sc.stop() + } + + } + } \ No newline at end of file