[SPARK-23697][CORE] LegacyAccumulatorWrapper should define isZero correctly
## What changes were proposed in this pull request? It's possible that Accumulators of Spark 1.x may no longer work with Spark 2.x. This is because `LegacyAccumulatorWrapper.isZero` may return wrong answer if `AccumulableParam` doesn't define equals/hashCode. This PR fixes this by using reference equality check in `LegacyAccumulatorWrapper.isZero`. ## How was this patch tested? a new test Author: Wenchen Fan <wenchen@databricks.com> Closes #21229 from cloud-fan/accumulator.
This commit is contained in:
parent
7f1b6b182e
commit
4d5de4d303
|
@ -486,7 +486,9 @@ class LegacyAccumulatorWrapper[R, T](
|
||||||
param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] {
|
param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] {
|
||||||
private[spark] var _value = initialValue // Current value on driver
|
private[spark] var _value = initialValue // Current value on driver
|
||||||
|
|
||||||
override def isZero: Boolean = _value == param.zero(initialValue)
|
@transient private lazy val _zero = param.zero(initialValue)
|
||||||
|
|
||||||
|
override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef])
|
||||||
|
|
||||||
override def copy(): LegacyAccumulatorWrapper[R, T] = {
|
override def copy(): LegacyAccumulatorWrapper[R, T] = {
|
||||||
val acc = new LegacyAccumulatorWrapper(initialValue, param)
|
val acc = new LegacyAccumulatorWrapper(initialValue, param)
|
||||||
|
@ -495,7 +497,7 @@ class LegacyAccumulatorWrapper[R, T](
|
||||||
}
|
}
|
||||||
|
|
||||||
override def reset(): Unit = {
|
override def reset(): Unit = {
|
||||||
_value = param.zero(initialValue)
|
_value = _zero
|
||||||
}
|
}
|
||||||
|
|
||||||
override def add(v: T): Unit = _value = param.addAccumulator(_value, v)
|
override def add(v: T): Unit = _value = param.addAccumulator(_value, v)
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
package org.apache.spark.util
|
package org.apache.spark.util
|
||||||
|
|
||||||
import org.apache.spark._
|
import org.apache.spark._
|
||||||
|
import org.apache.spark.serializer.JavaSerializer
|
||||||
|
|
||||||
class AccumulatorV2Suite extends SparkFunSuite {
|
class AccumulatorV2Suite extends SparkFunSuite {
|
||||||
|
|
||||||
|
@ -162,4 +163,22 @@ class AccumulatorV2Suite extends SparkFunSuite {
|
||||||
assert(acc3.isZero)
|
assert(acc3.isZero)
|
||||||
assert(acc3.value === "")
|
assert(acc3.value === "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") {
|
||||||
|
class MyData(val i: Int) extends Serializable
|
||||||
|
val param = new AccumulatorParam[MyData] {
|
||||||
|
override def zero(initialValue: MyData): MyData = new MyData(0)
|
||||||
|
override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i)
|
||||||
|
}
|
||||||
|
|
||||||
|
val acc = new LegacyAccumulatorWrapper(new MyData(0), param)
|
||||||
|
acc.metadata = AccumulatorMetadata(
|
||||||
|
AccumulatorContext.newId(),
|
||||||
|
Some("test"),
|
||||||
|
countFailedValues = false)
|
||||||
|
AccumulatorContext.register(acc)
|
||||||
|
|
||||||
|
val ser = new JavaSerializer(new SparkConf).newInstance()
|
||||||
|
ser.serialize(acc)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue