Fix tracking of updates in accumulators to solve an issue that would manifest in the 2.9 interpreter
This commit is contained in:
parent
0ccfe20755
commit
797b4547c3
|
@ -12,7 +12,7 @@ import scala.collection.mutable.Map
|
||||||
val zero = param.zero(initialValue) // Zero value to be passed to workers
|
val zero = param.zero(initialValue) // Zero value to be passed to workers
|
||||||
var deserialized = false
|
var deserialized = false
|
||||||
|
|
||||||
Accumulators.register(this)
|
Accumulators.register(this, true)
|
||||||
|
|
||||||
def += (term: T) { value_ = param.addInPlace(value_, term) }
|
def += (term: T) { value_ = param.addInPlace(value_, term) }
|
||||||
def value = this.value_
|
def value = this.value_
|
||||||
|
@ -26,7 +26,7 @@ import scala.collection.mutable.Map
|
||||||
in.defaultReadObject
|
in.defaultReadObject
|
||||||
value_ = zero
|
value_ = zero
|
||||||
deserialized = true
|
deserialized = true
|
||||||
Accumulators.register(this)
|
Accumulators.register(this, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def toString = value_.toString
|
override def toString = value_.toString
|
||||||
|
@ -42,31 +42,39 @@ import scala.collection.mutable.Map
|
||||||
private object Accumulators
|
private object Accumulators
|
||||||
{
|
{
|
||||||
// TODO: Use soft references? => need to make readObject work properly then
|
// TODO: Use soft references? => need to make readObject work properly then
|
||||||
val accums = Map[(Thread, Long), Accumulator[_]]()
|
val originals = Map[Long, Accumulator[_]]()
|
||||||
|
val localAccums = Map[Thread, Map[Long, Accumulator[_]]]()
|
||||||
var lastId: Long = 0
|
var lastId: Long = 0
|
||||||
|
|
||||||
def newId: Long = synchronized { lastId += 1; return lastId }
|
def newId: Long = synchronized { lastId += 1; return lastId }
|
||||||
|
|
||||||
def register(a: Accumulator[_]): Unit = synchronized {
|
def register(a: Accumulator[_], original: Boolean): Unit = synchronized {
|
||||||
accums((currentThread, a.id)) = a
|
if (original) {
|
||||||
|
originals(a.id) = a
|
||||||
|
} else {
|
||||||
|
val accums = localAccums.getOrElseUpdate(currentThread, Map())
|
||||||
|
accums(a.id) = a
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear the local (non-original) accumulators for the current thread
|
||||||
def clear: Unit = synchronized {
|
def clear: Unit = synchronized {
|
||||||
accums.retain((key, accum) => key._1 != currentThread)
|
localAccums.remove(currentThread)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the values of the local accumulators for the current thread (by ID)
|
||||||
def values: Map[Long, Any] = synchronized {
|
def values: Map[Long, Any] = synchronized {
|
||||||
val ret = Map[Long, Any]()
|
val ret = Map[Long, Any]()
|
||||||
for(((thread, id), accum) <- accums if thread == currentThread)
|
for ((id, accum) <- localAccums.getOrElse(currentThread, Map()))
|
||||||
ret(id) = accum.value
|
ret(id) = accum.value
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
def add(thread: Thread, values: Map[Long, Any]): Unit = synchronized {
|
// Add values to the original accumulators with some given IDs
|
||||||
|
def add(values: Map[Long, Any]): Unit = synchronized {
|
||||||
for ((id, value) <- values) {
|
for ((id, value) <- values) {
|
||||||
if (accums.contains((thread, id))) {
|
if (originals.contains(id)) {
|
||||||
val accum = accums((thread, id))
|
originals(id).asInstanceOf[Accumulator[Any]] += value
|
||||||
accum.asInstanceOf[Accumulator[Any]] += value
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -225,7 +225,7 @@ private trait DAGScheduler extends Scheduler with Logging {
|
||||||
if (evt.reason == Success) {
|
if (evt.reason == Success) {
|
||||||
// A task ended
|
// A task ended
|
||||||
logInfo("Completed " + evt.task)
|
logInfo("Completed " + evt.task)
|
||||||
Accumulators.add(currentThread, evt.accumUpdates)
|
Accumulators.add(evt.accumUpdates)
|
||||||
evt.task match {
|
evt.task match {
|
||||||
case rt: ResultTask[_, _] =>
|
case rt: ResultTask[_, _] =>
|
||||||
results(rt.outputId) = evt.result.asInstanceOf[U]
|
results(rt.outputId) = evt.result.asInstanceOf[U]
|
||||||
|
|
Loading…
Reference in a new issue