[SPARK-6229] Add SASL encryption to network library.

There are two main parts of this change:

- Extending the bootstrap mechanism in the network library to add a server-side
  bootstrap (which works a little bit differently than the client-side bootstrap), and
  to allow the  bootstraps to modify the underlying channel.

- Use SASL to encrypt data going through the RPC channel.

The second item requires some non-optimal code to be able to work around the
fact that the outbound path in netty is not thread-safe, and ordering is very important
when encryption is in the picture.

A lot of the changes outside the network/common library are just to adjust to the
changed API for initializing the RPC server.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #5377 from vanzin/SPARK-6229 and squashes the following commits:

ff01966 [Marcelo Vanzin] Use fancy new size config style.
be53f32 [Marcelo Vanzin] Merge branch 'master' into SPARK-6229
47d4aff [Marcelo Vanzin] Merge branch 'master' into SPARK-6229
7a2a805 [Marcelo Vanzin] Clean up some unneeded changes.
2f92237 [Marcelo Vanzin] Add comment.
67bb0c6 [Marcelo Vanzin] Revert "Avoid exposing ByteArrayWritableChannel outside of test code."
065f684 [Marcelo Vanzin] Add test to verify chunking.
3d1695d [Marcelo Vanzin] Minor cleanups.
73cff0e [Marcelo Vanzin] Skip bytes in decode path too.
318ad23 [Marcelo Vanzin] Avoid exposing ByteArrayWritableChannel outside of test code.
346f829 [Marcelo Vanzin] Avoid trip through channel selector by not reporting 0 bytes written.
a4a5938 [Marcelo Vanzin] Review feedback.
4797519 [Marcelo Vanzin] Remove unused import.
9908ada [Marcelo Vanzin] Fix test, SASL backend disposal.
7fe1489 [Marcelo Vanzin] Add a test that makes sure encryption is actually enabled.
adb6f9d [Marcelo Vanzin] Review feedback.
cf2a605 [Marcelo Vanzin] Clean up some code.
8584323 [Marcelo Vanzin] Fix a comment.
e98bc55 [Marcelo Vanzin] Add option to only allow encrypted connections to the server.
dad42fc [Marcelo Vanzin] Make encryption thread-safe, less memory-intensive.
b00999a [Marcelo Vanzin] Consolidate ByteArrayWritableChannel, fix SASL code to match master changes.
b923cae [Marcelo Vanzin] Make SASL encryption handler thread-safe, handle FileRegion messages.
39539a7 [Marcelo Vanzin] Add config option to enable SASL encryption.
351a86f [Marcelo Vanzin] Add SASL encryption to network library.
fbe6ccb [Marcelo Vanzin] Add TransportServerBootstrap, make SASL code use it.
This commit is contained in:
Marcelo Vanzin 2015-05-01 19:01:46 -07:00 committed by Reynold Xin
parent 8f50a07d21
commit 38d4e9e446
27 changed files with 1070 additions and 106 deletions

View file

@ -150,8 +150,13 @@ import org.apache.spark.util.Utils
* authorization. If not filter is in place the user is generally null and no authorization
* can take place.
*
* Connection encryption (SSL) configuration is organized hierarchically. The user can configure
* the default SSL settings which will be used for all the supported communication protocols unless
* When authentication is being used, encryption can also be enabled by setting the option
* spark.authenticate.enableSaslEncryption to true. This is only supported by communication
* channels that use the network-common library, and can be used as an alternative to SSL in those
* cases.
*
* SSL can be used for encryption for certain communication channels. The user can configure the
* default SSL settings which will be used for all the supported communication protocols unless
* they are overwritten by protocol specific settings. This way the user can easily provide the
* common settings for all the protocols without disabling the ability to configure each one
* individually.
@ -412,6 +417,14 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
*/
def isAuthenticationEnabled(): Boolean = authOn
/**
* Checks whether SASL encryption should be enabled.
* @return Whether to enable SASL encryption when connecting to services that support it.
*/
def isSaslEncryptionEnabled(): Boolean = {
sparkConf.getBoolean("spark.authenticate.enableSaslEncryption", false)
}
/**
* Gets the user used for authenticating HTTP connections.
* For now use a single hardcoded user.

View file

@ -19,10 +19,12 @@ package org.apache.spark.deploy
import java.util.concurrent.CountDownLatch
import scala.collection.JavaConversions._
import org.apache.spark.{Logging, SparkConf, SecurityManager}
import org.apache.spark.network.TransportContext
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.sasl.SaslRpcHandler
import org.apache.spark.network.sasl.SaslServerBootstrap
import org.apache.spark.network.server.TransportServer
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
import org.apache.spark.util.Utils
@ -44,10 +46,7 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0)
private val blockHandler = new ExternalShuffleBlockHandler(transportConf)
private val transportContext: TransportContext = {
val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler
new TransportContext(transportConf, handler)
}
private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler)
private var server: TransportServer = _
@ -62,7 +61,13 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
def start() {
require(server == null, "Shuffle server already started")
logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl")
server = transportContext.createServer(port)
val bootstraps =
if (useSasl) {
Seq(new SaslServerBootstrap(transportConf, securityManager))
} else {
Nil
}
server = transportContext.createServer(port, bootstraps)
}
def stop() {

View file

@ -24,7 +24,7 @@ import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory}
import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap}
import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
import org.apache.spark.network.shuffle.protocol.UploadBlock
@ -49,18 +49,18 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
private[this] var appId: String = _
override def init(blockDataManager: BlockDataManager): Unit = {
val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = {
val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
if (!authEnabled) {
(nettyRpcHandler, None)
} else {
(new SaslRpcHandler(nettyRpcHandler, securityManager),
Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager)))
}
val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
var serverBootstrap: Option[TransportServerBootstrap] = None
var clientBootstrap: Option[TransportClientBootstrap] = None
if (authEnabled) {
serverBootstrap = Some(new SaslServerBootstrap(transportConf, securityManager))
clientBootstrap = Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager,
securityManager.isSaslEncryptionEnabled()))
}
transportContext = new TransportContext(transportConf, rpcHandler)
clientFactory = transportContext.createClientFactory(bootstrap.toList)
server = transportContext.createServer(conf.getInt("spark.blockManager.port", 0))
clientFactory = transportContext.createClientFactory(clientBootstrap.toList)
server = transportContext.createServer(conf.getInt("spark.blockManager.port", 0),
serverBootstrap.toList)
appId = conf.getAppId
logInfo("Server created on " + server.getPort)
}

View file

@ -656,7 +656,7 @@ private[nio] class ConnectionManager(
connection.synchronized {
if (connection.sparkSaslServer == null) {
logDebug("Creating sasl Server")
connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager)
connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager, false)
}
}
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
@ -800,7 +800,7 @@ private[nio] class ConnectionManager(
if (!conn.isSaslComplete()) {
conn.synchronized {
if (conn.sparkSaslClient == null) {
conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager)
conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager, false)
var firstResponse: Array[Byte] = null
try {
firstResponse = conn.sparkSaslClient.firstToken()

View file

@ -111,7 +111,8 @@ private[spark] class BlockManager(
// standard BlockTransferService to directly connect to other Executors.
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores)
new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled())
new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(),
securityManager.isSaslEncryptionEnabled())
} else {
blockTransferService
}

View file

@ -36,6 +36,7 @@ import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.server.TransportRequestHandler;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;
@ -82,13 +83,21 @@ public class TransportContext {
}
/** Create a server which will attempt to bind to a specific port. */
public TransportServer createServer(int port) {
return new TransportServer(this, port);
public TransportServer createServer(int port, List<TransportServerBootstrap> bootstraps) {
return new TransportServer(this, port, rpcHandler, bootstraps);
}
/** Creates a new server, binding to any available ephemeral port. */
public TransportServer createServer(List<TransportServerBootstrap> bootstraps) {
return createServer(0, bootstraps);
}
public TransportServer createServer() {
return new TransportServer(this, 0);
return createServer(0, Lists.<TransportServerBootstrap>newArrayList());
}
public TransportChannelHandler initializePipeline(SocketChannel channel) {
return initializePipeline(channel, rpcHandler);
}
/**
@ -96,13 +105,18 @@ public class TransportContext {
* has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or
* response messages.
*
* @param channel The channel to initialize.
* @param channelRpcHandler The RPC handler to use for the channel.
*
* @return Returns the created TransportChannelHandler, which includes a TransportClient that can
* be used to communicate on this channel. The TransportClient is directly associated with a
* ChannelHandler to ensure all users of the same channel get the same TransportClient object.
*/
public TransportChannelHandler initializePipeline(SocketChannel channel) {
public TransportChannelHandler initializePipeline(
SocketChannel channel,
RpcHandler channelRpcHandler) {
try {
TransportChannelHandler channelHandler = createChannelHandler(channel);
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
channel.pipeline()
.addLast("encoder", encoder)
.addLast("frameDecoder", NettyUtils.createFrameDecoder())
@ -123,7 +137,7 @@ public class TransportContext {
* ResponseMessages. The channel is expected to have been successfully created, though certain
* properties (such as the remoteAddress()) may not be available yet.
*/
private TransportChannelHandler createChannelHandler(Channel channel) {
private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) {
TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
TransportClient client = new TransportClient(channel, responseHandler);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,

View file

@ -17,6 +17,8 @@
package org.apache.spark.network.client;
import io.netty.channel.Channel;
/**
* A bootstrap which is executed on a TransportClient before it is returned to the user.
* This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per-
@ -28,5 +30,5 @@ package org.apache.spark.network.client;
*/
public interface TransportClientBootstrap {
/** Performs the bootstrapping operation, throwing an exception on failure. */
public void doBootstrap(TransportClient client) throws RuntimeException;
void doBootstrap(TransportClient client, Channel channel) throws RuntimeException;
}

View file

@ -172,12 +172,14 @@ public class TransportClientFactory implements Closeable {
.option(ChannelOption.ALLOCATOR, pooledAllocator);
final AtomicReference<TransportClient> clientRef = new AtomicReference<TransportClient>();
final AtomicReference<Channel> channelRef = new AtomicReference<Channel>();
bootstrap.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) {
TransportChannelHandler clientHandler = context.initializePipeline(ch);
clientRef.set(clientHandler.getClient());
channelRef.set(ch);
}
});
@ -192,6 +194,7 @@ public class TransportClientFactory implements Closeable {
}
TransportClient client = clientRef.get();
Channel channel = channelRef.get();
assert client != null : "Channel future completed successfully with null client";
// Execute any client bootstraps synchronously before marking the Client as successful.
@ -199,7 +202,7 @@ public class TransportClientFactory implements Closeable {
logger.debug("Connection to {} successful, running bootstraps...", address);
try {
for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
clientBootstrap.doBootstrap(client);
clientBootstrap.doBootstrap(client, channel);
}
} catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;

View file

@ -17,8 +17,12 @@
package org.apache.spark.network.sasl;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -33,14 +37,24 @@ import org.apache.spark.network.util.TransportConf;
public class SaslClientBootstrap implements TransportClientBootstrap {
private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class);
private final boolean encrypt;
private final TransportConf conf;
private final String appId;
private final SecretKeyHolder secretKeyHolder;
public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) {
this(conf, appId, secretKeyHolder, false);
}
public SaslClientBootstrap(
TransportConf conf,
String appId,
SecretKeyHolder secretKeyHolder,
boolean encrypt) {
this.conf = conf;
this.appId = appId;
this.secretKeyHolder = secretKeyHolder;
this.encrypt = encrypt;
}
/**
@ -49,8 +63,8 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
* due to mismatch.
*/
@Override
public void doBootstrap(TransportClient client) {
SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder);
public void doBootstrap(TransportClient client, Channel channel) {
SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt);
try {
byte[] payload = saslClient.firstToken();
@ -62,13 +76,26 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs());
payload = saslClient.response(response);
}
if (encrypt) {
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
throw new RuntimeException(
new SaslException("Encryption requests by negotiated non-encrypted connection."));
}
SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
saslClient = null;
logger.debug("Channel {} configured for SASL encryption.", client);
}
} finally {
try {
// Once authentication is complete, the server will trust all remaining communication.
saslClient.dispose();
} catch (RuntimeException e) {
logger.error("Error while disposing SASL client", e);
if (saslClient != null) {
try {
// Once authentication is complete, the server will trust all remaining communication.
saslClient.dispose();
} catch (RuntimeException e) {
logger.error("Error while disposing SASL client", e);
}
}
}
}
}

View file

@ -0,0 +1,291 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.sasl;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.WritableByteChannel;
import java.util.List;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.FileRegion;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.util.AbstractReferenceCounted;
import io.netty.util.ReferenceCountUtil;
import org.apache.spark.network.util.ByteArrayWritableChannel;
import org.apache.spark.network.util.NettyUtils;
/**
* Provides SASL-based encription for transport channels. The single method exposed by this
* class installs the needed channel handlers on a connected channel.
*/
class SaslEncryption {
@VisibleForTesting
static final String ENCRYPTION_HANDLER_NAME = "saslEncryption";
/**
* Adds channel handlers that perform encryption / decryption of data using SASL.
*
* @param channel The channel.
* @param backend The SASL backend.
* @param maxOutboundBlockSize Max size in bytes of outgoing encrypted blocks, to control
* memory usage.
*/
static void addToChannel(
Channel channel,
SaslEncryptionBackend backend,
int maxOutboundBlockSize) {
channel.pipeline()
.addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize))
.addFirst("saslDecryption", new DecryptionHandler(backend))
.addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder());
}
private static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
private final int maxOutboundBlockSize;
private final SaslEncryptionBackend backend;
EncryptionHandler(SaslEncryptionBackend backend, int maxOutboundBlockSize) {
this.backend = backend;
this.maxOutboundBlockSize = maxOutboundBlockSize;
}
/**
* Wrap the incoming message in an implementation that will perform encryption lazily. This is
* needed to guarantee ordering of the outgoing encrypted packets - they need to be decrypted in
* the same order, and netty doesn't have an atomic ChannelHandlerContext.write() API, so it
* does not guarantee any ordering.
*/
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
ctx.write(new EncryptedMessage(backend, msg, maxOutboundBlockSize), promise);
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
try {
backend.dispose();
} finally {
super.handlerRemoved(ctx);
}
}
}
private static class DecryptionHandler extends MessageToMessageDecoder<ByteBuf> {
private final SaslEncryptionBackend backend;
DecryptionHandler(SaslEncryptionBackend backend) {
this.backend = backend;
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out)
throws Exception {
byte[] data;
int offset;
int length = msg.readableBytes();
if (msg.hasArray()) {
data = msg.array();
offset = msg.arrayOffset();
msg.skipBytes(length);
} else {
data = new byte[length];
msg.readBytes(data);
offset = 0;
}
out.add(Unpooled.wrappedBuffer(backend.unwrap(data, offset, length)));
}
}
@VisibleForTesting
static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion {
private final SaslEncryptionBackend backend;
private final boolean isByteBuf;
private final ByteBuf buf;
private final FileRegion region;
/**
* A channel used to buffer input data for encryption. The channel has an upper size bound
* so that if the input is larger than the allowed buffer, it will be broken into multiple
* chunks.
*/
private final ByteArrayWritableChannel byteChannel;
private ByteBuf currentHeader;
private ByteBuffer currentChunk;
private long currentChunkSize;
private long currentReportedBytes;
private long unencryptedChunkSize;
private long transferred;
EncryptedMessage(SaslEncryptionBackend backend, Object msg, int maxOutboundBlockSize) {
Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion,
"Unrecognized message type: %s", msg.getClass().getName());
this.backend = backend;
this.isByteBuf = msg instanceof ByteBuf;
this.buf = isByteBuf ? (ByteBuf) msg : null;
this.region = isByteBuf ? null : (FileRegion) msg;
this.byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize);
}
/**
* Returns the size of the original (unencrypted) message.
*
* This makes assumptions about how netty treats FileRegion instances, because there's no way
* to know beforehand what will be the size of the encrypted message. Namely, it assumes
* that netty will try to transfer data from this message while
* <code>transfered() < count()</code>. So these two methods return, technically, wrong data,
* but netty doesn't know better.
*/
@Override
public long count() {
return isByteBuf ? buf.readableBytes() : region.count();
}
@Override
public long position() {
return 0;
}
/**
* Returns an approximation of the amount of data transferred. See {@link #count()}.
*/
@Override
public long transfered() {
return transferred;
}
/**
* Transfers data from the original message to the channel, encrypting it in the process.
*
* This method also breaks down the original message into smaller chunks when needed. This
* is done to keep memory usage under control. This avoids having to copy the whole message
* data into memory at once, and can avoid ballooning memory usage when transferring large
* messages such as shuffle blocks.
*
* The {@link #transfered()} counter also behaves a little funny, in that it won't go forward
* until a whole chunk has been written. This is done because the code can't use the actual
* number of bytes written to the channel as the transferred count (see {@link #count()}).
* Instead, once an encrypted chunk is written to the output (including its header), the
* size of the original block will be added to the {@link #transfered()} amount.
*/
@Override
public long transferTo(final WritableByteChannel target, final long position)
throws IOException {
Preconditions.checkArgument(position == transfered(), "Invalid position.");
long reportedWritten = 0L;
long actuallyWritten = 0L;
do {
if (currentChunk == null) {
nextChunk();
}
if (currentHeader.readableBytes() > 0) {
int bytesWritten = target.write(currentHeader.nioBuffer());
currentHeader.skipBytes(bytesWritten);
actuallyWritten += bytesWritten;
if (currentHeader.readableBytes() > 0) {
// Break out of loop if there are still header bytes left to write.
break;
}
}
actuallyWritten += target.write(currentChunk);
if (!currentChunk.hasRemaining()) {
// Only update the count of written bytes once a full chunk has been written.
// See method javadoc.
long chunkBytesRemaining = unencryptedChunkSize - currentReportedBytes;
reportedWritten += chunkBytesRemaining;
transferred += chunkBytesRemaining;
currentHeader.release();
currentHeader = null;
currentChunk = null;
currentChunkSize = 0;
currentReportedBytes = 0;
}
} while (currentChunk == null && transfered() + reportedWritten < count());
// Returning 0 triggers a backoff mechanism in netty which may harm performance. Instead,
// we return 1 until we can (i.e. until the reported count would actually match the size
// of the current chunk), at which point we resort to returning 0 so that the counts still
// match, at the cost of some performance. That situation should be rare, though.
if (reportedWritten != 0L) {
return reportedWritten;
}
if (actuallyWritten > 0 && currentReportedBytes < currentChunkSize - 1) {
transferred += 1L;
currentReportedBytes += 1L;
return 1L;
}
return 0L;
}
private void nextChunk() throws IOException {
byteChannel.reset();
if (isByteBuf) {
int copied = byteChannel.write(buf.nioBuffer());
buf.skipBytes(copied);
} else {
region.transferTo(byteChannel, region.transfered());
}
byte[] encrypted = backend.wrap(byteChannel.getData(), 0, byteChannel.length());
this.currentChunk = ByteBuffer.wrap(encrypted);
this.currentChunkSize = encrypted.length;
this.currentHeader = Unpooled.copyLong(8 + currentChunkSize);
this.unencryptedChunkSize = byteChannel.length();
}
@Override
protected void deallocate() {
if (currentHeader != null) {
currentHeader.release();
}
if (buf != null) {
buf.release();
}
if (region != null) {
region.release();
}
}
}
}

View file

@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.sasl;
import javax.security.sasl.SaslException;
interface SaslEncryptionBackend {
/** Disposes of resources used by the backend. */
void dispose();
/** Encrypt data. */
byte[] wrap(byte[] data, int offset, int len) throws SaslException;
/** Decrypt data. */
byte[] unwrap(byte[] data, int offset, int len) throws SaslException;
}

View file

@ -17,10 +17,10 @@
package org.apache.spark.network.sasl;
import java.util.concurrent.ConcurrentMap;
import javax.security.sasl.Sasl;
import com.google.common.collect.Maps;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -28,6 +28,7 @@ import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.util.TransportConf;
/**
* RPC Handler which performs SASL authentication before delegating to a child RPC handler.
@ -37,8 +38,14 @@ import org.apache.spark.network.server.StreamManager;
* Note that the authentication process consists of multiple challenge-response pairs, each of
* which are individual RPCs.
*/
public class SaslRpcHandler extends RpcHandler {
private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
class SaslRpcHandler extends RpcHandler {
private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
/** Transport configuration. */
private final TransportConf conf;
/** The client channel. */
private final Channel channel;
/** RpcHandler we will delegate to for authenticated connections. */
private final RpcHandler delegate;
@ -46,19 +53,25 @@ public class SaslRpcHandler extends RpcHandler {
/** Class which provides secret keys which are shared by server and client on a per-app basis. */
private final SecretKeyHolder secretKeyHolder;
/** Maps each channel to its SASL authentication state. */
private final ConcurrentMap<TransportClient, SparkSaslServer> channelAuthenticationMap;
private SparkSaslServer saslServer;
private boolean isComplete;
public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) {
SaslRpcHandler(
TransportConf conf,
Channel channel,
RpcHandler delegate,
SecretKeyHolder secretKeyHolder) {
this.conf = conf;
this.channel = channel;
this.delegate = delegate;
this.secretKeyHolder = secretKeyHolder;
this.channelAuthenticationMap = Maps.newConcurrentMap();
this.saslServer = null;
this.isComplete = false;
}
@Override
public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
SparkSaslServer saslServer = channelAuthenticationMap.get(client);
if (saslServer != null && saslServer.isComplete()) {
if (isComplete) {
// Authentication complete, delegate to base handler.
delegate.receive(client, message, callback);
return;
@ -68,15 +81,30 @@ public class SaslRpcHandler extends RpcHandler {
if (saslServer == null) {
// First message in the handshake, setup the necessary state.
saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder);
channelAuthenticationMap.put(client, saslServer);
saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
conf.saslServerAlwaysEncrypt());
}
byte[] response = saslServer.response(saslMessage.payload);
callback.onSuccess(response);
// Setup encryption after the SASL response is sent, otherwise the client can't parse the
// response. It's ok to change the channel pipeline here since we are processing an incoming
// message, so the pipeline is busy and no new incoming messages will be fed to it before this
// method returns. This assumes that the code ensures, through other means, that no outbound
// messages are being written to the channel while negotiation is still going on.
if (saslServer.isComplete()) {
logger.debug("SASL authentication successful for channel {}", client);
isComplete = true;
if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
logger.debug("Enabling encryption for channel {}", client);
SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
saslServer = null;
} else {
saslServer.dispose();
saslServer = null;
}
}
callback.onSuccess(response);
}
@Override
@ -86,9 +114,9 @@ public class SaslRpcHandler extends RpcHandler {
@Override
public void connectionTerminated(TransportClient client) {
SparkSaslServer saslServer = channelAuthenticationMap.remove(client);
if (saslServer != null) {
saslServer.dispose();
}
}
}

View file

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.sasl;
import io.netty.channel.Channel;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.util.TransportConf;
/**
* A bootstrap which is executed on a TransportServer's client channel once a client connects
* to the server. This allows customizing the client channel to allow for things such as SASL
* authentication.
*/
public class SaslServerBootstrap implements TransportServerBootstrap {
private final TransportConf conf;
private final SecretKeyHolder secretKeyHolder;
public SaslServerBootstrap(TransportConf conf, SecretKeyHolder secretKeyHolder) {
this.conf = conf;
this.secretKeyHolder = secretKeyHolder;
}
/**
* Wrap the given application handler in a SaslRpcHandler that will handle the initial SASL
* negotiation.
*/
public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
return new SaslRpcHandler(conf, channel, rpcHandler, secretKeyHolder);
}
}

View file

@ -17,6 +17,8 @@
package org.apache.spark.network.sasl;
import java.io.IOException;
import java.util.Map;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
@ -27,9 +29,9 @@ import javax.security.sasl.RealmChoiceCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import java.io.IOException;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -40,19 +42,25 @@ import static org.apache.spark.network.sasl.SparkSaslServer.*;
* initial state to the "authenticated" state. This client initializes the protocol via a
* firstToken, which is then followed by a set of challenges and responses.
*/
public class SparkSaslClient {
public class SparkSaslClient implements SaslEncryptionBackend {
private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class);
private final String secretKeyId;
private final SecretKeyHolder secretKeyHolder;
private final String expectedQop;
private SaslClient saslClient;
public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) {
public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, boolean encrypt) {
this.secretKeyId = secretKeyId;
this.secretKeyHolder = secretKeyHolder;
this.expectedQop = encrypt ? QOP_AUTH_CONF : QOP_AUTH;
Map<String, String> saslProps = ImmutableMap.<String, String>builder()
.put(Sasl.QOP, expectedQop)
.build();
try {
this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM,
SASL_PROPS, new ClientCallbackHandler());
saslProps, new ClientCallbackHandler());
} catch (SaslException e) {
throw Throwables.propagate(e);
}
@ -76,6 +84,11 @@ public class SparkSaslClient {
return saslClient != null && saslClient.isComplete();
}
/** Returns the value of a negotiated property. */
public Object getNegotiatedProperty(String name) {
return saslClient.getNegotiatedProperty(name);
}
/**
* Respond to server's SASL token.
* @param token contains server's SASL token
@ -93,6 +106,7 @@ public class SparkSaslClient {
* Disposes of any system resources or security-sensitive information the
* SaslClient might be using.
*/
@Override
public synchronized void dispose() {
if (saslClient != null) {
try {
@ -134,4 +148,15 @@ public class SparkSaslClient {
}
}
}
@Override
public byte[] wrap(byte[] data, int offset, int len) throws SaslException {
return saslClient.wrap(data, offset, len);
}
@Override
public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
return saslClient.unwrap(data, offset, len);
}
}

View file

@ -44,7 +44,7 @@ import org.slf4j.LoggerFactory;
* initial state to the "authenticated" state. (It is not a server in the sense of accepting
* connections on some socket.)
*/
public class SparkSaslServer {
public class SparkSaslServer implements SaslEncryptionBackend {
private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class);
/**
@ -60,26 +60,37 @@ public class SparkSaslServer {
static final String DIGEST = "DIGEST-MD5";
/**
* The quality of protection is just "auth". This means that we are doing
* authentication only, we are not supporting integrity or privacy protection of the
* communication channel after authentication. This could be changed to be configurable
* in the future.
* Quality of protection value that includes encryption.
*/
static final Map<String, String> SASL_PROPS = ImmutableMap.<String, String>builder()
.put(Sasl.QOP, "auth")
.put(Sasl.SERVER_AUTH, "true")
.build();
static final String QOP_AUTH_CONF = "auth-conf";
/**
* Quality of protection value that does not include encryption.
*/
static final String QOP_AUTH = "auth";
/** Identifier for a certain secret key within the secretKeyHolder. */
private final String secretKeyId;
private final SecretKeyHolder secretKeyHolder;
private SaslServer saslServer;
public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) {
public SparkSaslServer(
String secretKeyId,
SecretKeyHolder secretKeyHolder,
boolean alwaysEncrypt) {
this.secretKeyId = secretKeyId;
this.secretKeyHolder = secretKeyHolder;
// Sasl.QOP is a comma-separated list of supported values. The value that allows encryption
// is listed first since it's preferred over the non-encrypted one (if the client also
// lists both in the request).
String qop = alwaysEncrypt ? QOP_AUTH_CONF : String.format("%s,%s", QOP_AUTH_CONF, QOP_AUTH);
Map<String, String> saslProps = ImmutableMap.<String, String>builder()
.put(Sasl.SERVER_AUTH, "true")
.put(Sasl.QOP, qop)
.build();
try {
this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS,
this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, saslProps,
new DigestCallbackHandler());
} catch (SaslException e) {
throw Throwables.propagate(e);
@ -93,6 +104,11 @@ public class SparkSaslServer {
return saslServer != null && saslServer.isComplete();
}
/** Returns the value of a negotiated property. */
public Object getNegotiatedProperty(String name) {
return saslServer.getNegotiatedProperty(name);
}
/**
* Used to respond to server SASL tokens.
* @param token Server's SASL token
@ -110,6 +126,7 @@ public class SparkSaslServer {
* Disposes of any system resources or security-sensitive information the
* SaslServer might be using.
*/
@Override
public synchronized void dispose() {
if (saslServer != null) {
try {
@ -122,6 +139,16 @@ public class SparkSaslServer {
}
}
@Override
public byte[] wrap(byte[] data, int offset, int len) throws SaslException {
return saslServer.wrap(data, offset, len);
}
@Override
public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
return saslServer.unwrap(data, offset, len);
}
/**
* Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism.
*/

View file

@ -19,8 +19,11 @@ package org.apache.spark.network.server;
import java.io.Closeable;
import java.net.InetSocketAddress;
import java.util.List;
import java.util.concurrent.TimeUnit;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.ChannelFuture;
@ -44,15 +47,23 @@ public class TransportServer implements Closeable {
private final TransportContext context;
private final TransportConf conf;
private final RpcHandler appRpcHandler;
private final List<TransportServerBootstrap> bootstraps;
private ServerBootstrap bootstrap;
private ChannelFuture channelFuture;
private int port = -1;
/** Creates a TransportServer that binds to the given port, or to any available if 0. */
public TransportServer(TransportContext context, int portToBind) {
public TransportServer(
TransportContext context,
int portToBind,
RpcHandler appRpcHandler,
List<TransportServerBootstrap> bootstraps) {
this.context = context;
this.conf = context.getConf();
this.appRpcHandler = appRpcHandler;
this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps));
init(portToBind);
}
@ -95,7 +106,11 @@ public class TransportServer implements Closeable {
bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
context.initializePipeline(ch);
RpcHandler rpcHandler = appRpcHandler;
for (TransportServerBootstrap bootstrap : bootstraps) {
rpcHandler = bootstrap.doBootstrap(ch, rpcHandler);
}
context.initializePipeline(ch, rpcHandler);
}
});

View file

@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.server;
import io.netty.channel.Channel;
/**
* A bootstrap which is executed on a TransportServer's client channel once a client connects
* to the server. This allows customizing the client channel to allow for things such as SASL
* authentication.
*/
public interface TransportServerBootstrap {
/**
* Customizes the channel to include new features, if needed.
*
* @param channel The connected channel opened by the client.
* @param rpcHandler The RPC handler for the server.
* @return The RPC handler to use for the channel.
*/
RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler);
}

View file

@ -15,11 +15,14 @@
* limitations under the License.
*/
package org.apache.spark.network;
package org.apache.spark.network.util;
import java.nio.ByteBuffer;
import java.nio.channels.WritableByteChannel;
/**
* A writable channel that stores the written data in a byte array in memory.
*/
public class ByteArrayWritableChannel implements WritableByteChannel {
private final byte[] data;
@ -27,19 +30,30 @@ public class ByteArrayWritableChannel implements WritableByteChannel {
public ByteArrayWritableChannel(int size) {
this.data = new byte[size];
this.offset = 0;
}
public byte[] getData() {
return data;
}
public int length() {
return offset;
}
/** Resets the channel so that writing to it will overwrite the existing buffer. */
public void reset() {
offset = 0;
}
/**
* Reads from the given buffer into the internal byte array.
*/
@Override
public int write(ByteBuffer src) {
int available = src.remaining();
src.get(data, offset, available);
offset += available;
return available;
int toTransfer = Math.min(src.remaining(), data.length - offset);
src.get(data, offset, toTransfer);
offset += toTransfer;
return toTransfer;
}
@Override

View file

@ -17,6 +17,8 @@
package org.apache.spark.network.util;
import com.google.common.primitives.Ints;
/**
* A central location that tracks all the settings we expose to users.
*/
@ -112,4 +114,20 @@ public class TransportConf {
public int portMaxRetries() {
return conf.getInt("spark.port.maxRetries", 16);
}
/**
* Maximum number of bytes to be encrypted at a time when SASL encryption is enabled.
*/
public int maxSaslEncryptedBlockSize() {
return Ints.checkedCast(JavaUtils.byteStringAsBytes(
conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k")));
}
/**
* Whether the server should enforce encryption on SASL-authenticated connections.
*/
public boolean saslServerAlwaysEncrypt() {
return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
}
}

View file

@ -39,6 +39,7 @@ import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
import org.apache.spark.network.util.ByteArrayWritableChannel;
import org.apache.spark.network.util.NettyUtils;
public class ProtocolSuite {

View file

@ -29,7 +29,7 @@ import org.junit.Test;
import static org.junit.Assert.*;
import org.apache.spark.network.ByteArrayWritableChannel;
import org.apache.spark.network.util.ByteArrayWritableChannel;
public class MessageWithHeaderSuite {

View file

@ -17,12 +17,47 @@
package org.apache.spark.network.sasl;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static com.google.common.base.Charsets.UTF_8;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
import java.io.File;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import javax.security.sasl.SaslException;
import com.google.common.collect.Lists;
import com.google.common.io.ByteStreams;
import com.google.common.io.Files;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.apache.spark.network.TestUtils;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.util.ByteArrayWritableChannel;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
/**
* Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes.
@ -44,8 +79,8 @@ public class SparkSaslSuite {
@Test
public void testMatching() {
SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder);
SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder);
SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder, false);
SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder, false);
assertFalse(client.isComplete());
assertFalse(server.isComplete());
@ -64,11 +99,10 @@ public class SparkSaslSuite {
assertFalse(client.isComplete());
}
@Test
public void testNonMatching() {
SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder);
SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder);
SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder, false);
SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder, false);
assertFalse(client.isComplete());
assertFalse(server.isComplete());
@ -86,4 +120,312 @@ public class SparkSaslSuite {
assertFalse(server.isComplete());
}
}
@Test
public void testSaslAuthentication() throws Exception {
testBasicSasl(false);
}
@Test
public void testSaslEncryption() throws Exception {
testBasicSasl(true);
}
private void testBasicSasl(boolean encrypt) throws Exception {
RpcHandler rpcHandler = mock(RpcHandler.class);
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) {
byte[] message = (byte[]) invocation.getArguments()[1];
RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2];
assertEquals("Ping", new String(message, UTF_8));
cb.onSuccess("Pong".getBytes(UTF_8));
return null;
}
})
.when(rpcHandler)
.receive(any(TransportClient.class), any(byte[].class), any(RpcResponseCallback.class));
SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
try {
byte[] response = ctx.client.sendRpcSync("Ping".getBytes(UTF_8), TimeUnit.SECONDS.toMillis(10));
assertEquals("Pong", new String(response, UTF_8));
} finally {
ctx.close();
}
}
@Test
public void testEncryptedMessage() throws Exception {
SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class);
byte[] data = new byte[1024];
new Random().nextBytes(data);
when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data);
ByteBuf msg = Unpooled.buffer();
try {
msg.writeBytes(data);
// Create a channel with a really small buffer compared to the data. This means that on each
// call, the outbound data will not be fully written, so the write() method should return a
// dummy count to keep the channel alive when possible.
ByteArrayWritableChannel channel = new ByteArrayWritableChannel(32);
SaslEncryption.EncryptedMessage emsg =
new SaslEncryption.EncryptedMessage(backend, msg, 1024);
long count = emsg.transferTo(channel, emsg.transfered());
assertTrue(count < data.length);
assertTrue(count > 0);
// Here, the output buffer is full so nothing should be transferred.
assertEquals(0, emsg.transferTo(channel, emsg.transfered()));
// Now there's room in the buffer, but not enough to transfer all the remaining data,
// so the dummy count should be returned.
channel.reset();
assertEquals(1, emsg.transferTo(channel, emsg.transfered()));
// Eventually, the whole message should be transferred.
for (int i = 0; i < data.length / 32 - 2; i++) {
channel.reset();
assertEquals(1, emsg.transferTo(channel, emsg.transfered()));
}
channel.reset();
count = emsg.transferTo(channel, emsg.transfered());
assertTrue("Unexpected count: " + count, count > 1 && count < data.length);
assertEquals(data.length, emsg.transfered());
} finally {
msg.release();
}
}
@Test
public void testEncryptedMessageChunking() throws Exception {
File file = File.createTempFile("sasltest", ".txt");
try {
TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
byte[] data = new byte[8 * 1024];
new Random().nextBytes(data);
Files.write(data, file);
SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class);
// It doesn't really matter what we return here, as long as it's not null.
when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data);
FileSegmentManagedBuffer msg = new FileSegmentManagedBuffer(conf, file, 0, file.length());
SaslEncryption.EncryptedMessage emsg =
new SaslEncryption.EncryptedMessage(backend, msg.convertToNetty(), data.length / 8);
ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length);
while (emsg.transfered() < emsg.count()) {
channel.reset();
emsg.transferTo(channel, emsg.transfered());
}
verify(backend, times(8)).wrap(any(byte[].class), anyInt(), anyInt());
} finally {
file.delete();
}
}
@Test
public void testFileRegionEncryption() throws Exception {
final String blockSizeConf = "spark.network.sasl.maxEncryptedBlockSize";
System.setProperty(blockSizeConf, "1k");
final AtomicReference<ManagedBuffer> response = new AtomicReference();
final File file = File.createTempFile("sasltest", ".txt");
SaslTestCtx ctx = null;
try {
final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
StreamManager sm = mock(StreamManager.class);
when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer<ManagedBuffer>() {
@Override
public ManagedBuffer answer(InvocationOnMock invocation) {
return new FileSegmentManagedBuffer(conf, file, 0, file.length());
}
});
RpcHandler rpcHandler = mock(RpcHandler.class);
when(rpcHandler.getStreamManager()).thenReturn(sm);
byte[] data = new byte[8 * 1024];
new Random().nextBytes(data);
Files.write(data, file);
ctx = new SaslTestCtx(rpcHandler, true, false);
final Object lock = new Object();
ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) {
response.set((ManagedBuffer) invocation.getArguments()[1]);
response.get().retain();
synchronized (lock) {
lock.notifyAll();
}
return null;
}
}).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class));
synchronized (lock) {
ctx.client.fetchChunk(0, 0, callback);
lock.wait(10 * 1000);
}
verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
verify(callback, never()).onFailure(anyInt(), any(Throwable.class));
byte[] received = ByteStreams.toByteArray(response.get().createInputStream());
assertTrue(Arrays.equals(data, received));
} finally {
file.delete();
if (ctx != null) {
ctx.close();
}
if (response.get() != null) {
response.get().release();
}
System.clearProperty(blockSizeConf);
}
}
@Test
public void testServerAlwaysEncrypt() throws Exception {
final String alwaysEncryptConfName = "spark.network.sasl.serverAlwaysEncrypt";
System.setProperty(alwaysEncryptConfName, "true");
SaslTestCtx ctx = null;
try {
ctx = new SaslTestCtx(mock(RpcHandler.class), false, false);
fail("Should have failed to connect without encryption.");
} catch (Exception e) {
assertTrue(e.getCause() instanceof SaslException);
} finally {
if (ctx != null) {
ctx.close();
}
System.clearProperty(alwaysEncryptConfName);
}
}
@Test
public void testDataEncryptionIsActuallyEnabled() throws Exception {
// This test sets up an encrypted connection but then, using a client bootstrap, removes
// the encryption handler from the client side. This should cause the server to not be
// able to understand RPCs sent to it and thus close the connection.
SaslTestCtx ctx = null;
try {
ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
ctx.client.sendRpcSync("Ping".getBytes(UTF_8), TimeUnit.SECONDS.toMillis(10));
fail("Should have failed to send RPC to server.");
} catch (Exception e) {
assertFalse(e.getCause() instanceof TimeoutException);
} finally {
if (ctx != null) {
ctx.close();
}
}
}
private static class SaslTestCtx {
final TransportClient client;
final TransportServer server;
private final boolean encrypt;
private final boolean disableClientEncryption;
private final EncryptionCheckerBootstrap checker;
SaslTestCtx(
RpcHandler rpcHandler,
boolean encrypt,
boolean disableClientEncryption)
throws Exception {
TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
when(keyHolder.getSaslUser(anyString())).thenReturn("user");
when(keyHolder.getSecretKey(anyString())).thenReturn("secret");
TransportContext ctx = new TransportContext(conf, rpcHandler);
this.checker = new EncryptionCheckerBootstrap();
this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder),
checker));
try {
List<TransportClientBootstrap> clientBootstraps = Lists.newArrayList();
clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder, encrypt));
if (disableClientEncryption) {
clientBootstraps.add(new EncryptionDisablerBootstrap());
}
this.client = ctx.createClientFactory(clientBootstraps)
.createClient(TestUtils.getLocalHost(), server.getPort());
} catch (Exception e) {
close();
throw e;
}
this.encrypt = encrypt;
this.disableClientEncryption = disableClientEncryption;
}
void close() {
if (!disableClientEncryption) {
assertEquals(encrypt, checker.foundEncryptionHandler);
}
if (client != null) {
client.close();
}
if (server != null) {
server.close();
}
}
}
private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAdapter
implements TransportServerBootstrap {
boolean foundEncryptionHandler;
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
if (!foundEncryptionHandler) {
foundEncryptionHandler =
ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null;
}
ctx.write(msg, promise);
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
super.handlerRemoved(ctx);
}
@Override
public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
channel.pipeline().addFirst("encryptionChecker", this);
return rpcHandler;
}
}
private static class EncryptionDisablerBootstrap implements TransportClientBootstrap {
@Override
public void doBootstrap(TransportClient client, Channel channel) {
channel.pipeline().remove(SaslEncryption.ENCRYPTION_HANDLER_NAME);
}
}
}

View file

@ -20,6 +20,7 @@ package org.apache.spark.network.shuffle;
import java.io.IOException;
import java.util.List;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -46,6 +47,7 @@ public class ExternalShuffleClient extends ShuffleClient {
private final TransportConf conf;
private final boolean saslEnabled;
private final boolean saslEncryptionEnabled;
private final SecretKeyHolder secretKeyHolder;
private TransportClientFactory clientFactory;
@ -58,10 +60,15 @@ public class ExternalShuffleClient extends ShuffleClient {
public ExternalShuffleClient(
TransportConf conf,
SecretKeyHolder secretKeyHolder,
boolean saslEnabled) {
boolean saslEnabled,
boolean saslEncryptionEnabled) {
Preconditions.checkArgument(
!saslEncryptionEnabled || saslEnabled,
"SASL encryption can only be enabled if SASL is also enabled.");
this.conf = conf;
this.secretKeyHolder = secretKeyHolder;
this.saslEnabled = saslEnabled;
this.saslEncryptionEnabled = saslEncryptionEnabled;
}
@Override
@ -70,7 +77,7 @@ public class ExternalShuffleClient extends ShuffleClient {
TransportContext context = new TransportContext(conf, new NoOpRpcHandler());
List<TransportClientBootstrap> bootstraps = Lists.newArrayList();
if (saslEnabled) {
bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder));
bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled));
}
clientFactory = context.createClientFactory(bootstraps);
}

View file

@ -18,6 +18,7 @@
package org.apache.spark.network.sasl;
import java.io.IOException;
import java.util.Arrays;
import com.google.common.collect.Lists;
import org.junit.After;
@ -37,6 +38,7 @@ import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
@ -72,10 +74,11 @@ public class SaslIntegrationSuite {
@BeforeClass
public static void beforeAll() throws IOException {
SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key");
SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder);
conf = new TransportConf(new SystemPropertyConfigProvider());
context = new TransportContext(conf, handler);
server = context.createServer();
context = new TransportContext(conf, new TestRpcHandler());
TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
server = context.createServer(Arrays.asList(bootstrap));
}

View file

@ -136,7 +136,7 @@ public class ExternalShuffleIntegrationSuite {
final Semaphore requestsRemaining = new Semaphore(0);
ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false);
ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false);
client.init(APP_ID);
client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
new BlockFetchingListener() {
@ -274,7 +274,7 @@ public class ExternalShuffleIntegrationSuite {
private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo)
throws IOException {
ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false);
ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false);
client.init(APP_ID);
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),
executorId, executorInfo);

View file

@ -18,6 +18,7 @@
package org.apache.spark.network.shuffle;
import java.io.IOException;
import java.util.Arrays;
import org.junit.After;
import org.junit.Before;
@ -27,10 +28,11 @@ import static org.junit.Assert.*;
import org.apache.spark.network.TestUtils;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.sasl.SaslRpcHandler;
import org.apache.spark.network.sasl.SaslServerBootstrap;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
@ -42,10 +44,10 @@ public class ExternalShuffleSecuritySuite {
@Before
public void beforeEach() {
RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(conf),
new TestSecretKeyHolder("my-app-id", "secret"));
TransportContext context = new TransportContext(conf, handler);
this.server = context.createServer();
TransportContext context = new TransportContext(conf, new ExternalShuffleBlockHandler(conf));
TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf,
new TestSecretKeyHolder("my-app-id", "secret"));
this.server = context.createServer(Arrays.asList(bootstrap));
}
@After
@ -58,13 +60,13 @@ public class ExternalShuffleSecuritySuite {
@Test
public void testValid() throws IOException {
validate("my-app-id", "secret");
validate("my-app-id", "secret", false);
}
@Test
public void testBadAppId() {
try {
validate("wrong-app-id", "secret");
validate("wrong-app-id", "secret", false);
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!"));
}
@ -73,16 +75,21 @@ public class ExternalShuffleSecuritySuite {
@Test
public void testBadSecret() {
try {
validate("my-app-id", "bad-secret");
validate("my-app-id", "bad-secret", false);
} catch (Exception e) {
assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
}
}
@Test
public void testEncryption() throws IOException {
validate("my-app-id", "secret", true);
}
/** Creates an ExternalShuffleClient and attempts to register with the server. */
private void validate(String appId, String secretKey) throws IOException {
private void validate(String appId, String secretKey, boolean encrypt) throws IOException {
ExternalShuffleClient client =
new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true);
new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true, encrypt);
client.init(appId);
// Registration either succeeds or throws an exception.
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0",

View file

@ -17,9 +17,10 @@
package org.apache.spark.network.yarn;
import java.lang.Override;
import java.nio.ByteBuffer;
import java.util.List;
import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ContainerId;
@ -32,10 +33,11 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.sasl.SaslRpcHandler;
import org.apache.spark.network.sasl.SaslServerBootstrap;
import org.apache.spark.network.sasl.ShuffleSecretManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
import org.apache.spark.network.util.TransportConf;
import org.apache.spark.network.yarn.util.HadoopConfigProvider;
@ -103,16 +105,17 @@ public class YarnShuffleService extends AuxiliaryService {
// special RPC handler that filters out unauthenticated fetch requests
boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE);
blockHandler = new ExternalShuffleBlockHandler(transportConf);
RpcHandler rpcHandler = blockHandler;
List<TransportServerBootstrap> bootstraps = Lists.newArrayList();
if (authEnabled) {
secretManager = new ShuffleSecretManager();
rpcHandler = new SaslRpcHandler(rpcHandler, secretManager);
bootstraps.add(new SaslServerBootstrap(transportConf, secretManager));
}
int port = conf.getInt(
SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT);
TransportContext transportContext = new TransportContext(transportConf, rpcHandler);
shuffleServer = transportContext.createServer(port);
TransportContext transportContext = new TransportContext(transportConf, blockHandler);
shuffleServer = transportContext.createServer(port, bootstraps);
String authEnabledString = authEnabled ? "enabled" : "not enabled";
logger.info("Started YARN shuffle service for Spark on port {}. " +
"Authentication is {}.", port, authEnabledString);