Fix tracking of updates in accumulators to solve an issue that would manifest in the 2.9 interpreter

This commit is contained in:
Matei Zaharia 2011-07-14 14:08:34 -04:00
parent 0ccfe20755
commit 797b4547c3
2 changed files with 21 additions and 13 deletions

View file

@ -12,7 +12,7 @@ import scala.collection.mutable.Map
val zero = param.zero(initialValue) // Zero value to be passed to workers
var deserialized = false
Accumulators.register(this)
Accumulators.register(this, true)
def += (term: T) { value_ = param.addInPlace(value_, term) }
def value = this.value_
@ -26,7 +26,7 @@ import scala.collection.mutable.Map
in.defaultReadObject
value_ = zero
deserialized = true
Accumulators.register(this)
Accumulators.register(this, false)
}
override def toString = value_.toString
@ -42,31 +42,39 @@ import scala.collection.mutable.Map
private object Accumulators
{
// TODO: Use soft references? => need to make readObject work properly then
val accums = Map[(Thread, Long), Accumulator[_]]()
val originals = Map[Long, Accumulator[_]]()
val localAccums = Map[Thread, Map[Long, Accumulator[_]]]()
var lastId: Long = 0
def newId: Long = synchronized { lastId += 1; return lastId }
def register(a: Accumulator[_]): Unit = synchronized {
accums((currentThread, a.id)) = a
def register(a: Accumulator[_], original: Boolean): Unit = synchronized {
if (original) {
originals(a.id) = a
} else {
val accums = localAccums.getOrElseUpdate(currentThread, Map())
accums(a.id) = a
}
}
// Clear the local (non-original) accumulators for the current thread
def clear: Unit = synchronized {
accums.retain((key, accum) => key._1 != currentThread)
localAccums.remove(currentThread)
}
// Get the values of the local accumulators for the current thread (by ID)
def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]()
for(((thread, id), accum) <- accums if thread == currentThread)
for ((id, accum) <- localAccums.getOrElse(currentThread, Map()))
ret(id) = accum.value
return ret
}
def add(thread: Thread, values: Map[Long, Any]): Unit = synchronized {
// Add values to the original accumulators with some given IDs
def add(values: Map[Long, Any]): Unit = synchronized {
for ((id, value) <- values) {
if (accums.contains((thread, id))) {
val accum = accums((thread, id))
accum.asInstanceOf[Accumulator[Any]] += value
if (originals.contains(id)) {
originals(id).asInstanceOf[Accumulator[Any]] += value
}
}
}

View file

@ -225,7 +225,7 @@ private trait DAGScheduler extends Scheduler with Logging {
if (evt.reason == Success) {
// A task ended
logInfo("Completed " + evt.task)
Accumulators.add(currentThread, evt.accumUpdates)
Accumulators.add(evt.accumUpdates)
evt.task match {
case rt: ResultTask[_, _] =>
results(rt.outputId) = evt.result.asInstanceOf[U]