[SPARK-24234][SS] Support multiple row writers in continuous processing shuffle reader.

## What changes were proposed in this pull request?

https://docs.google.com/document/d/1IL4kJoKrZWeyIhklKUJqsW-yEN7V7aL05MmM65AYOfE/edit#heading=h.8t3ci57f7uii

Support multiple different row writers in continuous processing shuffle reader.

Note that having multiple read-side buffers ended up being the natural way to do this. Otherwise it's hard to express the constraint of sending an epoch marker only when all writers have sent one.

## How was this patch tested?

new unit tests

Author: Jose Torres <torres.joseph.f+github@gmail.com>

Closes #21385 from jose-torres/multipleWrite.
This commit is contained in:
Jose Torres 2018-05-24 17:08:52 -07:00 committed by Tathagata Das
parent 53c06ddabb
commit 0fd68cb727
3 changed files with 227 additions and 44 deletions

View file

@ -25,11 +25,16 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.NextIterator
case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition {
case class ContinuousShuffleReadPartition(
index: Int,
queueSize: Int,
numShuffleWriters: Int,
epochIntervalMs: Long)
extends Partition {
// Initialized only on the executor, and only once even as we call compute() multiple times.
lazy val (reader: ContinuousShuffleReader, endpoint) = {
val env = SparkEnv.get.rpcEnv
val receiver = new UnsafeRowReceiver(queueSize, env)
val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, epochIntervalMs, env)
val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver)
TaskContext.get().addTaskCompletionListener { ctx =>
env.stop(endpoint)
@ -42,16 +47,24 @@ case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Pa
* RDD at the map side of each continuous processing shuffle task. Upstream tasks send their
* shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks
* poll from their receiver until an epoch marker is sent.
*
* @param sc the RDD context
* @param numPartitions the number of read partitions for this RDD
* @param queueSize the size of the row buffers to use
* @param numShuffleWriters the number of continuous shuffle writers feeding into this RDD
* @param epochIntervalMs the checkpoint interval of the streaming query
*/
class ContinuousShuffleReadRDD(
sc: SparkContext,
numPartitions: Int,
queueSize: Int = 1024)
queueSize: Int = 1024,
numShuffleWriters: Int = 1,
epochIntervalMs: Long = 1000)
extends RDD[UnsafeRow](sc, Nil) {
override protected def getPartitions: Array[Partition] = {
(0 until numPartitions).map { partIndex =>
ContinuousShuffleReadPartition(partIndex, queueSize)
ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, epochIntervalMs)
}.toArray
}

View file

@ -17,9 +17,11 @@
package org.apache.spark.sql.execution.streaming.continuous.shuffle
import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue}
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
@ -27,10 +29,17 @@ import org.apache.spark.util.NextIterator
/**
* Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker.
*
* Each message comes tagged with writerId, identifying which writer the message is coming
* from. The receiver will only begin the next epoch once all writers have sent an epoch
* marker ending the current epoch.
*/
private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable
private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceiverMessage
private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage
private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable {
def writerId: Int
}
private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow)
extends UnsafeRowReceiverMessage
private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRowReceiverMessage
/**
* RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle
@ -41,11 +50,15 @@ private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessa
*/
private[shuffle] class UnsafeRowReceiver(
queueSize: Int,
numShuffleWriters: Int,
epochIntervalMs: Long,
override val rpcEnv: RpcEnv)
extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging {
// Note that this queue will be drained from the main task thread and populated in the RPC
// response thread.
private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize)
private val queues = Array.fill(numShuffleWriters) {
new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize)
}
// Exposed for testing to determine if the endpoint gets stopped on task end.
private[shuffle] val stopped = new AtomicBoolean(false)
@ -56,20 +69,70 @@ private[shuffle] class UnsafeRowReceiver(
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case r: UnsafeRowReceiverMessage =>
queue.put(r)
queues(r.writerId).put(r)
context.reply(())
}
override def read(): Iterator[UnsafeRow] = {
new NextIterator[UnsafeRow] {
override def getNext(): UnsafeRow = queue.take() match {
case ReceiverRow(r) => r
case ReceiverEpochMarker() =>
finished = true
null
// An array of flags for whether each writer ID has gotten an epoch marker.
private val writerEpochMarkersReceived = Array.fill(numShuffleWriters)(false)
private val executor = Executors.newFixedThreadPool(numShuffleWriters)
private val completion = new ExecutorCompletionService[UnsafeRowReceiverMessage](executor)
private def completionTask(writerId: Int) = new Callable[UnsafeRowReceiverMessage] {
override def call(): UnsafeRowReceiverMessage = queues(writerId).take()
}
override def close(): Unit = {}
// Initialize by submitting tasks to read the first row from each writer.
(0 until numShuffleWriters).foreach(writerId => completion.submit(completionTask(writerId)))
/**
* In each call to getNext(), we pull the next row available in the completion queue, and then
* submit another task to read the next row from the writer which returned it.
*
* When a writer sends an epoch marker, we note that it's finished and don't submit another
* task for it in this epoch. The iterator is over once all writers have sent an epoch marker.
*/
override def getNext(): UnsafeRow = {
var nextRow: UnsafeRow = null
while (!finished && nextRow == null) {
completion.poll(epochIntervalMs, TimeUnit.MILLISECONDS) match {
case null =>
// Try again if the poll didn't wait long enough to get a real result.
// But we should be getting at least an epoch marker every checkpoint interval.
val writerIdsUncommitted = writerEpochMarkersReceived.zipWithIndex.collect {
case (flag, idx) if !flag => idx
}
logWarning(
s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " +
s"for writers $writerIdsUncommitted to send epoch markers.")
// The completion service guarantees this future will be available immediately.
case future => future.get() match {
case ReceiverRow(writerId, r) =>
// Start reading the next element in the queue we just took from.
completion.submit(completionTask(writerId))
nextRow = r
case ReceiverEpochMarker(writerId) =>
// Don't read any more from this queue. If all the writers have sent epoch markers,
// the epoch is over; otherwise we need to loop again to poll from the remaining
// writers.
writerEpochMarkersReceived(writerId) = true
if (writerEpochMarkersReceived.forall(_ == true)) {
finished = true
}
}
}
}
nextRow
}
override def close(): Unit = {
executor.shutdownNow()
}
}
}
}

View file

@ -21,7 +21,8 @@ import org.apache.spark.{TaskContext, TaskContextImpl}
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection}
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.types.{DataType, IntegerType}
import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
import org.apache.spark.unsafe.types.UTF8String
class ContinuousShuffleReadSuite extends StreamTest {
@ -30,6 +31,11 @@ class ContinuousShuffleReadSuite extends StreamTest {
new GenericInternalRow(Array(value: Any)))
}
private def unsafeRow(value: String) = {
UnsafeProjection.create(Array(StringType : DataType))(
new GenericInternalRow(Array(UTF8String.fromString(value): Any)))
}
private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = {
messages.foreach(endpoint.askSync[Unit](_))
}
@ -57,8 +63,8 @@ class ContinuousShuffleReadSuite extends StreamTest {
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
send(
endpoint,
ReceiverEpochMarker(),
ReceiverRow(unsafeRow(111))
ReceiverEpochMarker(0),
ReceiverRow(0, unsafeRow(111))
)
ctx.markTaskCompleted(None)
@ -71,8 +77,11 @@ class ContinuousShuffleReadSuite extends StreamTest {
test("receiver stopped with marker last") {
val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
endpoint.askSync[Unit](ReceiverRow(unsafeRow(111)))
endpoint.askSync[Unit](ReceiverEpochMarker())
send(
endpoint,
ReceiverRow(0, unsafeRow(111)),
ReceiverEpochMarker(0)
)
ctx.markTaskCompleted(None)
val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
@ -86,10 +95,10 @@ class ContinuousShuffleReadSuite extends StreamTest {
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
send(
endpoint,
ReceiverRow(unsafeRow(111)),
ReceiverRow(unsafeRow(222)),
ReceiverRow(unsafeRow(333)),
ReceiverEpochMarker()
ReceiverRow(0, unsafeRow(111)),
ReceiverRow(0, unsafeRow(222)),
ReceiverRow(0, unsafeRow(333)),
ReceiverEpochMarker(0)
)
val iter = rdd.compute(rdd.partitions(0), ctx)
@ -101,11 +110,11 @@ class ContinuousShuffleReadSuite extends StreamTest {
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
send(
endpoint,
ReceiverRow(unsafeRow(111)),
ReceiverEpochMarker(),
ReceiverRow(unsafeRow(222)),
ReceiverRow(unsafeRow(333)),
ReceiverEpochMarker()
ReceiverRow(0, unsafeRow(111)),
ReceiverEpochMarker(0),
ReceiverRow(0, unsafeRow(222)),
ReceiverRow(0, unsafeRow(333)),
ReceiverEpochMarker(0)
)
val firstEpoch = rdd.compute(rdd.partitions(0), ctx)
@ -118,14 +127,15 @@ class ContinuousShuffleReadSuite extends StreamTest {
test("empty epochs") {
val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
send(
endpoint,
ReceiverEpochMarker(),
ReceiverEpochMarker(),
ReceiverRow(unsafeRow(111)),
ReceiverEpochMarker(),
ReceiverEpochMarker(),
ReceiverEpochMarker()
ReceiverEpochMarker(0),
ReceiverEpochMarker(0),
ReceiverRow(0, unsafeRow(111)),
ReceiverEpochMarker(0),
ReceiverEpochMarker(0),
ReceiverEpochMarker(0)
)
assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
@ -146,8 +156,8 @@ class ContinuousShuffleReadSuite extends StreamTest {
// Send index for identification.
send(
part.endpoint,
ReceiverRow(unsafeRow(part.index)),
ReceiverEpochMarker()
ReceiverRow(0, unsafeRow(part.index)),
ReceiverEpochMarker(0)
)
}
@ -160,25 +170,122 @@ class ContinuousShuffleReadSuite extends StreamTest {
}
test("blocks waiting for new rows") {
val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
val rdd = new ContinuousShuffleReadRDD(
sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue)
val epoch = rdd.compute(rdd.partitions(0), ctx)
val readRowThread = new Thread {
override def run(): Unit = {
// set the non-inheritable thread local
TaskContext.setTaskContext(ctx)
val epoch = rdd.compute(rdd.partitions(0), ctx)
epoch.next().getInt(0)
try {
epoch.next().getInt(0)
} catch {
case _: InterruptedException => // do nothing - expected at test ending
}
}
}
try {
readRowThread.start()
eventually(timeout(streamingTimeout)) {
assert(readRowThread.getState == Thread.State.WAITING)
assert(readRowThread.getState == Thread.State.TIMED_WAITING)
}
} finally {
readRowThread.interrupt()
readRowThread.join()
}
}
test("multiple writers") {
val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3)
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
send(
endpoint,
ReceiverRow(0, unsafeRow("writer0-row0")),
ReceiverRow(1, unsafeRow("writer1-row0")),
ReceiverRow(2, unsafeRow("writer2-row0")),
ReceiverEpochMarker(0),
ReceiverEpochMarker(1),
ReceiverEpochMarker(2)
)
val firstEpoch = rdd.compute(rdd.partitions(0), ctx)
assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet ==
Set("writer0-row0", "writer1-row0", "writer2-row0"))
}
test("epoch only ends when all writers send markers") {
val rdd = new ContinuousShuffleReadRDD(
sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue)
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
send(
endpoint,
ReceiverRow(0, unsafeRow("writer0-row0")),
ReceiverRow(1, unsafeRow("writer1-row0")),
ReceiverRow(2, unsafeRow("writer2-row0")),
ReceiverEpochMarker(0),
ReceiverEpochMarker(2)
)
val epoch = rdd.compute(rdd.partitions(0), ctx)
val rows = (0 until 3).map(_ => epoch.next()).toSet
assert(rows.map(_.getUTF8String(0).toString) ==
Set("writer0-row0", "writer1-row0", "writer2-row0"))
// After checking the right rows, block until we get an epoch marker indicating there's no next.
// (Also fail the assertion if for some reason we get a row.)
val readEpochMarkerThread = new Thread {
override def run(): Unit = {
assert(!epoch.hasNext)
}
}
readEpochMarkerThread.start()
eventually(timeout(streamingTimeout)) {
assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING)
}
// Send the last epoch marker - now the epoch should finish.
send(endpoint, ReceiverEpochMarker(1))
eventually(timeout(streamingTimeout)) {
!readEpochMarkerThread.isAlive
}
// Join to pick up assertion failures.
readEpochMarkerThread.join()
}
test("writer epochs non aligned") {
val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3)
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
// We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should
// collate them as though the markers were aligned in the first place.
send(
endpoint,
ReceiverRow(0, unsafeRow("writer0-row0")),
ReceiverEpochMarker(0),
ReceiverRow(0, unsafeRow("writer0-row1")),
ReceiverEpochMarker(0),
ReceiverEpochMarker(0),
ReceiverEpochMarker(1),
ReceiverRow(1, unsafeRow("writer1-row0")),
ReceiverEpochMarker(1),
ReceiverRow(1, unsafeRow("writer1-row1")),
ReceiverEpochMarker(1),
ReceiverEpochMarker(2),
ReceiverEpochMarker(2),
ReceiverRow(2, unsafeRow("writer2-row0")),
ReceiverEpochMarker(2)
)
val firstEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet
assert(firstEpoch == Set("writer0-row0"))
val secondEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet
assert(secondEpoch == Set("writer0-row1", "writer1-row0"))
val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet
assert(thirdEpoch == Set("writer1-row1", "writer2-row0"))
}
}