[SPARK-14494][SQL] Fix the race conditions in MemoryStream and MemorySink
## What changes were proposed in this pull request? Make sure accessing mutable variables in MemoryStream and MemorySink are protected by `synchronized`. This is probably why MemorySinkSuite failed here: https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.2/650/testReport/junit/org.apache.spark.sql.streaming/MemorySinkSuite/registering_as_a_table/ ## How was this patch tested? Existing unit tests. Author: Shixiong Zhu <shixiong@databricks.com> Closes #12261 from zsxwing/memory-race-condition.
This commit is contained in:
parent
5de26194a3
commit
2dacc81ec3
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql.execution.streaming
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import javax.annotation.concurrent.GuardedBy
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.util.control.NonFatal
|
||||
|
@ -47,8 +48,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
|
|||
protected val encoder = encoderFor[A]
|
||||
protected val logicalPlan = StreamingExecutionRelation(this)
|
||||
protected val output = logicalPlan.output
|
||||
|
||||
@GuardedBy("this")
|
||||
protected val batches = new ArrayBuffer[Dataset[A]]
|
||||
|
||||
@GuardedBy("this")
|
||||
protected var currentOffset: LongOffset = new LongOffset(-1)
|
||||
|
||||
def schema: StructType = encoder.schema
|
||||
|
@ -67,10 +71,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
|
|||
|
||||
def addData(data: TraversableOnce[A]): Offset = {
|
||||
import sqlContext.implicits._
|
||||
val ds = data.toVector.toDS()
|
||||
logDebug(s"Adding ds: $ds")
|
||||
this.synchronized {
|
||||
currentOffset = currentOffset + 1
|
||||
val ds = data.toVector.toDS()
|
||||
logDebug(s"Adding ds: $ds")
|
||||
batches.append(ds)
|
||||
currentOffset
|
||||
}
|
||||
|
@ -78,10 +82,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
|
|||
|
||||
override def toString: String = s"MemoryStream[${output.mkString(",")}]"
|
||||
|
||||
override def getOffset: Option[Offset] = if (batches.isEmpty) {
|
||||
None
|
||||
} else {
|
||||
Some(currentOffset)
|
||||
override def getOffset: Option[Offset] = synchronized {
|
||||
if (batches.isEmpty) {
|
||||
None
|
||||
} else {
|
||||
Some(currentOffset)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -91,7 +97,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
|
|||
val startOrdinal =
|
||||
start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1
|
||||
val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1
|
||||
val newBlocks = batches.slice(startOrdinal, endOrdinal)
|
||||
val newBlocks = synchronized { batches.slice(startOrdinal, endOrdinal) }
|
||||
|
||||
logDebug(
|
||||
s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}")
|
||||
|
@ -110,6 +116,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
|
|||
*/
|
||||
class MemorySink(val schema: StructType) extends Sink with Logging {
|
||||
/** An order list of batches that have been written to this [[Sink]]. */
|
||||
@GuardedBy("this")
|
||||
private val batches = new ArrayBuffer[Array[Row]]()
|
||||
|
||||
/** Returns all rows that are stored in this [[Sink]]. */
|
||||
|
@ -117,7 +124,7 @@ class MemorySink(val schema: StructType) extends Sink with Logging {
|
|||
batches.flatten
|
||||
}
|
||||
|
||||
def lastBatch: Seq[Row] = batches.last
|
||||
def lastBatch: Seq[Row] = synchronized { batches.last }
|
||||
|
||||
def toDebugString: String = synchronized {
|
||||
batches.zipWithIndex.map { case (b, i) =>
|
||||
|
@ -128,7 +135,7 @@ class MemorySink(val schema: StructType) extends Sink with Logging {
|
|||
}.mkString("\n")
|
||||
}
|
||||
|
||||
override def addBatch(batchId: Long, data: DataFrame): Unit = {
|
||||
override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized {
|
||||
if (batchId == batches.size) {
|
||||
logDebug(s"Committing batch $batchId")
|
||||
batches.append(data.collect())
|
||||
|
|
Loading…
Reference in a new issue