[SPARK-14494][SQL] Fix the race conditions in MemoryStream and MemorySink

## What changes were proposed in this pull request?

Make sure accessing mutable variables in MemoryStream and MemorySink are protected by `synchronized`.
This is probably why MemorySinkSuite failed here: https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.2/650/testReport/junit/org.apache.spark.sql.streaming/MemorySinkSuite/registering_as_a_table/

## How was this patch tested?
Existing unit tests.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #12261 from zsxwing/memory-race-condition.
This commit is contained in:
Shixiong Zhu 2016-04-11 10:42:51 -07:00 committed by Michael Armbrust
parent 5de26194a3
commit 2dacc81ec3

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.streaming
import java.util.concurrent.atomic.AtomicInteger
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
@ -47,8 +48,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
protected val encoder = encoderFor[A]
protected val logicalPlan = StreamingExecutionRelation(this)
protected val output = logicalPlan.output
@GuardedBy("this")
protected val batches = new ArrayBuffer[Dataset[A]]
@GuardedBy("this")
protected var currentOffset: LongOffset = new LongOffset(-1)
def schema: StructType = encoder.schema
@ -67,10 +71,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
def addData(data: TraversableOnce[A]): Offset = {
import sqlContext.implicits._
this.synchronized {
currentOffset = currentOffset + 1
val ds = data.toVector.toDS()
logDebug(s"Adding ds: $ds")
this.synchronized {
currentOffset = currentOffset + 1
batches.append(ds)
currentOffset
}
@ -78,11 +82,13 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
override def toString: String = s"MemoryStream[${output.mkString(",")}]"
override def getOffset: Option[Offset] = if (batches.isEmpty) {
override def getOffset: Option[Offset] = synchronized {
if (batches.isEmpty) {
None
} else {
Some(currentOffset)
}
}
/**
* Returns the next batch of data that is available after `start`, if any is available.
@ -91,7 +97,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
val startOrdinal =
start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1
val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1
val newBlocks = batches.slice(startOrdinal, endOrdinal)
val newBlocks = synchronized { batches.slice(startOrdinal, endOrdinal) }
logDebug(
s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}")
@ -110,6 +116,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
*/
class MemorySink(val schema: StructType) extends Sink with Logging {
/** An order list of batches that have been written to this [[Sink]]. */
@GuardedBy("this")
private val batches = new ArrayBuffer[Array[Row]]()
/** Returns all rows that are stored in this [[Sink]]. */
@ -117,7 +124,7 @@ class MemorySink(val schema: StructType) extends Sink with Logging {
batches.flatten
}
def lastBatch: Seq[Row] = batches.last
def lastBatch: Seq[Row] = synchronized { batches.last }
def toDebugString: String = synchronized {
batches.zipWithIndex.map { case (b, i) =>
@ -128,7 +135,7 @@ class MemorySink(val schema: StructType) extends Sink with Logging {
}.mkString("\n")
}
override def addBatch(batchId: Long, data: DataFrame): Unit = {
override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized {
if (batchId == batches.size) {
logDebug(s"Committing batch $batchId")
batches.append(data.collect())