[SPARK-19520][STREAMING] Do not encrypt data written to the WAL.

Spark's I/O encryption uses an ephemeral key for each driver instance.
So driver B cannot decrypt data written by driver A since it doesn't
have the correct key.

The write ahead log is used for recovery, thus needs to be readable by
a different driver. So it cannot be encrypted by Spark's I/O encryption
code.

The BlockManager APIs used by the WAL code to write the data automatically
encrypt data, so changes are needed so that callers can to opt out of
encryption.

Aside from that, the "putBytes" API in the BlockManager does not do
encryption, so a separate situation arised where the WAL would write
unencrypted data to the BM and, when those blocks were read, decryption
would fail. So the WAL code needs to ask the BM to encrypt that data
when encryption is enabled; this code is not optimal since it results
in a (temporary) second copy of the data block in memory, but should be
OK for now until a more performant solution is added. The non-encryption
case should not be affected.

Tested with new unit tests, and by running streaming apps that do
recovery using the WAL data with I/O encryption turned on.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #16862 from vanzin/SPARK-19520.
This commit is contained in:
Marcelo Vanzin 2017-02-13 14:19:41 -08:00
parent 9af8f743b0
commit 0169360ef5
8 changed files with 120 additions and 30 deletions

View file

@ -184,7 +184,7 @@ import org.apache.spark.util.Utils
private[spark] class SecurityManager( private[spark] class SecurityManager(
sparkConf: SparkConf, sparkConf: SparkConf,
ioEncryptionKey: Option[Array[Byte]] = None) val ioEncryptionKey: Option[Array[Byte]] = None)
extends Logging with SecretKeyHolder { extends Logging with SecretKeyHolder {
import SecurityManager._ import SecurityManager._

View file

@ -171,20 +171,26 @@ private[spark] class SerializerManager(
} }
/** Serializes into a chunked byte buffer. */ /** Serializes into a chunked byte buffer. */
def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = { def dataSerialize[T: ClassTag](
dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]]) blockId: BlockId,
values: Iterator[T],
allowEncryption: Boolean = true): ChunkedByteBuffer = {
dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]],
allowEncryption = allowEncryption)
} }
/** Serializes into a chunked byte buffer. */ /** Serializes into a chunked byte buffer. */
def dataSerializeWithExplicitClassTag( def dataSerializeWithExplicitClassTag(
blockId: BlockId, blockId: BlockId,
values: Iterator[_], values: Iterator[_],
classTag: ClassTag[_]): ChunkedByteBuffer = { classTag: ClassTag[_],
allowEncryption: Boolean = true): ChunkedByteBuffer = {
val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate) val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate)
val byteStream = new BufferedOutputStream(bbos) val byteStream = new BufferedOutputStream(bbos)
val autoPick = !blockId.isInstanceOf[StreamBlockId] val autoPick = !blockId.isInstanceOf[StreamBlockId]
val ser = getSerializer(classTag, autoPick).newInstance() val ser = getSerializer(classTag, autoPick).newInstance()
ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close() val encrypted = if (allowEncryption) wrapForEncryption(byteStream) else byteStream
ser.serializeStream(wrapForCompression(blockId, encrypted)).writeAll(values).close()
bbos.toChunkedByteBuffer bbos.toChunkedByteBuffer
} }
@ -194,13 +200,15 @@ private[spark] class SerializerManager(
*/ */
def dataDeserializeStream[T]( def dataDeserializeStream[T](
blockId: BlockId, blockId: BlockId,
inputStream: InputStream) inputStream: InputStream,
maybeEncrypted: Boolean = true)
(classTag: ClassTag[T]): Iterator[T] = { (classTag: ClassTag[T]): Iterator[T] = {
val stream = new BufferedInputStream(inputStream) val stream = new BufferedInputStream(inputStream)
val autoPick = !blockId.isInstanceOf[StreamBlockId] val autoPick = !blockId.isInstanceOf[StreamBlockId]
val decrypted = if (maybeEncrypted) wrapForEncryption(inputStream) else inputStream
getSerializer(classTag, autoPick) getSerializer(classTag, autoPick)
.newInstance() .newInstance()
.deserializeStream(wrapStream(blockId, stream)) .deserializeStream(wrapForCompression(blockId, decrypted))
.asIterator.asInstanceOf[Iterator[T]] .asIterator.asInstanceOf[Iterator[T]]
} }
} }

View file

@ -28,6 +28,8 @@ import scala.reflect.ClassTag
import scala.util.Random import scala.util.Random
import scala.util.control.NonFatal import scala.util.control.NonFatal
import com.google.common.io.ByteStreams
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
@ -38,6 +40,7 @@ import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
import org.apache.spark.rpc.RpcEnv import org.apache.spark.rpc.RpcEnv
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage.memory._ import org.apache.spark.storage.memory._
@ -752,15 +755,43 @@ private[spark] class BlockManager(
/** /**
* Put a new block of serialized bytes to the block manager. * Put a new block of serialized bytes to the block manager.
* *
* @param encrypt If true, asks the block manager to encrypt the data block before storing,
* when I/O encryption is enabled. This is required for blocks that have been
* read from unencrypted sources, since all the BlockManager read APIs
* automatically do decryption.
* @return true if the block was stored or false if an error occurred. * @return true if the block was stored or false if an error occurred.
*/ */
def putBytes[T: ClassTag]( def putBytes[T: ClassTag](
blockId: BlockId, blockId: BlockId,
bytes: ChunkedByteBuffer, bytes: ChunkedByteBuffer,
level: StorageLevel, level: StorageLevel,
tellMaster: Boolean = true): Boolean = { tellMaster: Boolean = true,
encrypt: Boolean = false): Boolean = {
require(bytes != null, "Bytes is null") require(bytes != null, "Bytes is null")
doPutBytes(blockId, bytes, level, implicitly[ClassTag[T]], tellMaster)
val bytesToStore =
if (encrypt && securityManager.ioEncryptionKey.isDefined) {
try {
val data = bytes.toByteBuffer
val in = new ByteBufferInputStream(data, true)
val byteBufOut = new ByteBufferOutputStream(data.remaining())
val out = CryptoStreamUtils.createCryptoOutputStream(byteBufOut, conf,
securityManager.ioEncryptionKey.get)
try {
ByteStreams.copy(in, out)
} finally {
in.close()
out.close()
}
new ChunkedByteBuffer(byteBufOut.toByteBuffer)
} finally {
bytes.dispose()
}
} else {
bytes
}
doPutBytes(blockId, bytesToStore, level, implicitly[ClassTag[T]], tellMaster)
} }
/** /**

View file

@ -2017,6 +2017,9 @@ To run a Spark Streaming applications, you need to have the following.
`spark.streaming.driver.writeAheadLog.closeFileAfterWrite` and `spark.streaming.driver.writeAheadLog.closeFileAfterWrite` and
`spark.streaming.receiver.writeAheadLog.closeFileAfterWrite`. See `spark.streaming.receiver.writeAheadLog.closeFileAfterWrite`. See
[Spark Streaming Configuration](configuration.html#spark-streaming) for more details. [Spark Streaming Configuration](configuration.html#spark-streaming) for more details.
Note that Spark will not encrypt data written to the write ahead log when I/O encryption is
enabled. If encryption of the write ahead log data is desired, it should be stored in a file
system that supports encryption natively.
- *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming - *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming
application to process data as fast as it is being received, the receivers can be rate limited application to process data as fast as it is being received, the receivers can be rate limited

View file

@ -27,7 +27,7 @@ import org.apache.spark._
import org.apache.spark.rdd.BlockRDD import org.apache.spark.rdd.BlockRDD
import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.streaming.util._ import org.apache.spark.streaming.util._
import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util._
import org.apache.spark.util.io.ChunkedByteBuffer import org.apache.spark.util.io.ChunkedByteBuffer
/** /**
@ -158,13 +158,16 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
logInfo(s"Read partition data of $this from write ahead log, record handle " + logInfo(s"Read partition data of $this from write ahead log, record handle " +
partition.walRecordHandle) partition.walRecordHandle)
if (storeInBlockManager) { if (storeInBlockManager) {
blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel) blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel,
encrypt = true)
logDebug(s"Stored partition data of $this into block manager with level $storageLevel") logDebug(s"Stored partition data of $this into block manager with level $storageLevel")
dataRead.rewind() dataRead.rewind()
} }
serializerManager serializerManager
.dataDeserializeStream( .dataDeserializeStream(
blockId, new ChunkedByteBuffer(dataRead).toInputStream())(elementClassTag) blockId,
new ChunkedByteBuffer(dataRead).toInputStream(),
maybeEncrypted = false)(elementClassTag)
.asInstanceOf[Iterator[T]] .asInstanceOf[Iterator[T]]
} }

View file

@ -87,7 +87,8 @@ private[streaming] class BlockManagerBasedBlockHandler(
putResult putResult
case ByteBufferBlock(byteBuffer) => case ByteBufferBlock(byteBuffer) =>
blockManager.putBytes( blockManager.putBytes(
blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true) blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true,
encrypt = true)
case o => case o =>
throw new SparkException( throw new SparkException(
s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}") s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}")
@ -175,10 +176,11 @@ private[streaming] class WriteAheadLogBasedBlockHandler(
val serializedBlock = block match { val serializedBlock = block match {
case ArrayBufferBlock(arrayBuffer) => case ArrayBufferBlock(arrayBuffer) =>
numRecords = Some(arrayBuffer.size.toLong) numRecords = Some(arrayBuffer.size.toLong)
serializerManager.dataSerialize(blockId, arrayBuffer.iterator) serializerManager.dataSerialize(blockId, arrayBuffer.iterator, allowEncryption = false)
case IteratorBlock(iterator) => case IteratorBlock(iterator) =>
val countIterator = new CountingIterator(iterator) val countIterator = new CountingIterator(iterator)
val serializedBlock = serializerManager.dataSerialize(blockId, countIterator) val serializedBlock = serializerManager.dataSerialize(blockId, countIterator,
allowEncryption = false)
numRecords = countIterator.count numRecords = countIterator.count
serializedBlock serializedBlock
case ByteBufferBlock(byteBuffer) => case ByteBufferBlock(byteBuffer) =>
@ -193,7 +195,8 @@ private[streaming] class WriteAheadLogBasedBlockHandler(
blockId, blockId,
serializedBlock, serializedBlock,
effectiveStorageLevel, effectiveStorageLevel,
tellMaster = true) tellMaster = true,
encrypt = true)
if (!putSucceeded) { if (!putSucceeded) {
throw new SparkException( throw new SparkException(
s"Could not store $blockId to block manager with storage level $storageLevel") s"Could not store $blockId to block manager with storage level $storageLevel")

View file

@ -32,10 +32,12 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.memory.StaticMemoryManager import org.apache.spark.memory.StaticMemoryManager
import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.rpc.RpcEnv import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.serializer.{KryoSerializer, SerializerManager}
import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.storage._ import org.apache.spark.storage._
@ -44,7 +46,7 @@ import org.apache.spark.streaming.util._
import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.util.{ManualClock, Utils}
import org.apache.spark.util.io.ChunkedByteBuffer import org.apache.spark.util.io.ChunkedByteBuffer
class ReceivedBlockHandlerSuite abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean)
extends SparkFunSuite extends SparkFunSuite
with BeforeAndAfter with BeforeAndAfter
with Matchers with Matchers
@ -57,14 +59,22 @@ class ReceivedBlockHandlerSuite
val conf = new SparkConf() val conf = new SparkConf()
.set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") .set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1")
.set("spark.app.id", "streaming-test") .set("spark.app.id", "streaming-test")
.set(IO_ENCRYPTION_ENABLED, enableEncryption)
val encryptionKey =
if (enableEncryption) {
Some(CryptoStreamUtils.createKey(conf))
} else {
None
}
val hadoopConf = new Configuration() val hadoopConf = new Configuration()
val streamId = 1 val streamId = 1
val securityMgr = new SecurityManager(conf) val securityMgr = new SecurityManager(conf, encryptionKey)
val broadcastManager = new BroadcastManager(true, conf, securityMgr) val broadcastManager = new BroadcastManager(true, conf, securityMgr)
val mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true) val mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true)
val shuffleManager = new SortShuffleManager(conf) val shuffleManager = new SortShuffleManager(conf)
val serializer = new KryoSerializer(conf) val serializer = new KryoSerializer(conf)
var serializerManager = new SerializerManager(serializer, conf) var serializerManager = new SerializerManager(serializer, conf, encryptionKey)
val manualClock = new ManualClock val manualClock = new ManualClock
val blockManagerSize = 10000000 val blockManagerSize = 10000000
val blockManagerBuffer = new ArrayBuffer[BlockManager]() val blockManagerBuffer = new ArrayBuffer[BlockManager]()
@ -164,7 +174,9 @@ class ReceivedBlockHandlerSuite
val bytes = reader.read(fileSegment) val bytes = reader.read(fileSegment)
reader.close() reader.close()
serializerManager.dataDeserializeStream( serializerManager.dataDeserializeStream(
generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream())(ClassTag.Any).toList generateBlockId(),
new ChunkedByteBuffer(bytes).toInputStream(),
maybeEncrypted = false)(ClassTag.Any).toList
} }
loggedData shouldEqual data loggedData shouldEqual data
} }
@ -208,6 +220,8 @@ class ReceivedBlockHandlerSuite
sparkConf.set("spark.storage.unrollMemoryThreshold", "512") sparkConf.set("spark.storage.unrollMemoryThreshold", "512")
// spark.storage.unrollFraction set to 0.4 for BlockManager // spark.storage.unrollFraction set to 0.4 for BlockManager
sparkConf.set("spark.storage.unrollFraction", "0.4") sparkConf.set("spark.storage.unrollFraction", "0.4")
sparkConf.set(IO_ENCRYPTION_ENABLED, enableEncryption)
// Block Manager with 12000 * 0.4 = 4800 bytes of free space for unroll // Block Manager with 12000 * 0.4 = 4800 bytes of free space for unroll
blockManager = createBlockManager(12000, sparkConf) blockManager = createBlockManager(12000, sparkConf)
@ -343,7 +357,7 @@ class ReceivedBlockHandlerSuite
} }
def dataToByteBuffer(b: Seq[String]) = def dataToByteBuffer(b: Seq[String]) =
serializerManager.dataSerialize(generateBlockId, b.iterator) serializerManager.dataSerialize(generateBlockId, b.iterator, allowEncryption = false)
val blocks = data.grouped(10).toSeq val blocks = data.grouped(10).toSeq
@ -418,3 +432,6 @@ class ReceivedBlockHandlerSuite
private def generateBlockId(): StreamBlockId = StreamBlockId(streamId, scala.util.Random.nextLong) private def generateBlockId(): StreamBlockId = StreamBlockId(streamId, scala.util.Random.nextLong)
} }
class ReceivedBlockHandlerSuite extends BaseReceivedBlockHandlerSuite(false)
class ReceivedBlockHandlerWithEncryptionSuite extends BaseReceivedBlockHandlerSuite(true)

View file

@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
import org.apache.spark.internal.config._
import org.apache.spark.serializer.SerializerManager import org.apache.spark.serializer.SerializerManager
import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId}
import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter} import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter}
@ -45,6 +46,7 @@ class WriteAheadLogBackedBlockRDDSuite
override def beforeEach(): Unit = { override def beforeEach(): Unit = {
super.beforeEach() super.beforeEach()
initSparkContext()
dir = Utils.createTempDir() dir = Utils.createTempDir()
} }
@ -56,22 +58,33 @@ class WriteAheadLogBackedBlockRDDSuite
} }
} }
override def beforeAll(): Unit = { override def afterAll(): Unit = {
super.beforeAll() try {
sparkContext = new SparkContext(conf) stopSparkContext()
} finally {
super.afterAll()
}
}
private def initSparkContext(_conf: Option[SparkConf] = None): Unit = {
if (sparkContext == null) {
sparkContext = new SparkContext(_conf.getOrElse(conf))
blockManager = sparkContext.env.blockManager blockManager = sparkContext.env.blockManager
serializerManager = sparkContext.env.serializerManager serializerManager = sparkContext.env.serializerManager
} }
}
override def afterAll(): Unit = { private def stopSparkContext(): Unit = {
// Copied from LocalSparkContext, simpler than to introduced test dependencies to core tests. // Copied from LocalSparkContext, simpler than to introduced test dependencies to core tests.
try { try {
if (sparkContext != null) {
sparkContext.stop() sparkContext.stop()
}
System.clearProperty("spark.driver.port") System.clearProperty("spark.driver.port")
blockManager = null blockManager = null
serializerManager = null serializerManager = null
} finally { } finally {
super.afterAll() sparkContext = null
} }
} }
@ -106,6 +119,17 @@ class WriteAheadLogBackedBlockRDDSuite
numPartitions = 5, numPartitionsInBM = 0, numPartitionsInWAL = 5, testStoreInBM = true) numPartitions = 5, numPartitionsInBM = 0, numPartitionsInWAL = 5, testStoreInBM = true)
} }
test("read data in block manager and WAL with encryption on") {
stopSparkContext()
try {
val testConf = conf.clone().set(IO_ENCRYPTION_ENABLED, true)
initSparkContext(Some(testConf))
testRDD(numPartitions = 5, numPartitionsInBM = 3, numPartitionsInWAL = 2)
} finally {
stopSparkContext()
}
}
/** /**
* Test the WriteAheadLogBackedRDD, by writing some partitions of the data to block manager * Test the WriteAheadLogBackedRDD, by writing some partitions of the data to block manager
* and the rest to a write ahead log, and then reading it all back using the RDD. * and the rest to a write ahead log, and then reading it all back using the RDD.
@ -226,7 +250,8 @@ class WriteAheadLogBackedBlockRDDSuite
require(blockData.size === blockIds.size) require(blockData.size === blockIds.size)
val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf)
val segments = blockData.zip(blockIds).map { case (data, id) => val segments = blockData.zip(blockIds).map { case (data, id) =>
writer.write(serializerManager.dataSerialize(id, data.iterator).toByteBuffer) writer.write(serializerManager.dataSerialize(id, data.iterator, allowEncryption = false)
.toByteBuffer)
} }
writer.close() writer.close()
segments segments