[SPARK-23697][CORE] LegacyAccumulatorWrapper should define isZero correctly
## What changes were proposed in this pull request? It's possible that Accumulators of Spark 1.x may no longer work with Spark 2.x. This is because `LegacyAccumulatorWrapper.isZero` may return wrong answer if `AccumulableParam` doesn't define equals/hashCode. This PR fixes this by using reference equality check in `LegacyAccumulatorWrapper.isZero`. ## How was this patch tested? a new test Author: Wenchen Fan <wenchen@databricks.com> Closes #21229 from cloud-fan/accumulator.
This commit is contained in:
parent
7f1b6b182e
commit
4d5de4d303
|
@ -486,7 +486,9 @@ class LegacyAccumulatorWrapper[R, T](
|
|||
param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] {
|
||||
private[spark] var _value = initialValue // Current value on driver
|
||||
|
||||
override def isZero: Boolean = _value == param.zero(initialValue)
|
||||
@transient private lazy val _zero = param.zero(initialValue)
|
||||
|
||||
override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef])
|
||||
|
||||
override def copy(): LegacyAccumulatorWrapper[R, T] = {
|
||||
val acc = new LegacyAccumulatorWrapper(initialValue, param)
|
||||
|
@ -495,7 +497,7 @@ class LegacyAccumulatorWrapper[R, T](
|
|||
}
|
||||
|
||||
override def reset(): Unit = {
|
||||
_value = param.zero(initialValue)
|
||||
_value = _zero
|
||||
}
|
||||
|
||||
override def add(v: T): Unit = _value = param.addAccumulator(_value, v)
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.util
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.serializer.JavaSerializer
|
||||
|
||||
class AccumulatorV2Suite extends SparkFunSuite {
|
||||
|
||||
|
@ -162,4 +163,22 @@ class AccumulatorV2Suite extends SparkFunSuite {
|
|||
assert(acc3.isZero)
|
||||
assert(acc3.value === "")
|
||||
}
|
||||
|
||||
test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") {
|
||||
class MyData(val i: Int) extends Serializable
|
||||
val param = new AccumulatorParam[MyData] {
|
||||
override def zero(initialValue: MyData): MyData = new MyData(0)
|
||||
override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i)
|
||||
}
|
||||
|
||||
val acc = new LegacyAccumulatorWrapper(new MyData(0), param)
|
||||
acc.metadata = AccumulatorMetadata(
|
||||
AccumulatorContext.newId(),
|
||||
Some("test"),
|
||||
countFailedValues = false)
|
||||
AccumulatorContext.register(acc)
|
||||
|
||||
val ser = new JavaSerializer(new SparkConf).newInstance()
|
||||
ser.serialize(acc)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue