[SPARK-23697][CORE] LegacyAccumulatorWrapper should define isZero correctly

## What changes were proposed in this pull request?

It's possible that Accumulators of Spark 1.x may no longer work with Spark 2.x. This is because `LegacyAccumulatorWrapper.isZero` may return wrong answer if `AccumulableParam` doesn't define equals/hashCode.

This PR fixes this by using reference equality check in `LegacyAccumulatorWrapper.isZero`.

## How was this patch tested?

a new test

Author: Wenchen Fan <wenchen@databricks.com>

Closes #21229 from cloud-fan/accumulator.
This commit is contained in:
Wenchen Fan 2018-05-04 19:20:15 +08:00
parent 7f1b6b182e
commit 4d5de4d303
2 changed files with 23 additions and 2 deletions

View file

@ -486,7 +486,9 @@ class LegacyAccumulatorWrapper[R, T](
param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] { param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] {
private[spark] var _value = initialValue // Current value on driver private[spark] var _value = initialValue // Current value on driver
override def isZero: Boolean = _value == param.zero(initialValue) @transient private lazy val _zero = param.zero(initialValue)
override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef])
override def copy(): LegacyAccumulatorWrapper[R, T] = { override def copy(): LegacyAccumulatorWrapper[R, T] = {
val acc = new LegacyAccumulatorWrapper(initialValue, param) val acc = new LegacyAccumulatorWrapper(initialValue, param)
@ -495,7 +497,7 @@ class LegacyAccumulatorWrapper[R, T](
} }
override def reset(): Unit = { override def reset(): Unit = {
_value = param.zero(initialValue) _value = _zero
} }
override def add(v: T): Unit = _value = param.addAccumulator(_value, v) override def add(v: T): Unit = _value = param.addAccumulator(_value, v)

View file

@ -18,6 +18,7 @@
package org.apache.spark.util package org.apache.spark.util
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.serializer.JavaSerializer
class AccumulatorV2Suite extends SparkFunSuite { class AccumulatorV2Suite extends SparkFunSuite {
@ -162,4 +163,22 @@ class AccumulatorV2Suite extends SparkFunSuite {
assert(acc3.isZero) assert(acc3.isZero)
assert(acc3.value === "") assert(acc3.value === "")
} }
test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") {
class MyData(val i: Int) extends Serializable
val param = new AccumulatorParam[MyData] {
override def zero(initialValue: MyData): MyData = new MyData(0)
override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i)
}
val acc = new LegacyAccumulatorWrapper(new MyData(0), param)
acc.metadata = AccumulatorMetadata(
AccumulatorContext.newId(),
Some("test"),
countFailedValues = false)
AccumulatorContext.register(acc)
val ser = new JavaSerializer(new SparkConf).newInstance()
ser.serialize(acc)
}
} }