[SPARK-24234][SS] Support multiple row writers in continuous processing shuffle reader.
## What changes were proposed in this pull request? https://docs.google.com/document/d/1IL4kJoKrZWeyIhklKUJqsW-yEN7V7aL05MmM65AYOfE/edit#heading=h.8t3ci57f7uii Support multiple different row writers in continuous processing shuffle reader. Note that having multiple read-side buffers ended up being the natural way to do this. Otherwise it's hard to express the constraint of sending an epoch marker only when all writers have sent one. ## How was this patch tested? new unit tests Author: Jose Torres <torres.joseph.f+github@gmail.com> Closes #21385 from jose-torres/multipleWrite.
This commit is contained in:
parent
53c06ddabb
commit
0fd68cb727
|
@ -25,11 +25,16 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow
|
|||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.util.NextIterator
|
||||
|
||||
case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition {
|
||||
case class ContinuousShuffleReadPartition(
|
||||
index: Int,
|
||||
queueSize: Int,
|
||||
numShuffleWriters: Int,
|
||||
epochIntervalMs: Long)
|
||||
extends Partition {
|
||||
// Initialized only on the executor, and only once even as we call compute() multiple times.
|
||||
lazy val (reader: ContinuousShuffleReader, endpoint) = {
|
||||
val env = SparkEnv.get.rpcEnv
|
||||
val receiver = new UnsafeRowReceiver(queueSize, env)
|
||||
val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, epochIntervalMs, env)
|
||||
val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver)
|
||||
TaskContext.get().addTaskCompletionListener { ctx =>
|
||||
env.stop(endpoint)
|
||||
|
@ -42,16 +47,24 @@ case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Pa
|
|||
* RDD at the map side of each continuous processing shuffle task. Upstream tasks send their
|
||||
* shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks
|
||||
* poll from their receiver until an epoch marker is sent.
|
||||
*
|
||||
* @param sc the RDD context
|
||||
* @param numPartitions the number of read partitions for this RDD
|
||||
* @param queueSize the size of the row buffers to use
|
||||
* @param numShuffleWriters the number of continuous shuffle writers feeding into this RDD
|
||||
* @param epochIntervalMs the checkpoint interval of the streaming query
|
||||
*/
|
||||
class ContinuousShuffleReadRDD(
|
||||
sc: SparkContext,
|
||||
numPartitions: Int,
|
||||
queueSize: Int = 1024)
|
||||
queueSize: Int = 1024,
|
||||
numShuffleWriters: Int = 1,
|
||||
epochIntervalMs: Long = 1000)
|
||||
extends RDD[UnsafeRow](sc, Nil) {
|
||||
|
||||
override protected def getPartitions: Array[Partition] = {
|
||||
(0 until numPartitions).map { partIndex =>
|
||||
ContinuousShuffleReadPartition(partIndex, queueSize)
|
||||
ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, epochIntervalMs)
|
||||
}.toArray
|
||||
}
|
||||
|
||||
|
|
|
@ -17,9 +17,11 @@
|
|||
|
||||
package org.apache.spark.sql.execution.streaming.continuous.shuffle
|
||||
|
||||
import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue}
|
||||
import java.util.concurrent._
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
|
||||
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
|
||||
|
@ -27,10 +29,17 @@ import org.apache.spark.util.NextIterator
|
|||
|
||||
/**
|
||||
* Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker.
|
||||
*
|
||||
* Each message comes tagged with writerId, identifying which writer the message is coming
|
||||
* from. The receiver will only begin the next epoch once all writers have sent an epoch
|
||||
* marker ending the current epoch.
|
||||
*/
|
||||
private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable
|
||||
private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceiverMessage
|
||||
private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage
|
||||
private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable {
|
||||
def writerId: Int
|
||||
}
|
||||
private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow)
|
||||
extends UnsafeRowReceiverMessage
|
||||
private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends UnsafeRowReceiverMessage
|
||||
|
||||
/**
|
||||
* RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle
|
||||
|
@ -41,11 +50,15 @@ private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessa
|
|||
*/
|
||||
private[shuffle] class UnsafeRowReceiver(
|
||||
queueSize: Int,
|
||||
numShuffleWriters: Int,
|
||||
epochIntervalMs: Long,
|
||||
override val rpcEnv: RpcEnv)
|
||||
extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging {
|
||||
// Note that this queue will be drained from the main task thread and populated in the RPC
|
||||
// response thread.
|
||||
private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize)
|
||||
private val queues = Array.fill(numShuffleWriters) {
|
||||
new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize)
|
||||
}
|
||||
|
||||
// Exposed for testing to determine if the endpoint gets stopped on task end.
|
||||
private[shuffle] val stopped = new AtomicBoolean(false)
|
||||
|
@ -56,20 +69,70 @@ private[shuffle] class UnsafeRowReceiver(
|
|||
|
||||
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
|
||||
case r: UnsafeRowReceiverMessage =>
|
||||
queue.put(r)
|
||||
queues(r.writerId).put(r)
|
||||
context.reply(())
|
||||
}
|
||||
|
||||
override def read(): Iterator[UnsafeRow] = {
|
||||
new NextIterator[UnsafeRow] {
|
||||
override def getNext(): UnsafeRow = queue.take() match {
|
||||
case ReceiverRow(r) => r
|
||||
case ReceiverEpochMarker() =>
|
||||
finished = true
|
||||
null
|
||||
// An array of flags for whether each writer ID has gotten an epoch marker.
|
||||
private val writerEpochMarkersReceived = Array.fill(numShuffleWriters)(false)
|
||||
|
||||
private val executor = Executors.newFixedThreadPool(numShuffleWriters)
|
||||
private val completion = new ExecutorCompletionService[UnsafeRowReceiverMessage](executor)
|
||||
|
||||
private def completionTask(writerId: Int) = new Callable[UnsafeRowReceiverMessage] {
|
||||
override def call(): UnsafeRowReceiverMessage = queues(writerId).take()
|
||||
}
|
||||
|
||||
override def close(): Unit = {}
|
||||
// Initialize by submitting tasks to read the first row from each writer.
|
||||
(0 until numShuffleWriters).foreach(writerId => completion.submit(completionTask(writerId)))
|
||||
|
||||
/**
|
||||
* In each call to getNext(), we pull the next row available in the completion queue, and then
|
||||
* submit another task to read the next row from the writer which returned it.
|
||||
*
|
||||
* When a writer sends an epoch marker, we note that it's finished and don't submit another
|
||||
* task for it in this epoch. The iterator is over once all writers have sent an epoch marker.
|
||||
*/
|
||||
override def getNext(): UnsafeRow = {
|
||||
var nextRow: UnsafeRow = null
|
||||
while (!finished && nextRow == null) {
|
||||
completion.poll(epochIntervalMs, TimeUnit.MILLISECONDS) match {
|
||||
case null =>
|
||||
// Try again if the poll didn't wait long enough to get a real result.
|
||||
// But we should be getting at least an epoch marker every checkpoint interval.
|
||||
val writerIdsUncommitted = writerEpochMarkersReceived.zipWithIndex.collect {
|
||||
case (flag, idx) if !flag => idx
|
||||
}
|
||||
logWarning(
|
||||
s"Completion service failed to make progress after $epochIntervalMs ms. Waiting " +
|
||||
s"for writers $writerIdsUncommitted to send epoch markers.")
|
||||
|
||||
// The completion service guarantees this future will be available immediately.
|
||||
case future => future.get() match {
|
||||
case ReceiverRow(writerId, r) =>
|
||||
// Start reading the next element in the queue we just took from.
|
||||
completion.submit(completionTask(writerId))
|
||||
nextRow = r
|
||||
case ReceiverEpochMarker(writerId) =>
|
||||
// Don't read any more from this queue. If all the writers have sent epoch markers,
|
||||
// the epoch is over; otherwise we need to loop again to poll from the remaining
|
||||
// writers.
|
||||
writerEpochMarkersReceived(writerId) = true
|
||||
if (writerEpochMarkersReceived.forall(_ == true)) {
|
||||
finished = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nextRow
|
||||
}
|
||||
|
||||
override def close(): Unit = {
|
||||
executor.shutdownNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,7 +21,8 @@ import org.apache.spark.{TaskContext, TaskContextImpl}
|
|||
import org.apache.spark.rpc.RpcEndpointRef
|
||||
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection}
|
||||
import org.apache.spark.sql.streaming.StreamTest
|
||||
import org.apache.spark.sql.types.{DataType, IntegerType}
|
||||
import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
class ContinuousShuffleReadSuite extends StreamTest {
|
||||
|
||||
|
@ -30,6 +31,11 @@ class ContinuousShuffleReadSuite extends StreamTest {
|
|||
new GenericInternalRow(Array(value: Any)))
|
||||
}
|
||||
|
||||
private def unsafeRow(value: String) = {
|
||||
UnsafeProjection.create(Array(StringType : DataType))(
|
||||
new GenericInternalRow(Array(UTF8String.fromString(value): Any)))
|
||||
}
|
||||
|
||||
private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = {
|
||||
messages.foreach(endpoint.askSync[Unit](_))
|
||||
}
|
||||
|
@ -57,8 +63,8 @@ class ContinuousShuffleReadSuite extends StreamTest {
|
|||
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
|
||||
send(
|
||||
endpoint,
|
||||
ReceiverEpochMarker(),
|
||||
ReceiverRow(unsafeRow(111))
|
||||
ReceiverEpochMarker(0),
|
||||
ReceiverRow(0, unsafeRow(111))
|
||||
)
|
||||
|
||||
ctx.markTaskCompleted(None)
|
||||
|
@ -71,8 +77,11 @@ class ContinuousShuffleReadSuite extends StreamTest {
|
|||
test("receiver stopped with marker last") {
|
||||
val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
|
||||
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
|
||||
endpoint.askSync[Unit](ReceiverRow(unsafeRow(111)))
|
||||
endpoint.askSync[Unit](ReceiverEpochMarker())
|
||||
send(
|
||||
endpoint,
|
||||
ReceiverRow(0, unsafeRow(111)),
|
||||
ReceiverEpochMarker(0)
|
||||
)
|
||||
|
||||
ctx.markTaskCompleted(None)
|
||||
val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
|
||||
|
@ -86,10 +95,10 @@ class ContinuousShuffleReadSuite extends StreamTest {
|
|||
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
|
||||
send(
|
||||
endpoint,
|
||||
ReceiverRow(unsafeRow(111)),
|
||||
ReceiverRow(unsafeRow(222)),
|
||||
ReceiverRow(unsafeRow(333)),
|
||||
ReceiverEpochMarker()
|
||||
ReceiverRow(0, unsafeRow(111)),
|
||||
ReceiverRow(0, unsafeRow(222)),
|
||||
ReceiverRow(0, unsafeRow(333)),
|
||||
ReceiverEpochMarker(0)
|
||||
)
|
||||
|
||||
val iter = rdd.compute(rdd.partitions(0), ctx)
|
||||
|
@ -101,11 +110,11 @@ class ContinuousShuffleReadSuite extends StreamTest {
|
|||
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
|
||||
send(
|
||||
endpoint,
|
||||
ReceiverRow(unsafeRow(111)),
|
||||
ReceiverEpochMarker(),
|
||||
ReceiverRow(unsafeRow(222)),
|
||||
ReceiverRow(unsafeRow(333)),
|
||||
ReceiverEpochMarker()
|
||||
ReceiverRow(0, unsafeRow(111)),
|
||||
ReceiverEpochMarker(0),
|
||||
ReceiverRow(0, unsafeRow(222)),
|
||||
ReceiverRow(0, unsafeRow(333)),
|
||||
ReceiverEpochMarker(0)
|
||||
)
|
||||
|
||||
val firstEpoch = rdd.compute(rdd.partitions(0), ctx)
|
||||
|
@ -118,14 +127,15 @@ class ContinuousShuffleReadSuite extends StreamTest {
|
|||
test("empty epochs") {
|
||||
val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
|
||||
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
|
||||
|
||||
send(
|
||||
endpoint,
|
||||
ReceiverEpochMarker(),
|
||||
ReceiverEpochMarker(),
|
||||
ReceiverRow(unsafeRow(111)),
|
||||
ReceiverEpochMarker(),
|
||||
ReceiverEpochMarker(),
|
||||
ReceiverEpochMarker()
|
||||
ReceiverEpochMarker(0),
|
||||
ReceiverEpochMarker(0),
|
||||
ReceiverRow(0, unsafeRow(111)),
|
||||
ReceiverEpochMarker(0),
|
||||
ReceiverEpochMarker(0),
|
||||
ReceiverEpochMarker(0)
|
||||
)
|
||||
|
||||
assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
|
||||
|
@ -146,8 +156,8 @@ class ContinuousShuffleReadSuite extends StreamTest {
|
|||
// Send index for identification.
|
||||
send(
|
||||
part.endpoint,
|
||||
ReceiverRow(unsafeRow(part.index)),
|
||||
ReceiverEpochMarker()
|
||||
ReceiverRow(0, unsafeRow(part.index)),
|
||||
ReceiverEpochMarker(0)
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -160,25 +170,122 @@ class ContinuousShuffleReadSuite extends StreamTest {
|
|||
}
|
||||
|
||||
test("blocks waiting for new rows") {
|
||||
val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
|
||||
val rdd = new ContinuousShuffleReadRDD(
|
||||
sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue)
|
||||
val epoch = rdd.compute(rdd.partitions(0), ctx)
|
||||
|
||||
val readRowThread = new Thread {
|
||||
override def run(): Unit = {
|
||||
// set the non-inheritable thread local
|
||||
TaskContext.setTaskContext(ctx)
|
||||
val epoch = rdd.compute(rdd.partitions(0), ctx)
|
||||
epoch.next().getInt(0)
|
||||
try {
|
||||
epoch.next().getInt(0)
|
||||
} catch {
|
||||
case _: InterruptedException => // do nothing - expected at test ending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
readRowThread.start()
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
assert(readRowThread.getState == Thread.State.WAITING)
|
||||
assert(readRowThread.getState == Thread.State.TIMED_WAITING)
|
||||
}
|
||||
} finally {
|
||||
readRowThread.interrupt()
|
||||
readRowThread.join()
|
||||
}
|
||||
}
|
||||
|
||||
test("multiple writers") {
|
||||
val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3)
|
||||
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
|
||||
send(
|
||||
endpoint,
|
||||
ReceiverRow(0, unsafeRow("writer0-row0")),
|
||||
ReceiverRow(1, unsafeRow("writer1-row0")),
|
||||
ReceiverRow(2, unsafeRow("writer2-row0")),
|
||||
ReceiverEpochMarker(0),
|
||||
ReceiverEpochMarker(1),
|
||||
ReceiverEpochMarker(2)
|
||||
)
|
||||
|
||||
val firstEpoch = rdd.compute(rdd.partitions(0), ctx)
|
||||
assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet ==
|
||||
Set("writer0-row0", "writer1-row0", "writer2-row0"))
|
||||
}
|
||||
|
||||
test("epoch only ends when all writers send markers") {
|
||||
val rdd = new ContinuousShuffleReadRDD(
|
||||
sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs = Long.MaxValue)
|
||||
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
|
||||
send(
|
||||
endpoint,
|
||||
ReceiverRow(0, unsafeRow("writer0-row0")),
|
||||
ReceiverRow(1, unsafeRow("writer1-row0")),
|
||||
ReceiverRow(2, unsafeRow("writer2-row0")),
|
||||
ReceiverEpochMarker(0),
|
||||
ReceiverEpochMarker(2)
|
||||
)
|
||||
|
||||
val epoch = rdd.compute(rdd.partitions(0), ctx)
|
||||
val rows = (0 until 3).map(_ => epoch.next()).toSet
|
||||
assert(rows.map(_.getUTF8String(0).toString) ==
|
||||
Set("writer0-row0", "writer1-row0", "writer2-row0"))
|
||||
|
||||
// After checking the right rows, block until we get an epoch marker indicating there's no next.
|
||||
// (Also fail the assertion if for some reason we get a row.)
|
||||
val readEpochMarkerThread = new Thread {
|
||||
override def run(): Unit = {
|
||||
assert(!epoch.hasNext)
|
||||
}
|
||||
}
|
||||
|
||||
readEpochMarkerThread.start()
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING)
|
||||
}
|
||||
|
||||
// Send the last epoch marker - now the epoch should finish.
|
||||
send(endpoint, ReceiverEpochMarker(1))
|
||||
eventually(timeout(streamingTimeout)) {
|
||||
!readEpochMarkerThread.isAlive
|
||||
}
|
||||
|
||||
// Join to pick up assertion failures.
|
||||
readEpochMarkerThread.join()
|
||||
}
|
||||
|
||||
test("writer epochs non aligned") {
|
||||
val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, numShuffleWriters = 3)
|
||||
val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
|
||||
// We send multiple epochs for 0, then multiple for 1, then multiple for 2. The receiver should
|
||||
// collate them as though the markers were aligned in the first place.
|
||||
send(
|
||||
endpoint,
|
||||
ReceiverRow(0, unsafeRow("writer0-row0")),
|
||||
ReceiverEpochMarker(0),
|
||||
ReceiverRow(0, unsafeRow("writer0-row1")),
|
||||
ReceiverEpochMarker(0),
|
||||
ReceiverEpochMarker(0),
|
||||
|
||||
ReceiverEpochMarker(1),
|
||||
ReceiverRow(1, unsafeRow("writer1-row0")),
|
||||
ReceiverEpochMarker(1),
|
||||
ReceiverRow(1, unsafeRow("writer1-row1")),
|
||||
ReceiverEpochMarker(1),
|
||||
|
||||
ReceiverEpochMarker(2),
|
||||
ReceiverEpochMarker(2),
|
||||
ReceiverRow(2, unsafeRow("writer2-row0")),
|
||||
ReceiverEpochMarker(2)
|
||||
)
|
||||
|
||||
val firstEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet
|
||||
assert(firstEpoch == Set("writer0-row0"))
|
||||
|
||||
val secondEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet
|
||||
assert(secondEpoch == Set("writer0-row1", "writer1-row0"))
|
||||
|
||||
val thirdEpoch = rdd.compute(rdd.partitions(0), ctx).map(_.getUTF8String(0).toString).toSet
|
||||
assert(thirdEpoch == Set("writer1-row1", "writer2-row0"))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue