add accumulators for mutable collections, with correct typing!

This commit is contained in:
Imran Rashid 2012-08-17 15:52:42 -07:00
parent 680df96c43
commit 823878c77f
3 changed files with 46 additions and 5 deletions

View file

@ -3,6 +3,7 @@ package spark
import java.io._
import scala.collection.mutable.Map
import collection.generic.Growable
class Accumulable[T,R] (
@transient initialValue: T,
@ -99,6 +100,18 @@ trait AccumulableParam[R,T] extends Serializable {
def zero(initialValue: R): R
}
class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T] extends AccumulableParam[R,T] {
def addAccumulator(growable: R, elem: T) : R = {
growable += elem
growable
}
def addInPlace(t1: R, t2: R) : R = {
t1 ++= t2
t1
}
def zero(initialValue: R) = initialValue
}
// TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right
private object Accumulators {

View file

@ -4,7 +4,6 @@ import java.io._
import java.util.concurrent.atomic.AtomicInteger
import scala.actors.remote.RemoteActor
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
@ -31,6 +30,7 @@ import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
import org.apache.mesos.MesosNativeLibrary
import spark.broadcast._
import collection.generic.Growable
class SparkContext(
master: String,
@ -253,6 +253,10 @@ class SparkContext(
def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
new Accumulable(initialValue, param)
def accumlableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = {
val param = new GrowableAccumulableParam[R,T]
new Accumulable(initialValue, param)
}
// Keep around a weak hash map of values to Cached versions?
def broadcast[T](value: T) = Broadcast.getBroadcastFactory.newBroadcast[T] (value, isLocal)

View file

@ -3,9 +3,6 @@ package spark
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import collection.mutable
import java.util.Random
import scala.math.exp
import scala.math.signum
import spark.SparkContext._
class AccumulatorSuite extends FunSuite with ShouldMatchers {
@ -79,4 +76,31 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers {
}
}
test ("collection accumulators") {
val maxI = 1000
for (nThreads <- List(1, 10)) {
//test single & multi-threaded
val sc = new SparkContext("local[" + nThreads + "]", "test")
val setAcc = sc.accumlableCollection(mutable.HashSet[Int]())
val bufferAcc = sc.accumlableCollection(mutable.ArrayBuffer[Int]())
val mapAcc = sc.accumlableCollection(mutable.HashMap[Int,String]())
val d = sc.parallelize( (1 to maxI) ++ (1 to maxI))
d.foreach {
x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)}
}
//NOTE that this is typed correctly -- no casts necessary
setAcc.value.size should be (maxI)
bufferAcc.value.size should be (2 * maxI)
mapAcc.value.size should be (maxI)
for (i <- 1 to maxI) {
setAcc.value should contain(i)
bufferAcc.value should contain(i)
mapAcc.value should contain (i -> i.toString)
}
sc.stop()
}
}
}