[SPARK-17463][CORE] Make CollectionAccumulator and SetAccumulator's value can be read thread-safely

## What changes were proposed in this pull request?

Make CollectionAccumulator and SetAccumulator's value can be read thread-safely to fix the ConcurrentModificationException reported in [JIRA](https://issues.apache.org/jira/browse/SPARK-17463).

## How was this patch tested?

Existing tests.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #15063 from zsxwing/SPARK-17463.
This commit is contained in:
Shixiong Zhu 2016-09-14 13:33:51 -07:00 committed by Josh Rosen
parent ff6e4cbdc8
commit e33bfaed3b
5 changed files with 54 additions and 32 deletions

View file

@ -17,6 +17,9 @@
package org.apache.spark.executor
import java.util.{ArrayList, Collections}
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, LinkedHashMap}
import org.apache.spark._
@ -99,7 +102,11 @@ class TaskMetrics private[spark] () extends Serializable {
/**
* Storage statuses of any blocks that have been updated as a result of this task.
*/
def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses.value
def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = {
// This is called on driver. All accumulator updates have a fixed value. So it's safe to use
// `asScala` which accesses the internal values using `java.util.Iterator`.
_updatedBlockStatuses.value.asScala
}
// Setters and increment-ers
private[spark] def setExecutorDeserializeTime(v: Long): Unit =
@ -114,8 +121,10 @@ class TaskMetrics private[spark] () extends Serializable {
private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v)
private[spark] def incUpdatedBlockStatuses(v: (BlockId, BlockStatus)): Unit =
_updatedBlockStatuses.add(v)
private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit =
private[spark] def setUpdatedBlockStatuses(v: java.util.List[(BlockId, BlockStatus)]): Unit =
_updatedBlockStatuses.setValue(v)
private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit =
_updatedBlockStatuses.setValue(v.asJava)
/**
* Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted
@ -268,7 +277,7 @@ private[spark] object TaskMetrics extends Logging {
val name = info.name.get
val value = info.update.get
if (name == UPDATED_BLOCK_STATUSES) {
tm.setUpdatedBlockStatuses(value.asInstanceOf[Seq[(BlockId, BlockStatus)]])
tm.setUpdatedBlockStatuses(value.asInstanceOf[java.util.List[(BlockId, BlockStatus)]])
} else {
tm.nameToAccums.get(name).foreach(
_.asInstanceOf[LongAccumulator].setValue(value.asInstanceOf[Long])
@ -299,8 +308,8 @@ private[spark] object TaskMetrics extends Logging {
private[spark] class BlockStatusesAccumulator
extends AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]] {
private var _seq = ArrayBuffer.empty[(BlockId, BlockStatus)]
extends AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, BlockStatus)]] {
private val _seq = Collections.synchronizedList(new ArrayList[(BlockId, BlockStatus)]())
override def isZero(): Boolean = _seq.isEmpty
@ -308,25 +317,27 @@ private[spark] class BlockStatusesAccumulator
override def copy(): BlockStatusesAccumulator = {
val newAcc = new BlockStatusesAccumulator
newAcc._seq = _seq.clone()
newAcc._seq.addAll(_seq)
newAcc
}
override def reset(): Unit = _seq.clear()
override def add(v: (BlockId, BlockStatus)): Unit = _seq += v
override def add(v: (BlockId, BlockStatus)): Unit = _seq.add(v)
override def merge(other: AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]])
: Unit = other match {
case o: BlockStatusesAccumulator => _seq ++= o.value
case _ => throw new UnsupportedOperationException(
s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
override def merge(
other: AccumulatorV2[(BlockId, BlockStatus), java.util.List[(BlockId, BlockStatus)]]): Unit = {
other match {
case o: BlockStatusesAccumulator => _seq.addAll(o.value)
case _ => throw new UnsupportedOperationException(
s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
}
override def value: Seq[(BlockId, BlockStatus)] = _seq
override def value: java.util.List[(BlockId, BlockStatus)] = _seq
def setValue(newValue: Seq[(BlockId, BlockStatus)]): Unit = {
def setValue(newValue: java.util.List[(BlockId, BlockStatus)]): Unit = {
_seq.clear()
_seq ++= newValue
_seq.addAll(newValue)
}
}

View file

@ -19,7 +19,7 @@ package org.apache.spark.util
import java.{lang => jl}
import java.io.ObjectInputStream
import java.util.ArrayList
import java.util.{ArrayList, Collections}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong
@ -38,6 +38,9 @@ private[spark] case class AccumulatorMetadata(
/**
* The base class for accumulators, that can accumulate inputs of type `IN`, and produce output of
* type `OUT`.
*
* `OUT` should be a type that can be read atomically (e.g., Int, Long), or thread-safely
* (e.g., synchronized collections) because it will be read from other threads.
*/
abstract class AccumulatorV2[IN, OUT] extends Serializable {
private[spark] var metadata: AccumulatorMetadata = _
@ -433,7 +436,7 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
* @since 2.0.0
*/
class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] {
private val _list: java.util.List[T] = new ArrayList[T]
private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]())
override def isZero: Boolean = _list.isEmpty

View file

@ -310,11 +310,12 @@ private[spark] object JsonProtocol {
case v: Int => JInt(v)
case v: Long => JInt(v)
// We only have 3 kind of internal accumulator types, so if it's not int or long, it must be
// the blocks accumulator, whose type is `Seq[(BlockId, BlockStatus)]`
// the blocks accumulator, whose type is `java.util.List[(BlockId, BlockStatus)]`
case v =>
JArray(v.asInstanceOf[Seq[(BlockId, BlockStatus)]].toList.map { case (id, status) =>
("Block ID" -> id.toString) ~
("Status" -> blockStatusToJson(status))
JArray(v.asInstanceOf[java.util.List[(BlockId, BlockStatus)]].asScala.toList.map {
case (id, status) =>
("Block ID" -> id.toString) ~
("Status" -> blockStatusToJson(status))
})
}
} else {
@ -743,7 +744,7 @@ private[spark] object JsonProtocol {
val id = BlockId((blockJson \ "Block ID").extract[String])
val status = blockStatusFromJson(blockJson \ "Status")
(id, status)
}
}.asJava
case _ => throw new IllegalArgumentException(s"unexpected json value $value for " +
"accumulator " + name.get)
}

View file

@ -19,6 +19,7 @@ package org.apache.spark.util
import java.util.Properties
import scala.collection.JavaConverters._
import scala.collection.Map
import org.json4s.jackson.JsonMethods._
@ -415,7 +416,7 @@ class JsonProtocolSuite extends SparkFunSuite {
})
testAccumValue(Some(RESULT_SIZE), 3L, JInt(3))
testAccumValue(Some(shuffleRead.REMOTE_BLOCKS_FETCHED), 2, JInt(2))
testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks, blocksJson)
testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks.asJava, blocksJson)
// For anything else, we just cast the value to a string
testAccumValue(Some("anything"), blocks, JString(blocks.toString))
testAccumValue(Some("anything"), 123, JString("123"))

View file

@ -17,7 +17,9 @@
package org.apache.spark.sql.execution
import scala.collection.mutable.HashSet
import java.util.Collections
import scala.collection.JavaConverters._
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
@ -107,18 +109,20 @@ package object debug {
case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport {
def output: Seq[Attribute] = child.output
class SetAccumulator[T] extends AccumulatorV2[T, HashSet[T]] {
private val _set = new HashSet[T]()
class SetAccumulator[T] extends AccumulatorV2[T, java.util.Set[T]] {
private val _set = Collections.synchronizedSet(new java.util.HashSet[T]())
override def isZero: Boolean = _set.isEmpty
override def copy(): AccumulatorV2[T, HashSet[T]] = {
override def copy(): AccumulatorV2[T, java.util.Set[T]] = {
val newAcc = new SetAccumulator[T]()
newAcc._set ++= _set
newAcc._set.addAll(_set)
newAcc
}
override def reset(): Unit = _set.clear()
override def add(v: T): Unit = _set += v
override def merge(other: AccumulatorV2[T, HashSet[T]]): Unit = _set ++= other.value
override def value: HashSet[T] = _set
override def add(v: T): Unit = _set.add(v)
override def merge(other: AccumulatorV2[T, java.util.Set[T]]): Unit = {
_set.addAll(other.value)
}
override def value: java.util.Set[T] = _set
}
/**
@ -138,7 +142,9 @@ package object debug {
debugPrint(s"== ${child.simpleString} ==")
debugPrint(s"Tuples output: ${tupleCount.value}")
child.output.zip(columnStats).foreach { case (attr, metric) =>
val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}")
// This is called on driver. All accumulator updates have a fixed value. So it's safe to use
// `asScala` which accesses the internal values using `java.util.Iterator`.
val actualDataTypes = metric.elementTypes.value.asScala.mkString("{", ",", "}")
debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
}
}