[SPARK-24160] ShuffleBlockFetcherIterator should fail if it receives zero-size blocks

## What changes were proposed in this pull request?

This patch modifies `ShuffleBlockFetcherIterator` so that the receipt of zero-size blocks is treated as an error. This is done as a preventative measure to guard against a potential source of data loss bugs.

In the shuffle layer, we guarantee that zero-size blocks will never be requested (a block containing zero records is always 0 bytes in size and is marked as empty such that it will never be legitimately requested by executors). However, the existing code does not fully take advantage of this invariant in the shuffle-read path: the existing code did not explicitly check whether blocks are non-zero-size.

Additionally, our decompression and deserialization streams treat zero-size inputs as empty streams rather than errors (EOF might actually be treated as "end-of-stream" in certain layers (longstanding behavior dating to earliest versions of Spark) and decompressors like Snappy may be tolerant to zero-size inputs).

As a result, if some other bug causes legitimate buffers to be replaced with zero-sized buffers (due to corruption on either the send or receive sides) then this would translate into silent data loss rather than an explicit fail-fast error.

This patch addresses this problem by adding a `buf.size != 0` check. See code comments for pointers to tests which guarantee the invariants relied on here.

## How was this patch tested?

Existing tests (which required modifications, since some were creating empty buffers in mocks). I also added a test to make sure we fail on zero-size blocks.

To test that the zero-size blocks are indeed a potential corruption source, I manually ran a workload in `spark-shell` with a modified build which replaces all buffers with zero-size buffers in the receive path.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #21219 from JoshRosen/SPARK-24160.
This commit is contained in:
Josh Rosen 2018-05-07 14:34:03 +08:00 committed by Wenchen Fan
parent 7564a9a706
commit d2aa859b4f
2 changed files with 70 additions and 20 deletions

View file

@ -414,6 +414,25 @@ final class ShuffleBlockFetcherIterator(
logDebug("Number of requests in flight " + reqsInFlight)
}
if (buf.size == 0) {
// We will never legitimately receive a zero-size block. All blocks with zero records
// have zero size and all zero-size blocks have no records (and hence should never
// have been requested in the first place). This statement relies on behaviors of the
// shuffle writers, which are guaranteed by the following test cases:
//
// - BypassMergeSortShuffleWriterSuite: "write with some empty partitions"
// - UnsafeShuffleWriterSuite: "writeEmptyIterator"
// - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing"
//
// There is not an explicit test for SortShuffleWriter but the underlying APIs that
// uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter
// which returns a zero-size from commitAndGet() in case no records were written
// since the last call.
val msg = s"Received a zero-size buffer for block $blockId from $address " +
s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)"
throwFetchFailedException(blockId, address, new IOException(msg))
}
val in = try {
buf.createInputStream()
} catch {

View file

@ -65,12 +65,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
}
// Create a mock managed buffer for testing
def createMockManagedBuffer(): ManagedBuffer = {
def createMockManagedBuffer(size: Int = 1): ManagedBuffer = {
val mockManagedBuffer = mock(classOf[ManagedBuffer])
val in = mock(classOf[InputStream])
when(in.read(any())).thenReturn(1)
when(in.read(any(), any(), any())).thenReturn(1)
when(mockManagedBuffer.createInputStream()).thenReturn(in)
when(mockManagedBuffer.size()).thenReturn(size)
mockManagedBuffer
}
@ -269,6 +270,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
intercept[FetchFailedException] { iterator.next() }
}
private def mockCorruptBuffer(size: Long = 1L): ManagedBuffer = {
val corruptStream = mock(classOf[InputStream])
when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
val corruptBuffer = mock(classOf[ManagedBuffer])
when(corruptBuffer.size()).thenReturn(size)
when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
corruptBuffer
}
test("retry corrupt blocks") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
@ -284,11 +294,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// Semaphore to coordinate event sequence in two different threads.
val sem = new Semaphore(0)
val corruptStream = mock(classOf[InputStream])
when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
val corruptBuffer = mock(classOf[ManagedBuffer])
when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100)
val transfer = mock(classOf[BlockTransferService])
@ -301,7 +306,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer)
sem.release()
@ -339,7 +344,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
Future {
// Return the first block, and then fail.
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
sem.release()
}
}
@ -353,11 +358,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
}
test("big blocks are not checked for corruption") {
val corruptStream = mock(classOf[InputStream])
when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
val corruptBuffer = mock(classOf[ManagedBuffer])
when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
doReturn(10000L).when(corruptBuffer).size()
val corruptBuffer = mockCorruptBuffer(10000L)
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
@ -413,11 +414,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// Semaphore to coordinate event sequence in two different threads.
val sem = new Semaphore(0)
val corruptStream = mock(classOf[InputStream])
when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
val corruptBuffer = mock(classOf[ManagedBuffer])
when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
@ -428,9 +424,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 2, 0).toString, corruptBuffer)
ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer())
sem.release()
}
}
@ -527,4 +523,39 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// shuffle block to disk.
assert(tempFileManager != null)
}
test("fail zero-size blocks") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
doReturn(localBmId).when(blockManager).blockManagerId
// Make sure remote blocks would return
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val blocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer()
)
val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0)))
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
blockManager,
blocksByAddress,
(_, in) => in,
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
Int.MaxValue,
true)
// All blocks fetched return zero length and should trigger a receive-side error:
val e = intercept[FetchFailedException] { iterator.next() }
assert(e.getMessage.contains("Received a zero-size buffer"))
}
}