add accumulators for mutable collections, with correct typing!
This commit is contained in:
parent
680df96c43
commit
823878c77f
|
@ -3,6 +3,7 @@ package spark
|
|||
import java.io._
|
||||
|
||||
import scala.collection.mutable.Map
|
||||
import collection.generic.Growable
|
||||
|
||||
class Accumulable[T,R] (
|
||||
@transient initialValue: T,
|
||||
|
@ -99,6 +100,18 @@ trait AccumulableParam[R,T] extends Serializable {
|
|||
def zero(initialValue: R): R
|
||||
}
|
||||
|
||||
class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T] extends AccumulableParam[R,T] {
|
||||
def addAccumulator(growable: R, elem: T) : R = {
|
||||
growable += elem
|
||||
growable
|
||||
}
|
||||
def addInPlace(t1: R, t2: R) : R = {
|
||||
t1 ++= t2
|
||||
t1
|
||||
}
|
||||
def zero(initialValue: R) = initialValue
|
||||
}
|
||||
|
||||
// TODO: The multi-thread support in accumulators is kind of lame; check
|
||||
// if there's a more intuitive way of doing it right
|
||||
private object Accumulators {
|
||||
|
|
|
@ -4,7 +4,6 @@ import java.io._
|
|||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
import scala.actors.remote.RemoteActor
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
|
@ -31,6 +30,7 @@ import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
|
|||
import org.apache.mesos.MesosNativeLibrary
|
||||
|
||||
import spark.broadcast._
|
||||
import collection.generic.Growable
|
||||
|
||||
class SparkContext(
|
||||
master: String,
|
||||
|
@ -253,6 +253,10 @@ class SparkContext(
|
|||
def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
|
||||
new Accumulable(initialValue, param)
|
||||
|
||||
def accumlableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = {
|
||||
val param = new GrowableAccumulableParam[R,T]
|
||||
new Accumulable(initialValue, param)
|
||||
}
|
||||
|
||||
// Keep around a weak hash map of values to Cached versions?
|
||||
def broadcast[T](value: T) = Broadcast.getBroadcastFactory.newBroadcast[T] (value, isLocal)
|
||||
|
|
|
@ -3,9 +3,6 @@ package spark
|
|||
import org.scalatest.FunSuite
|
||||
import org.scalatest.matchers.ShouldMatchers
|
||||
import collection.mutable
|
||||
import java.util.Random
|
||||
import scala.math.exp
|
||||
import scala.math.signum
|
||||
import spark.SparkContext._
|
||||
|
||||
class AccumulatorSuite extends FunSuite with ShouldMatchers {
|
||||
|
@ -79,4 +76,31 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers {
|
|||
}
|
||||
}
|
||||
|
||||
test ("collection accumulators") {
|
||||
val maxI = 1000
|
||||
for (nThreads <- List(1, 10)) {
|
||||
//test single & multi-threaded
|
||||
val sc = new SparkContext("local[" + nThreads + "]", "test")
|
||||
val setAcc = sc.accumlableCollection(mutable.HashSet[Int]())
|
||||
val bufferAcc = sc.accumlableCollection(mutable.ArrayBuffer[Int]())
|
||||
val mapAcc = sc.accumlableCollection(mutable.HashMap[Int,String]())
|
||||
val d = sc.parallelize( (1 to maxI) ++ (1 to maxI))
|
||||
d.foreach {
|
||||
x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)}
|
||||
}
|
||||
|
||||
//NOTE that this is typed correctly -- no casts necessary
|
||||
setAcc.value.size should be (maxI)
|
||||
bufferAcc.value.size should be (2 * maxI)
|
||||
mapAcc.value.size should be (maxI)
|
||||
for (i <- 1 to maxI) {
|
||||
setAcc.value should contain(i)
|
||||
bufferAcc.value should contain(i)
|
||||
mapAcc.value should contain (i -> i.toString)
|
||||
}
|
||||
sc.stop()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in a new issue