Fix tracking of updates in accumulators to solve an issue that would manifest in the 2.9 interpreter
This commit is contained in:
parent
0ccfe20755
commit
797b4547c3
|
@ -12,7 +12,7 @@ import scala.collection.mutable.Map
|
|||
val zero = param.zero(initialValue) // Zero value to be passed to workers
|
||||
var deserialized = false
|
||||
|
||||
Accumulators.register(this)
|
||||
Accumulators.register(this, true)
|
||||
|
||||
def += (term: T) { value_ = param.addInPlace(value_, term) }
|
||||
def value = this.value_
|
||||
|
@ -26,7 +26,7 @@ import scala.collection.mutable.Map
|
|||
in.defaultReadObject
|
||||
value_ = zero
|
||||
deserialized = true
|
||||
Accumulators.register(this)
|
||||
Accumulators.register(this, false)
|
||||
}
|
||||
|
||||
override def toString = value_.toString
|
||||
|
@ -42,31 +42,39 @@ import scala.collection.mutable.Map
|
|||
private object Accumulators
|
||||
{
|
||||
// TODO: Use soft references? => need to make readObject work properly then
|
||||
val accums = Map[(Thread, Long), Accumulator[_]]()
|
||||
val originals = Map[Long, Accumulator[_]]()
|
||||
val localAccums = Map[Thread, Map[Long, Accumulator[_]]]()
|
||||
var lastId: Long = 0
|
||||
|
||||
def newId: Long = synchronized { lastId += 1; return lastId }
|
||||
|
||||
def register(a: Accumulator[_]): Unit = synchronized {
|
||||
accums((currentThread, a.id)) = a
|
||||
def register(a: Accumulator[_], original: Boolean): Unit = synchronized {
|
||||
if (original) {
|
||||
originals(a.id) = a
|
||||
} else {
|
||||
val accums = localAccums.getOrElseUpdate(currentThread, Map())
|
||||
accums(a.id) = a
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the local (non-original) accumulators for the current thread
|
||||
def clear: Unit = synchronized {
|
||||
accums.retain((key, accum) => key._1 != currentThread)
|
||||
localAccums.remove(currentThread)
|
||||
}
|
||||
|
||||
// Get the values of the local accumulators for the current thread (by ID)
|
||||
def values: Map[Long, Any] = synchronized {
|
||||
val ret = Map[Long, Any]()
|
||||
for(((thread, id), accum) <- accums if thread == currentThread)
|
||||
for ((id, accum) <- localAccums.getOrElse(currentThread, Map()))
|
||||
ret(id) = accum.value
|
||||
return ret
|
||||
}
|
||||
|
||||
def add(thread: Thread, values: Map[Long, Any]): Unit = synchronized {
|
||||
// Add values to the original accumulators with some given IDs
|
||||
def add(values: Map[Long, Any]): Unit = synchronized {
|
||||
for ((id, value) <- values) {
|
||||
if (accums.contains((thread, id))) {
|
||||
val accum = accums((thread, id))
|
||||
accum.asInstanceOf[Accumulator[Any]] += value
|
||||
if (originals.contains(id)) {
|
||||
originals(id).asInstanceOf[Accumulator[Any]] += value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -225,7 +225,7 @@ private trait DAGScheduler extends Scheduler with Logging {
|
|||
if (evt.reason == Success) {
|
||||
// A task ended
|
||||
logInfo("Completed " + evt.task)
|
||||
Accumulators.add(currentThread, evt.accumUpdates)
|
||||
Accumulators.add(evt.accumUpdates)
|
||||
evt.task match {
|
||||
case rt: ResultTask[_, _] =>
|
||||
results(rt.outputId) = evt.result.asInstanceOf[U]
|
||||
|
|
Loading…
Reference in a new issue