[SPARK-6229] Add SASL encryption to network library.
There are two main parts of this change: - Extending the bootstrap mechanism in the network library to add a server-side bootstrap (which works a little bit differently than the client-side bootstrap), and to allow the bootstraps to modify the underlying channel. - Use SASL to encrypt data going through the RPC channel. The second item requires some non-optimal code to be able to work around the fact that the outbound path in netty is not thread-safe, and ordering is very important when encryption is in the picture. A lot of the changes outside the network/common library are just to adjust to the changed API for initializing the RPC server. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #5377 from vanzin/SPARK-6229 and squashes the following commits: ff01966 [Marcelo Vanzin] Use fancy new size config style. be53f32 [Marcelo Vanzin] Merge branch 'master' into SPARK-6229 47d4aff [Marcelo Vanzin] Merge branch 'master' into SPARK-6229 7a2a805 [Marcelo Vanzin] Clean up some unneeded changes. 2f92237 [Marcelo Vanzin] Add comment. 67bb0c6 [Marcelo Vanzin] Revert "Avoid exposing ByteArrayWritableChannel outside of test code." 065f684 [Marcelo Vanzin] Add test to verify chunking. 3d1695d [Marcelo Vanzin] Minor cleanups. 73cff0e [Marcelo Vanzin] Skip bytes in decode path too. 318ad23 [Marcelo Vanzin] Avoid exposing ByteArrayWritableChannel outside of test code. 346f829 [Marcelo Vanzin] Avoid trip through channel selector by not reporting 0 bytes written. a4a5938 [Marcelo Vanzin] Review feedback. 4797519 [Marcelo Vanzin] Remove unused import. 9908ada [Marcelo Vanzin] Fix test, SASL backend disposal. 7fe1489 [Marcelo Vanzin] Add a test that makes sure encryption is actually enabled. adb6f9d [Marcelo Vanzin] Review feedback. cf2a605 [Marcelo Vanzin] Clean up some code. 8584323 [Marcelo Vanzin] Fix a comment. e98bc55 [Marcelo Vanzin] Add option to only allow encrypted connections to the server. dad42fc [Marcelo Vanzin] Make encryption thread-safe, less memory-intensive. b00999a [Marcelo Vanzin] Consolidate ByteArrayWritableChannel, fix SASL code to match master changes. b923cae [Marcelo Vanzin] Make SASL encryption handler thread-safe, handle FileRegion messages. 39539a7 [Marcelo Vanzin] Add config option to enable SASL encryption. 351a86f [Marcelo Vanzin] Add SASL encryption to network library. fbe6ccb [Marcelo Vanzin] Add TransportServerBootstrap, make SASL code use it.
This commit is contained in:
parent
8f50a07d21
commit
38d4e9e446
|
@ -150,8 +150,13 @@ import org.apache.spark.util.Utils
|
|||
* authorization. If not filter is in place the user is generally null and no authorization
|
||||
* can take place.
|
||||
*
|
||||
* Connection encryption (SSL) configuration is organized hierarchically. The user can configure
|
||||
* the default SSL settings which will be used for all the supported communication protocols unless
|
||||
* When authentication is being used, encryption can also be enabled by setting the option
|
||||
* spark.authenticate.enableSaslEncryption to true. This is only supported by communication
|
||||
* channels that use the network-common library, and can be used as an alternative to SSL in those
|
||||
* cases.
|
||||
*
|
||||
* SSL can be used for encryption for certain communication channels. The user can configure the
|
||||
* default SSL settings which will be used for all the supported communication protocols unless
|
||||
* they are overwritten by protocol specific settings. This way the user can easily provide the
|
||||
* common settings for all the protocols without disabling the ability to configure each one
|
||||
* individually.
|
||||
|
@ -412,6 +417,14 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
|
|||
*/
|
||||
def isAuthenticationEnabled(): Boolean = authOn
|
||||
|
||||
/**
|
||||
* Checks whether SASL encryption should be enabled.
|
||||
* @return Whether to enable SASL encryption when connecting to services that support it.
|
||||
*/
|
||||
def isSaslEncryptionEnabled(): Boolean = {
|
||||
sparkConf.getBoolean("spark.authenticate.enableSaslEncryption", false)
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the user used for authenticating HTTP connections.
|
||||
* For now use a single hardcoded user.
|
||||
|
|
|
@ -19,10 +19,12 @@ package org.apache.spark.deploy
|
|||
|
||||
import java.util.concurrent.CountDownLatch
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
import org.apache.spark.{Logging, SparkConf, SecurityManager}
|
||||
import org.apache.spark.network.TransportContext
|
||||
import org.apache.spark.network.netty.SparkTransportConf
|
||||
import org.apache.spark.network.sasl.SaslRpcHandler
|
||||
import org.apache.spark.network.sasl.SaslServerBootstrap
|
||||
import org.apache.spark.network.server.TransportServer
|
||||
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
|
||||
import org.apache.spark.util.Utils
|
||||
|
@ -44,10 +46,7 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
|
|||
|
||||
private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0)
|
||||
private val blockHandler = new ExternalShuffleBlockHandler(transportConf)
|
||||
private val transportContext: TransportContext = {
|
||||
val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler
|
||||
new TransportContext(transportConf, handler)
|
||||
}
|
||||
private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler)
|
||||
|
||||
private var server: TransportServer = _
|
||||
|
||||
|
@ -62,7 +61,13 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
|
|||
def start() {
|
||||
require(server == null, "Shuffle server already started")
|
||||
logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl")
|
||||
server = transportContext.createServer(port)
|
||||
val bootstraps =
|
||||
if (useSasl) {
|
||||
Seq(new SaslServerBootstrap(transportConf, securityManager))
|
||||
} else {
|
||||
Nil
|
||||
}
|
||||
server = transportContext.createServer(port, bootstraps)
|
||||
}
|
||||
|
||||
def stop() {
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.{SecurityManager, SparkConf}
|
|||
import org.apache.spark.network._
|
||||
import org.apache.spark.network.buffer.ManagedBuffer
|
||||
import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory}
|
||||
import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
|
||||
import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap}
|
||||
import org.apache.spark.network.server._
|
||||
import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
|
||||
import org.apache.spark.network.shuffle.protocol.UploadBlock
|
||||
|
@ -49,18 +49,18 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
|
|||
private[this] var appId: String = _
|
||||
|
||||
override def init(blockDataManager: BlockDataManager): Unit = {
|
||||
val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = {
|
||||
val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
|
||||
if (!authEnabled) {
|
||||
(nettyRpcHandler, None)
|
||||
} else {
|
||||
(new SaslRpcHandler(nettyRpcHandler, securityManager),
|
||||
Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager)))
|
||||
}
|
||||
val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
|
||||
var serverBootstrap: Option[TransportServerBootstrap] = None
|
||||
var clientBootstrap: Option[TransportClientBootstrap] = None
|
||||
if (authEnabled) {
|
||||
serverBootstrap = Some(new SaslServerBootstrap(transportConf, securityManager))
|
||||
clientBootstrap = Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager,
|
||||
securityManager.isSaslEncryptionEnabled()))
|
||||
}
|
||||
transportContext = new TransportContext(transportConf, rpcHandler)
|
||||
clientFactory = transportContext.createClientFactory(bootstrap.toList)
|
||||
server = transportContext.createServer(conf.getInt("spark.blockManager.port", 0))
|
||||
clientFactory = transportContext.createClientFactory(clientBootstrap.toList)
|
||||
server = transportContext.createServer(conf.getInt("spark.blockManager.port", 0),
|
||||
serverBootstrap.toList)
|
||||
appId = conf.getAppId
|
||||
logInfo("Server created on " + server.getPort)
|
||||
}
|
||||
|
|
|
@ -656,7 +656,7 @@ private[nio] class ConnectionManager(
|
|||
connection.synchronized {
|
||||
if (connection.sparkSaslServer == null) {
|
||||
logDebug("Creating sasl Server")
|
||||
connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager)
|
||||
connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager, false)
|
||||
}
|
||||
}
|
||||
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
|
||||
|
@ -800,7 +800,7 @@ private[nio] class ConnectionManager(
|
|||
if (!conn.isSaslComplete()) {
|
||||
conn.synchronized {
|
||||
if (conn.sparkSaslClient == null) {
|
||||
conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager)
|
||||
conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager, false)
|
||||
var firstResponse: Array[Byte] = null
|
||||
try {
|
||||
firstResponse = conn.sparkSaslClient.firstToken()
|
||||
|
|
|
@ -111,7 +111,8 @@ private[spark] class BlockManager(
|
|||
// standard BlockTransferService to directly connect to other Executors.
|
||||
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
|
||||
val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores)
|
||||
new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled())
|
||||
new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(),
|
||||
securityManager.isSaslEncryptionEnabled())
|
||||
} else {
|
||||
blockTransferService
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ import org.apache.spark.network.server.RpcHandler;
|
|||
import org.apache.spark.network.server.TransportChannelHandler;
|
||||
import org.apache.spark.network.server.TransportRequestHandler;
|
||||
import org.apache.spark.network.server.TransportServer;
|
||||
import org.apache.spark.network.server.TransportServerBootstrap;
|
||||
import org.apache.spark.network.util.NettyUtils;
|
||||
import org.apache.spark.network.util.TransportConf;
|
||||
|
||||
|
@ -82,13 +83,21 @@ public class TransportContext {
|
|||
}
|
||||
|
||||
/** Create a server which will attempt to bind to a specific port. */
|
||||
public TransportServer createServer(int port) {
|
||||
return new TransportServer(this, port);
|
||||
public TransportServer createServer(int port, List<TransportServerBootstrap> bootstraps) {
|
||||
return new TransportServer(this, port, rpcHandler, bootstraps);
|
||||
}
|
||||
|
||||
/** Creates a new server, binding to any available ephemeral port. */
|
||||
public TransportServer createServer(List<TransportServerBootstrap> bootstraps) {
|
||||
return createServer(0, bootstraps);
|
||||
}
|
||||
|
||||
public TransportServer createServer() {
|
||||
return new TransportServer(this, 0);
|
||||
return createServer(0, Lists.<TransportServerBootstrap>newArrayList());
|
||||
}
|
||||
|
||||
public TransportChannelHandler initializePipeline(SocketChannel channel) {
|
||||
return initializePipeline(channel, rpcHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -96,13 +105,18 @@ public class TransportContext {
|
|||
* has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or
|
||||
* response messages.
|
||||
*
|
||||
* @param channel The channel to initialize.
|
||||
* @param channelRpcHandler The RPC handler to use for the channel.
|
||||
*
|
||||
* @return Returns the created TransportChannelHandler, which includes a TransportClient that can
|
||||
* be used to communicate on this channel. The TransportClient is directly associated with a
|
||||
* ChannelHandler to ensure all users of the same channel get the same TransportClient object.
|
||||
*/
|
||||
public TransportChannelHandler initializePipeline(SocketChannel channel) {
|
||||
public TransportChannelHandler initializePipeline(
|
||||
SocketChannel channel,
|
||||
RpcHandler channelRpcHandler) {
|
||||
try {
|
||||
TransportChannelHandler channelHandler = createChannelHandler(channel);
|
||||
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
|
||||
channel.pipeline()
|
||||
.addLast("encoder", encoder)
|
||||
.addLast("frameDecoder", NettyUtils.createFrameDecoder())
|
||||
|
@ -123,7 +137,7 @@ public class TransportContext {
|
|||
* ResponseMessages. The channel is expected to have been successfully created, though certain
|
||||
* properties (such as the remoteAddress()) may not be available yet.
|
||||
*/
|
||||
private TransportChannelHandler createChannelHandler(Channel channel) {
|
||||
private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) {
|
||||
TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
|
||||
TransportClient client = new TransportClient(channel, responseHandler);
|
||||
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
package org.apache.spark.network.client;
|
||||
|
||||
import io.netty.channel.Channel;
|
||||
|
||||
/**
|
||||
* A bootstrap which is executed on a TransportClient before it is returned to the user.
|
||||
* This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per-
|
||||
|
@ -28,5 +30,5 @@ package org.apache.spark.network.client;
|
|||
*/
|
||||
public interface TransportClientBootstrap {
|
||||
/** Performs the bootstrapping operation, throwing an exception on failure. */
|
||||
public void doBootstrap(TransportClient client) throws RuntimeException;
|
||||
void doBootstrap(TransportClient client, Channel channel) throws RuntimeException;
|
||||
}
|
||||
|
|
|
@ -172,12 +172,14 @@ public class TransportClientFactory implements Closeable {
|
|||
.option(ChannelOption.ALLOCATOR, pooledAllocator);
|
||||
|
||||
final AtomicReference<TransportClient> clientRef = new AtomicReference<TransportClient>();
|
||||
final AtomicReference<Channel> channelRef = new AtomicReference<Channel>();
|
||||
|
||||
bootstrap.handler(new ChannelInitializer<SocketChannel>() {
|
||||
@Override
|
||||
public void initChannel(SocketChannel ch) {
|
||||
TransportChannelHandler clientHandler = context.initializePipeline(ch);
|
||||
clientRef.set(clientHandler.getClient());
|
||||
channelRef.set(ch);
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -192,6 +194,7 @@ public class TransportClientFactory implements Closeable {
|
|||
}
|
||||
|
||||
TransportClient client = clientRef.get();
|
||||
Channel channel = channelRef.get();
|
||||
assert client != null : "Channel future completed successfully with null client";
|
||||
|
||||
// Execute any client bootstraps synchronously before marking the Client as successful.
|
||||
|
@ -199,7 +202,7 @@ public class TransportClientFactory implements Closeable {
|
|||
logger.debug("Connection to {} successful, running bootstraps...", address);
|
||||
try {
|
||||
for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
|
||||
clientBootstrap.doBootstrap(client);
|
||||
clientBootstrap.doBootstrap(client, channel);
|
||||
}
|
||||
} catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
|
||||
long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
|
||||
|
|
|
@ -17,8 +17,12 @@
|
|||
|
||||
package org.apache.spark.network.sasl;
|
||||
|
||||
import javax.security.sasl.Sasl;
|
||||
import javax.security.sasl.SaslException;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.Channel;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
@ -33,14 +37,24 @@ import org.apache.spark.network.util.TransportConf;
|
|||
public class SaslClientBootstrap implements TransportClientBootstrap {
|
||||
private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class);
|
||||
|
||||
private final boolean encrypt;
|
||||
private final TransportConf conf;
|
||||
private final String appId;
|
||||
private final SecretKeyHolder secretKeyHolder;
|
||||
|
||||
public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) {
|
||||
this(conf, appId, secretKeyHolder, false);
|
||||
}
|
||||
|
||||
public SaslClientBootstrap(
|
||||
TransportConf conf,
|
||||
String appId,
|
||||
SecretKeyHolder secretKeyHolder,
|
||||
boolean encrypt) {
|
||||
this.conf = conf;
|
||||
this.appId = appId;
|
||||
this.secretKeyHolder = secretKeyHolder;
|
||||
this.encrypt = encrypt;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -49,8 +63,8 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
|
|||
* due to mismatch.
|
||||
*/
|
||||
@Override
|
||||
public void doBootstrap(TransportClient client) {
|
||||
SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder);
|
||||
public void doBootstrap(TransportClient client, Channel channel) {
|
||||
SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt);
|
||||
try {
|
||||
byte[] payload = saslClient.firstToken();
|
||||
|
||||
|
@ -62,13 +76,26 @@ public class SaslClientBootstrap implements TransportClientBootstrap {
|
|||
byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs());
|
||||
payload = saslClient.response(response);
|
||||
}
|
||||
|
||||
if (encrypt) {
|
||||
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
|
||||
throw new RuntimeException(
|
||||
new SaslException("Encryption requests by negotiated non-encrypted connection."));
|
||||
}
|
||||
SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
|
||||
saslClient = null;
|
||||
logger.debug("Channel {} configured for SASL encryption.", client);
|
||||
}
|
||||
} finally {
|
||||
try {
|
||||
// Once authentication is complete, the server will trust all remaining communication.
|
||||
saslClient.dispose();
|
||||
} catch (RuntimeException e) {
|
||||
logger.error("Error while disposing SASL client", e);
|
||||
if (saslClient != null) {
|
||||
try {
|
||||
// Once authentication is complete, the server will trust all remaining communication.
|
||||
saslClient.dispose();
|
||||
} catch (RuntimeException e) {
|
||||
logger.error("Error while disposing SASL client", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,291 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.network.sasl;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.WritableByteChannel;
|
||||
import java.util.List;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelOutboundHandlerAdapter;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import io.netty.channel.FileRegion;
|
||||
import io.netty.handler.codec.MessageToMessageDecoder;
|
||||
import io.netty.util.AbstractReferenceCounted;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
|
||||
import org.apache.spark.network.util.ByteArrayWritableChannel;
|
||||
import org.apache.spark.network.util.NettyUtils;
|
||||
|
||||
/**
|
||||
* Provides SASL-based encription for transport channels. The single method exposed by this
|
||||
* class installs the needed channel handlers on a connected channel.
|
||||
*/
|
||||
class SaslEncryption {
|
||||
|
||||
@VisibleForTesting
|
||||
static final String ENCRYPTION_HANDLER_NAME = "saslEncryption";
|
||||
|
||||
/**
|
||||
* Adds channel handlers that perform encryption / decryption of data using SASL.
|
||||
*
|
||||
* @param channel The channel.
|
||||
* @param backend The SASL backend.
|
||||
* @param maxOutboundBlockSize Max size in bytes of outgoing encrypted blocks, to control
|
||||
* memory usage.
|
||||
*/
|
||||
static void addToChannel(
|
||||
Channel channel,
|
||||
SaslEncryptionBackend backend,
|
||||
int maxOutboundBlockSize) {
|
||||
channel.pipeline()
|
||||
.addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize))
|
||||
.addFirst("saslDecryption", new DecryptionHandler(backend))
|
||||
.addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder());
|
||||
}
|
||||
|
||||
private static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
|
||||
|
||||
private final int maxOutboundBlockSize;
|
||||
private final SaslEncryptionBackend backend;
|
||||
|
||||
EncryptionHandler(SaslEncryptionBackend backend, int maxOutboundBlockSize) {
|
||||
this.backend = backend;
|
||||
this.maxOutboundBlockSize = maxOutboundBlockSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap the incoming message in an implementation that will perform encryption lazily. This is
|
||||
* needed to guarantee ordering of the outgoing encrypted packets - they need to be decrypted in
|
||||
* the same order, and netty doesn't have an atomic ChannelHandlerContext.write() API, so it
|
||||
* does not guarantee any ordering.
|
||||
*/
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
|
||||
throws Exception {
|
||||
|
||||
ctx.write(new EncryptedMessage(backend, msg, maxOutboundBlockSize), promise);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
|
||||
try {
|
||||
backend.dispose();
|
||||
} finally {
|
||||
super.handlerRemoved(ctx);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static class DecryptionHandler extends MessageToMessageDecoder<ByteBuf> {
|
||||
|
||||
private final SaslEncryptionBackend backend;
|
||||
|
||||
DecryptionHandler(SaslEncryptionBackend backend) {
|
||||
this.backend = backend;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out)
|
||||
throws Exception {
|
||||
|
||||
byte[] data;
|
||||
int offset;
|
||||
int length = msg.readableBytes();
|
||||
if (msg.hasArray()) {
|
||||
data = msg.array();
|
||||
offset = msg.arrayOffset();
|
||||
msg.skipBytes(length);
|
||||
} else {
|
||||
data = new byte[length];
|
||||
msg.readBytes(data);
|
||||
offset = 0;
|
||||
}
|
||||
|
||||
out.add(Unpooled.wrappedBuffer(backend.unwrap(data, offset, length)));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion {
|
||||
|
||||
private final SaslEncryptionBackend backend;
|
||||
private final boolean isByteBuf;
|
||||
private final ByteBuf buf;
|
||||
private final FileRegion region;
|
||||
|
||||
/**
|
||||
* A channel used to buffer input data for encryption. The channel has an upper size bound
|
||||
* so that if the input is larger than the allowed buffer, it will be broken into multiple
|
||||
* chunks.
|
||||
*/
|
||||
private final ByteArrayWritableChannel byteChannel;
|
||||
|
||||
private ByteBuf currentHeader;
|
||||
private ByteBuffer currentChunk;
|
||||
private long currentChunkSize;
|
||||
private long currentReportedBytes;
|
||||
private long unencryptedChunkSize;
|
||||
private long transferred;
|
||||
|
||||
EncryptedMessage(SaslEncryptionBackend backend, Object msg, int maxOutboundBlockSize) {
|
||||
Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion,
|
||||
"Unrecognized message type: %s", msg.getClass().getName());
|
||||
this.backend = backend;
|
||||
this.isByteBuf = msg instanceof ByteBuf;
|
||||
this.buf = isByteBuf ? (ByteBuf) msg : null;
|
||||
this.region = isByteBuf ? null : (FileRegion) msg;
|
||||
this.byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the size of the original (unencrypted) message.
|
||||
*
|
||||
* This makes assumptions about how netty treats FileRegion instances, because there's no way
|
||||
* to know beforehand what will be the size of the encrypted message. Namely, it assumes
|
||||
* that netty will try to transfer data from this message while
|
||||
* <code>transfered() < count()</code>. So these two methods return, technically, wrong data,
|
||||
* but netty doesn't know better.
|
||||
*/
|
||||
@Override
|
||||
public long count() {
|
||||
return isByteBuf ? buf.readableBytes() : region.count();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long position() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an approximation of the amount of data transferred. See {@link #count()}.
|
||||
*/
|
||||
@Override
|
||||
public long transfered() {
|
||||
return transferred;
|
||||
}
|
||||
|
||||
/**
|
||||
* Transfers data from the original message to the channel, encrypting it in the process.
|
||||
*
|
||||
* This method also breaks down the original message into smaller chunks when needed. This
|
||||
* is done to keep memory usage under control. This avoids having to copy the whole message
|
||||
* data into memory at once, and can avoid ballooning memory usage when transferring large
|
||||
* messages such as shuffle blocks.
|
||||
*
|
||||
* The {@link #transfered()} counter also behaves a little funny, in that it won't go forward
|
||||
* until a whole chunk has been written. This is done because the code can't use the actual
|
||||
* number of bytes written to the channel as the transferred count (see {@link #count()}).
|
||||
* Instead, once an encrypted chunk is written to the output (including its header), the
|
||||
* size of the original block will be added to the {@link #transfered()} amount.
|
||||
*/
|
||||
@Override
|
||||
public long transferTo(final WritableByteChannel target, final long position)
|
||||
throws IOException {
|
||||
|
||||
Preconditions.checkArgument(position == transfered(), "Invalid position.");
|
||||
|
||||
long reportedWritten = 0L;
|
||||
long actuallyWritten = 0L;
|
||||
do {
|
||||
if (currentChunk == null) {
|
||||
nextChunk();
|
||||
}
|
||||
|
||||
if (currentHeader.readableBytes() > 0) {
|
||||
int bytesWritten = target.write(currentHeader.nioBuffer());
|
||||
currentHeader.skipBytes(bytesWritten);
|
||||
actuallyWritten += bytesWritten;
|
||||
if (currentHeader.readableBytes() > 0) {
|
||||
// Break out of loop if there are still header bytes left to write.
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
actuallyWritten += target.write(currentChunk);
|
||||
if (!currentChunk.hasRemaining()) {
|
||||
// Only update the count of written bytes once a full chunk has been written.
|
||||
// See method javadoc.
|
||||
long chunkBytesRemaining = unencryptedChunkSize - currentReportedBytes;
|
||||
reportedWritten += chunkBytesRemaining;
|
||||
transferred += chunkBytesRemaining;
|
||||
currentHeader.release();
|
||||
currentHeader = null;
|
||||
currentChunk = null;
|
||||
currentChunkSize = 0;
|
||||
currentReportedBytes = 0;
|
||||
}
|
||||
} while (currentChunk == null && transfered() + reportedWritten < count());
|
||||
|
||||
// Returning 0 triggers a backoff mechanism in netty which may harm performance. Instead,
|
||||
// we return 1 until we can (i.e. until the reported count would actually match the size
|
||||
// of the current chunk), at which point we resort to returning 0 so that the counts still
|
||||
// match, at the cost of some performance. That situation should be rare, though.
|
||||
if (reportedWritten != 0L) {
|
||||
return reportedWritten;
|
||||
}
|
||||
|
||||
if (actuallyWritten > 0 && currentReportedBytes < currentChunkSize - 1) {
|
||||
transferred += 1L;
|
||||
currentReportedBytes += 1L;
|
||||
return 1L;
|
||||
}
|
||||
|
||||
return 0L;
|
||||
}
|
||||
|
||||
private void nextChunk() throws IOException {
|
||||
byteChannel.reset();
|
||||
if (isByteBuf) {
|
||||
int copied = byteChannel.write(buf.nioBuffer());
|
||||
buf.skipBytes(copied);
|
||||
} else {
|
||||
region.transferTo(byteChannel, region.transfered());
|
||||
}
|
||||
|
||||
byte[] encrypted = backend.wrap(byteChannel.getData(), 0, byteChannel.length());
|
||||
this.currentChunk = ByteBuffer.wrap(encrypted);
|
||||
this.currentChunkSize = encrypted.length;
|
||||
this.currentHeader = Unpooled.copyLong(8 + currentChunkSize);
|
||||
this.unencryptedChunkSize = byteChannel.length();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void deallocate() {
|
||||
if (currentHeader != null) {
|
||||
currentHeader.release();
|
||||
}
|
||||
if (buf != null) {
|
||||
buf.release();
|
||||
}
|
||||
if (region != null) {
|
||||
region.release();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.network.sasl;
|
||||
|
||||
import javax.security.sasl.SaslException;
|
||||
|
||||
interface SaslEncryptionBackend {
|
||||
|
||||
/** Disposes of resources used by the backend. */
|
||||
void dispose();
|
||||
|
||||
/** Encrypt data. */
|
||||
byte[] wrap(byte[] data, int offset, int len) throws SaslException;
|
||||
|
||||
/** Decrypt data. */
|
||||
byte[] unwrap(byte[] data, int offset, int len) throws SaslException;
|
||||
|
||||
}
|
|
@ -17,10 +17,10 @@
|
|||
|
||||
package org.apache.spark.network.sasl;
|
||||
|
||||
import java.util.concurrent.ConcurrentMap;
|
||||
import javax.security.sasl.Sasl;
|
||||
|
||||
import com.google.common.collect.Maps;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.Channel;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
@ -28,6 +28,7 @@ import org.apache.spark.network.client.RpcResponseCallback;
|
|||
import org.apache.spark.network.client.TransportClient;
|
||||
import org.apache.spark.network.server.RpcHandler;
|
||||
import org.apache.spark.network.server.StreamManager;
|
||||
import org.apache.spark.network.util.TransportConf;
|
||||
|
||||
/**
|
||||
* RPC Handler which performs SASL authentication before delegating to a child RPC handler.
|
||||
|
@ -37,8 +38,14 @@ import org.apache.spark.network.server.StreamManager;
|
|||
* Note that the authentication process consists of multiple challenge-response pairs, each of
|
||||
* which are individual RPCs.
|
||||
*/
|
||||
public class SaslRpcHandler extends RpcHandler {
|
||||
private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
|
||||
class SaslRpcHandler extends RpcHandler {
|
||||
private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
|
||||
|
||||
/** Transport configuration. */
|
||||
private final TransportConf conf;
|
||||
|
||||
/** The client channel. */
|
||||
private final Channel channel;
|
||||
|
||||
/** RpcHandler we will delegate to for authenticated connections. */
|
||||
private final RpcHandler delegate;
|
||||
|
@ -46,19 +53,25 @@ public class SaslRpcHandler extends RpcHandler {
|
|||
/** Class which provides secret keys which are shared by server and client on a per-app basis. */
|
||||
private final SecretKeyHolder secretKeyHolder;
|
||||
|
||||
/** Maps each channel to its SASL authentication state. */
|
||||
private final ConcurrentMap<TransportClient, SparkSaslServer> channelAuthenticationMap;
|
||||
private SparkSaslServer saslServer;
|
||||
private boolean isComplete;
|
||||
|
||||
public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) {
|
||||
SaslRpcHandler(
|
||||
TransportConf conf,
|
||||
Channel channel,
|
||||
RpcHandler delegate,
|
||||
SecretKeyHolder secretKeyHolder) {
|
||||
this.conf = conf;
|
||||
this.channel = channel;
|
||||
this.delegate = delegate;
|
||||
this.secretKeyHolder = secretKeyHolder;
|
||||
this.channelAuthenticationMap = Maps.newConcurrentMap();
|
||||
this.saslServer = null;
|
||||
this.isComplete = false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
|
||||
SparkSaslServer saslServer = channelAuthenticationMap.get(client);
|
||||
if (saslServer != null && saslServer.isComplete()) {
|
||||
if (isComplete) {
|
||||
// Authentication complete, delegate to base handler.
|
||||
delegate.receive(client, message, callback);
|
||||
return;
|
||||
|
@ -68,15 +81,30 @@ public class SaslRpcHandler extends RpcHandler {
|
|||
|
||||
if (saslServer == null) {
|
||||
// First message in the handshake, setup the necessary state.
|
||||
saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder);
|
||||
channelAuthenticationMap.put(client, saslServer);
|
||||
saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
|
||||
conf.saslServerAlwaysEncrypt());
|
||||
}
|
||||
|
||||
byte[] response = saslServer.response(saslMessage.payload);
|
||||
callback.onSuccess(response);
|
||||
|
||||
// Setup encryption after the SASL response is sent, otherwise the client can't parse the
|
||||
// response. It's ok to change the channel pipeline here since we are processing an incoming
|
||||
// message, so the pipeline is busy and no new incoming messages will be fed to it before this
|
||||
// method returns. This assumes that the code ensures, through other means, that no outbound
|
||||
// messages are being written to the channel while negotiation is still going on.
|
||||
if (saslServer.isComplete()) {
|
||||
logger.debug("SASL authentication successful for channel {}", client);
|
||||
isComplete = true;
|
||||
if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
|
||||
logger.debug("Enabling encryption for channel {}", client);
|
||||
SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
|
||||
saslServer = null;
|
||||
} else {
|
||||
saslServer.dispose();
|
||||
saslServer = null;
|
||||
}
|
||||
}
|
||||
callback.onSuccess(response);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -86,9 +114,9 @@ public class SaslRpcHandler extends RpcHandler {
|
|||
|
||||
@Override
|
||||
public void connectionTerminated(TransportClient client) {
|
||||
SparkSaslServer saslServer = channelAuthenticationMap.remove(client);
|
||||
if (saslServer != null) {
|
||||
saslServer.dispose();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.network.sasl;
|
||||
|
||||
import io.netty.channel.Channel;
|
||||
|
||||
import org.apache.spark.network.server.RpcHandler;
|
||||
import org.apache.spark.network.server.TransportServerBootstrap;
|
||||
import org.apache.spark.network.util.TransportConf;
|
||||
|
||||
/**
|
||||
* A bootstrap which is executed on a TransportServer's client channel once a client connects
|
||||
* to the server. This allows customizing the client channel to allow for things such as SASL
|
||||
* authentication.
|
||||
*/
|
||||
public class SaslServerBootstrap implements TransportServerBootstrap {
|
||||
|
||||
private final TransportConf conf;
|
||||
private final SecretKeyHolder secretKeyHolder;
|
||||
|
||||
public SaslServerBootstrap(TransportConf conf, SecretKeyHolder secretKeyHolder) {
|
||||
this.conf = conf;
|
||||
this.secretKeyHolder = secretKeyHolder;
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap the given application handler in a SaslRpcHandler that will handle the initial SASL
|
||||
* negotiation.
|
||||
*/
|
||||
public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
|
||||
return new SaslRpcHandler(conf, channel, rpcHandler, secretKeyHolder);
|
||||
}
|
||||
|
||||
}
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
package org.apache.spark.network.sasl;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import javax.security.auth.callback.Callback;
|
||||
import javax.security.auth.callback.CallbackHandler;
|
||||
import javax.security.auth.callback.NameCallback;
|
||||
|
@ -27,9 +29,9 @@ import javax.security.sasl.RealmChoiceCallback;
|
|||
import javax.security.sasl.Sasl;
|
||||
import javax.security.sasl.SaslClient;
|
||||
import javax.security.sasl.SaslException;
|
||||
import java.io.IOException;
|
||||
|
||||
import com.google.common.base.Throwables;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
@ -40,19 +42,25 @@ import static org.apache.spark.network.sasl.SparkSaslServer.*;
|
|||
* initial state to the "authenticated" state. This client initializes the protocol via a
|
||||
* firstToken, which is then followed by a set of challenges and responses.
|
||||
*/
|
||||
public class SparkSaslClient {
|
||||
public class SparkSaslClient implements SaslEncryptionBackend {
|
||||
private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class);
|
||||
|
||||
private final String secretKeyId;
|
||||
private final SecretKeyHolder secretKeyHolder;
|
||||
private final String expectedQop;
|
||||
private SaslClient saslClient;
|
||||
|
||||
public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) {
|
||||
public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, boolean encrypt) {
|
||||
this.secretKeyId = secretKeyId;
|
||||
this.secretKeyHolder = secretKeyHolder;
|
||||
this.expectedQop = encrypt ? QOP_AUTH_CONF : QOP_AUTH;
|
||||
|
||||
Map<String, String> saslProps = ImmutableMap.<String, String>builder()
|
||||
.put(Sasl.QOP, expectedQop)
|
||||
.build();
|
||||
try {
|
||||
this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM,
|
||||
SASL_PROPS, new ClientCallbackHandler());
|
||||
saslProps, new ClientCallbackHandler());
|
||||
} catch (SaslException e) {
|
||||
throw Throwables.propagate(e);
|
||||
}
|
||||
|
@ -76,6 +84,11 @@ public class SparkSaslClient {
|
|||
return saslClient != null && saslClient.isComplete();
|
||||
}
|
||||
|
||||
/** Returns the value of a negotiated property. */
|
||||
public Object getNegotiatedProperty(String name) {
|
||||
return saslClient.getNegotiatedProperty(name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Respond to server's SASL token.
|
||||
* @param token contains server's SASL token
|
||||
|
@ -93,6 +106,7 @@ public class SparkSaslClient {
|
|||
* Disposes of any system resources or security-sensitive information the
|
||||
* SaslClient might be using.
|
||||
*/
|
||||
@Override
|
||||
public synchronized void dispose() {
|
||||
if (saslClient != null) {
|
||||
try {
|
||||
|
@ -134,4 +148,15 @@ public class SparkSaslClient {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] wrap(byte[] data, int offset, int len) throws SaslException {
|
||||
return saslClient.wrap(data, offset, len);
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
|
||||
return saslClient.unwrap(data, offset, len);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ import org.slf4j.LoggerFactory;
|
|||
* initial state to the "authenticated" state. (It is not a server in the sense of accepting
|
||||
* connections on some socket.)
|
||||
*/
|
||||
public class SparkSaslServer {
|
||||
public class SparkSaslServer implements SaslEncryptionBackend {
|
||||
private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class);
|
||||
|
||||
/**
|
||||
|
@ -60,26 +60,37 @@ public class SparkSaslServer {
|
|||
static final String DIGEST = "DIGEST-MD5";
|
||||
|
||||
/**
|
||||
* The quality of protection is just "auth". This means that we are doing
|
||||
* authentication only, we are not supporting integrity or privacy protection of the
|
||||
* communication channel after authentication. This could be changed to be configurable
|
||||
* in the future.
|
||||
* Quality of protection value that includes encryption.
|
||||
*/
|
||||
static final Map<String, String> SASL_PROPS = ImmutableMap.<String, String>builder()
|
||||
.put(Sasl.QOP, "auth")
|
||||
.put(Sasl.SERVER_AUTH, "true")
|
||||
.build();
|
||||
static final String QOP_AUTH_CONF = "auth-conf";
|
||||
|
||||
/**
|
||||
* Quality of protection value that does not include encryption.
|
||||
*/
|
||||
static final String QOP_AUTH = "auth";
|
||||
|
||||
/** Identifier for a certain secret key within the secretKeyHolder. */
|
||||
private final String secretKeyId;
|
||||
private final SecretKeyHolder secretKeyHolder;
|
||||
private SaslServer saslServer;
|
||||
|
||||
public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) {
|
||||
public SparkSaslServer(
|
||||
String secretKeyId,
|
||||
SecretKeyHolder secretKeyHolder,
|
||||
boolean alwaysEncrypt) {
|
||||
this.secretKeyId = secretKeyId;
|
||||
this.secretKeyHolder = secretKeyHolder;
|
||||
|
||||
// Sasl.QOP is a comma-separated list of supported values. The value that allows encryption
|
||||
// is listed first since it's preferred over the non-encrypted one (if the client also
|
||||
// lists both in the request).
|
||||
String qop = alwaysEncrypt ? QOP_AUTH_CONF : String.format("%s,%s", QOP_AUTH_CONF, QOP_AUTH);
|
||||
Map<String, String> saslProps = ImmutableMap.<String, String>builder()
|
||||
.put(Sasl.SERVER_AUTH, "true")
|
||||
.put(Sasl.QOP, qop)
|
||||
.build();
|
||||
try {
|
||||
this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS,
|
||||
this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, saslProps,
|
||||
new DigestCallbackHandler());
|
||||
} catch (SaslException e) {
|
||||
throw Throwables.propagate(e);
|
||||
|
@ -93,6 +104,11 @@ public class SparkSaslServer {
|
|||
return saslServer != null && saslServer.isComplete();
|
||||
}
|
||||
|
||||
/** Returns the value of a negotiated property. */
|
||||
public Object getNegotiatedProperty(String name) {
|
||||
return saslServer.getNegotiatedProperty(name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Used to respond to server SASL tokens.
|
||||
* @param token Server's SASL token
|
||||
|
@ -110,6 +126,7 @@ public class SparkSaslServer {
|
|||
* Disposes of any system resources or security-sensitive information the
|
||||
* SaslServer might be using.
|
||||
*/
|
||||
@Override
|
||||
public synchronized void dispose() {
|
||||
if (saslServer != null) {
|
||||
try {
|
||||
|
@ -122,6 +139,16 @@ public class SparkSaslServer {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] wrap(byte[] data, int offset, int len) throws SaslException {
|
||||
return saslServer.wrap(data, offset, len);
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
|
||||
return saslServer.unwrap(data, offset, len);
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism.
|
||||
*/
|
||||
|
|
|
@ -19,8 +19,11 @@ package org.apache.spark.network.server;
|
|||
|
||||
import java.io.Closeable;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
import io.netty.bootstrap.ServerBootstrap;
|
||||
import io.netty.buffer.PooledByteBufAllocator;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
|
@ -44,15 +47,23 @@ public class TransportServer implements Closeable {
|
|||
|
||||
private final TransportContext context;
|
||||
private final TransportConf conf;
|
||||
private final RpcHandler appRpcHandler;
|
||||
private final List<TransportServerBootstrap> bootstraps;
|
||||
|
||||
private ServerBootstrap bootstrap;
|
||||
private ChannelFuture channelFuture;
|
||||
private int port = -1;
|
||||
|
||||
/** Creates a TransportServer that binds to the given port, or to any available if 0. */
|
||||
public TransportServer(TransportContext context, int portToBind) {
|
||||
public TransportServer(
|
||||
TransportContext context,
|
||||
int portToBind,
|
||||
RpcHandler appRpcHandler,
|
||||
List<TransportServerBootstrap> bootstraps) {
|
||||
this.context = context;
|
||||
this.conf = context.getConf();
|
||||
this.appRpcHandler = appRpcHandler;
|
||||
this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps));
|
||||
|
||||
init(portToBind);
|
||||
}
|
||||
|
@ -95,7 +106,11 @@ public class TransportServer implements Closeable {
|
|||
bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {
|
||||
@Override
|
||||
protected void initChannel(SocketChannel ch) throws Exception {
|
||||
context.initializePipeline(ch);
|
||||
RpcHandler rpcHandler = appRpcHandler;
|
||||
for (TransportServerBootstrap bootstrap : bootstraps) {
|
||||
rpcHandler = bootstrap.doBootstrap(ch, rpcHandler);
|
||||
}
|
||||
context.initializePipeline(ch, rpcHandler);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.network.server;
|
||||
|
||||
import io.netty.channel.Channel;
|
||||
|
||||
/**
|
||||
* A bootstrap which is executed on a TransportServer's client channel once a client connects
|
||||
* to the server. This allows customizing the client channel to allow for things such as SASL
|
||||
* authentication.
|
||||
*/
|
||||
public interface TransportServerBootstrap {
|
||||
/**
|
||||
* Customizes the channel to include new features, if needed.
|
||||
*
|
||||
* @param channel The connected channel opened by the client.
|
||||
* @param rpcHandler The RPC handler for the server.
|
||||
* @return The RPC handler to use for the channel.
|
||||
*/
|
||||
RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler);
|
||||
}
|
|
@ -15,11 +15,14 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.network;
|
||||
package org.apache.spark.network.util;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.WritableByteChannel;
|
||||
|
||||
/**
|
||||
* A writable channel that stores the written data in a byte array in memory.
|
||||
*/
|
||||
public class ByteArrayWritableChannel implements WritableByteChannel {
|
||||
|
||||
private final byte[] data;
|
||||
|
@ -27,19 +30,30 @@ public class ByteArrayWritableChannel implements WritableByteChannel {
|
|||
|
||||
public ByteArrayWritableChannel(int size) {
|
||||
this.data = new byte[size];
|
||||
this.offset = 0;
|
||||
}
|
||||
|
||||
public byte[] getData() {
|
||||
return data;
|
||||
}
|
||||
|
||||
public int length() {
|
||||
return offset;
|
||||
}
|
||||
|
||||
/** Resets the channel so that writing to it will overwrite the existing buffer. */
|
||||
public void reset() {
|
||||
offset = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads from the given buffer into the internal byte array.
|
||||
*/
|
||||
@Override
|
||||
public int write(ByteBuffer src) {
|
||||
int available = src.remaining();
|
||||
src.get(data, offset, available);
|
||||
offset += available;
|
||||
return available;
|
||||
int toTransfer = Math.min(src.remaining(), data.length - offset);
|
||||
src.get(data, offset, toTransfer);
|
||||
offset += toTransfer;
|
||||
return toTransfer;
|
||||
}
|
||||
|
||||
@Override
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
package org.apache.spark.network.util;
|
||||
|
||||
import com.google.common.primitives.Ints;
|
||||
|
||||
/**
|
||||
* A central location that tracks all the settings we expose to users.
|
||||
*/
|
||||
|
@ -112,4 +114,20 @@ public class TransportConf {
|
|||
public int portMaxRetries() {
|
||||
return conf.getInt("spark.port.maxRetries", 16);
|
||||
}
|
||||
|
||||
/**
|
||||
* Maximum number of bytes to be encrypted at a time when SASL encryption is enabled.
|
||||
*/
|
||||
public int maxSaslEncryptedBlockSize() {
|
||||
return Ints.checkedCast(JavaUtils.byteStringAsBytes(
|
||||
conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k")));
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether the server should enforce encryption on SASL-authenticated connections.
|
||||
*/
|
||||
public boolean saslServerAlwaysEncrypt() {
|
||||
return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -39,6 +39,7 @@ import org.apache.spark.network.protocol.RpcFailure;
|
|||
import org.apache.spark.network.protocol.RpcRequest;
|
||||
import org.apache.spark.network.protocol.RpcResponse;
|
||||
import org.apache.spark.network.protocol.StreamChunkId;
|
||||
import org.apache.spark.network.util.ByteArrayWritableChannel;
|
||||
import org.apache.spark.network.util.NettyUtils;
|
||||
|
||||
public class ProtocolSuite {
|
||||
|
|
|
@ -29,7 +29,7 @@ import org.junit.Test;
|
|||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
import org.apache.spark.network.ByteArrayWritableChannel;
|
||||
import org.apache.spark.network.util.ByteArrayWritableChannel;
|
||||
|
||||
public class MessageWithHeaderSuite {
|
||||
|
||||
|
|
|
@ -17,12 +17,47 @@
|
|||
|
||||
package org.apache.spark.network.sasl;
|
||||
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.fail;
|
||||
import static com.google.common.base.Charsets.UTF_8;
|
||||
import static org.junit.Assert.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import javax.security.sasl.SaslException;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.io.ByteStreams;
|
||||
import com.google.common.io.Files;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelOutboundHandlerAdapter;
|
||||
import io.netty.channel.ChannelPromise;
|
||||
import org.junit.Test;
|
||||
import org.mockito.invocation.InvocationOnMock;
|
||||
import org.mockito.stubbing.Answer;
|
||||
|
||||
import org.apache.spark.network.TestUtils;
|
||||
import org.apache.spark.network.TransportContext;
|
||||
import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
|
||||
import org.apache.spark.network.buffer.ManagedBuffer;
|
||||
import org.apache.spark.network.client.ChunkReceivedCallback;
|
||||
import org.apache.spark.network.client.RpcResponseCallback;
|
||||
import org.apache.spark.network.client.TransportClient;
|
||||
import org.apache.spark.network.client.TransportClientBootstrap;
|
||||
import org.apache.spark.network.server.RpcHandler;
|
||||
import org.apache.spark.network.server.StreamManager;
|
||||
import org.apache.spark.network.server.TransportServer;
|
||||
import org.apache.spark.network.server.TransportServerBootstrap;
|
||||
import org.apache.spark.network.util.ByteArrayWritableChannel;
|
||||
import org.apache.spark.network.util.SystemPropertyConfigProvider;
|
||||
import org.apache.spark.network.util.TransportConf;
|
||||
|
||||
/**
|
||||
* Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes.
|
||||
|
@ -44,8 +79,8 @@ public class SparkSaslSuite {
|
|||
|
||||
@Test
|
||||
public void testMatching() {
|
||||
SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder);
|
||||
SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder);
|
||||
SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder, false);
|
||||
SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder, false);
|
||||
|
||||
assertFalse(client.isComplete());
|
||||
assertFalse(server.isComplete());
|
||||
|
@ -64,11 +99,10 @@ public class SparkSaslSuite {
|
|||
assertFalse(client.isComplete());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testNonMatching() {
|
||||
SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder);
|
||||
SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder);
|
||||
SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder, false);
|
||||
SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder, false);
|
||||
|
||||
assertFalse(client.isComplete());
|
||||
assertFalse(server.isComplete());
|
||||
|
@ -86,4 +120,312 @@ public class SparkSaslSuite {
|
|||
assertFalse(server.isComplete());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSaslAuthentication() throws Exception {
|
||||
testBasicSasl(false);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSaslEncryption() throws Exception {
|
||||
testBasicSasl(true);
|
||||
}
|
||||
|
||||
private void testBasicSasl(boolean encrypt) throws Exception {
|
||||
RpcHandler rpcHandler = mock(RpcHandler.class);
|
||||
doAnswer(new Answer<Void>() {
|
||||
@Override
|
||||
public Void answer(InvocationOnMock invocation) {
|
||||
byte[] message = (byte[]) invocation.getArguments()[1];
|
||||
RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2];
|
||||
assertEquals("Ping", new String(message, UTF_8));
|
||||
cb.onSuccess("Pong".getBytes(UTF_8));
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.when(rpcHandler)
|
||||
.receive(any(TransportClient.class), any(byte[].class), any(RpcResponseCallback.class));
|
||||
|
||||
SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
|
||||
try {
|
||||
byte[] response = ctx.client.sendRpcSync("Ping".getBytes(UTF_8), TimeUnit.SECONDS.toMillis(10));
|
||||
assertEquals("Pong", new String(response, UTF_8));
|
||||
} finally {
|
||||
ctx.close();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEncryptedMessage() throws Exception {
|
||||
SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class);
|
||||
byte[] data = new byte[1024];
|
||||
new Random().nextBytes(data);
|
||||
when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data);
|
||||
|
||||
ByteBuf msg = Unpooled.buffer();
|
||||
try {
|
||||
msg.writeBytes(data);
|
||||
|
||||
// Create a channel with a really small buffer compared to the data. This means that on each
|
||||
// call, the outbound data will not be fully written, so the write() method should return a
|
||||
// dummy count to keep the channel alive when possible.
|
||||
ByteArrayWritableChannel channel = new ByteArrayWritableChannel(32);
|
||||
|
||||
SaslEncryption.EncryptedMessage emsg =
|
||||
new SaslEncryption.EncryptedMessage(backend, msg, 1024);
|
||||
long count = emsg.transferTo(channel, emsg.transfered());
|
||||
assertTrue(count < data.length);
|
||||
assertTrue(count > 0);
|
||||
|
||||
// Here, the output buffer is full so nothing should be transferred.
|
||||
assertEquals(0, emsg.transferTo(channel, emsg.transfered()));
|
||||
|
||||
// Now there's room in the buffer, but not enough to transfer all the remaining data,
|
||||
// so the dummy count should be returned.
|
||||
channel.reset();
|
||||
assertEquals(1, emsg.transferTo(channel, emsg.transfered()));
|
||||
|
||||
// Eventually, the whole message should be transferred.
|
||||
for (int i = 0; i < data.length / 32 - 2; i++) {
|
||||
channel.reset();
|
||||
assertEquals(1, emsg.transferTo(channel, emsg.transfered()));
|
||||
}
|
||||
|
||||
channel.reset();
|
||||
count = emsg.transferTo(channel, emsg.transfered());
|
||||
assertTrue("Unexpected count: " + count, count > 1 && count < data.length);
|
||||
assertEquals(data.length, emsg.transfered());
|
||||
} finally {
|
||||
msg.release();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEncryptedMessageChunking() throws Exception {
|
||||
File file = File.createTempFile("sasltest", ".txt");
|
||||
try {
|
||||
TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
|
||||
|
||||
byte[] data = new byte[8 * 1024];
|
||||
new Random().nextBytes(data);
|
||||
Files.write(data, file);
|
||||
|
||||
SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class);
|
||||
// It doesn't really matter what we return here, as long as it's not null.
|
||||
when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data);
|
||||
|
||||
FileSegmentManagedBuffer msg = new FileSegmentManagedBuffer(conf, file, 0, file.length());
|
||||
SaslEncryption.EncryptedMessage emsg =
|
||||
new SaslEncryption.EncryptedMessage(backend, msg.convertToNetty(), data.length / 8);
|
||||
|
||||
ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length);
|
||||
while (emsg.transfered() < emsg.count()) {
|
||||
channel.reset();
|
||||
emsg.transferTo(channel, emsg.transfered());
|
||||
}
|
||||
|
||||
verify(backend, times(8)).wrap(any(byte[].class), anyInt(), anyInt());
|
||||
} finally {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFileRegionEncryption() throws Exception {
|
||||
final String blockSizeConf = "spark.network.sasl.maxEncryptedBlockSize";
|
||||
System.setProperty(blockSizeConf, "1k");
|
||||
|
||||
final AtomicReference<ManagedBuffer> response = new AtomicReference();
|
||||
final File file = File.createTempFile("sasltest", ".txt");
|
||||
SaslTestCtx ctx = null;
|
||||
try {
|
||||
final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
|
||||
StreamManager sm = mock(StreamManager.class);
|
||||
when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer<ManagedBuffer>() {
|
||||
@Override
|
||||
public ManagedBuffer answer(InvocationOnMock invocation) {
|
||||
return new FileSegmentManagedBuffer(conf, file, 0, file.length());
|
||||
}
|
||||
});
|
||||
|
||||
RpcHandler rpcHandler = mock(RpcHandler.class);
|
||||
when(rpcHandler.getStreamManager()).thenReturn(sm);
|
||||
|
||||
byte[] data = new byte[8 * 1024];
|
||||
new Random().nextBytes(data);
|
||||
Files.write(data, file);
|
||||
|
||||
ctx = new SaslTestCtx(rpcHandler, true, false);
|
||||
|
||||
final Object lock = new Object();
|
||||
|
||||
ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
|
||||
doAnswer(new Answer<Void>() {
|
||||
@Override
|
||||
public Void answer(InvocationOnMock invocation) {
|
||||
response.set((ManagedBuffer) invocation.getArguments()[1]);
|
||||
response.get().retain();
|
||||
synchronized (lock) {
|
||||
lock.notifyAll();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class));
|
||||
|
||||
synchronized (lock) {
|
||||
ctx.client.fetchChunk(0, 0, callback);
|
||||
lock.wait(10 * 1000);
|
||||
}
|
||||
|
||||
verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class));
|
||||
verify(callback, never()).onFailure(anyInt(), any(Throwable.class));
|
||||
|
||||
byte[] received = ByteStreams.toByteArray(response.get().createInputStream());
|
||||
assertTrue(Arrays.equals(data, received));
|
||||
} finally {
|
||||
file.delete();
|
||||
if (ctx != null) {
|
||||
ctx.close();
|
||||
}
|
||||
if (response.get() != null) {
|
||||
response.get().release();
|
||||
}
|
||||
System.clearProperty(blockSizeConf);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testServerAlwaysEncrypt() throws Exception {
|
||||
final String alwaysEncryptConfName = "spark.network.sasl.serverAlwaysEncrypt";
|
||||
System.setProperty(alwaysEncryptConfName, "true");
|
||||
|
||||
SaslTestCtx ctx = null;
|
||||
try {
|
||||
ctx = new SaslTestCtx(mock(RpcHandler.class), false, false);
|
||||
fail("Should have failed to connect without encryption.");
|
||||
} catch (Exception e) {
|
||||
assertTrue(e.getCause() instanceof SaslException);
|
||||
} finally {
|
||||
if (ctx != null) {
|
||||
ctx.close();
|
||||
}
|
||||
System.clearProperty(alwaysEncryptConfName);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDataEncryptionIsActuallyEnabled() throws Exception {
|
||||
// This test sets up an encrypted connection but then, using a client bootstrap, removes
|
||||
// the encryption handler from the client side. This should cause the server to not be
|
||||
// able to understand RPCs sent to it and thus close the connection.
|
||||
SaslTestCtx ctx = null;
|
||||
try {
|
||||
ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
|
||||
ctx.client.sendRpcSync("Ping".getBytes(UTF_8), TimeUnit.SECONDS.toMillis(10));
|
||||
fail("Should have failed to send RPC to server.");
|
||||
} catch (Exception e) {
|
||||
assertFalse(e.getCause() instanceof TimeoutException);
|
||||
} finally {
|
||||
if (ctx != null) {
|
||||
ctx.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static class SaslTestCtx {
|
||||
|
||||
final TransportClient client;
|
||||
final TransportServer server;
|
||||
|
||||
private final boolean encrypt;
|
||||
private final boolean disableClientEncryption;
|
||||
private final EncryptionCheckerBootstrap checker;
|
||||
|
||||
SaslTestCtx(
|
||||
RpcHandler rpcHandler,
|
||||
boolean encrypt,
|
||||
boolean disableClientEncryption)
|
||||
throws Exception {
|
||||
|
||||
TransportConf conf = new TransportConf(new SystemPropertyConfigProvider());
|
||||
|
||||
SecretKeyHolder keyHolder = mock(SecretKeyHolder.class);
|
||||
when(keyHolder.getSaslUser(anyString())).thenReturn("user");
|
||||
when(keyHolder.getSecretKey(anyString())).thenReturn("secret");
|
||||
|
||||
TransportContext ctx = new TransportContext(conf, rpcHandler);
|
||||
|
||||
this.checker = new EncryptionCheckerBootstrap();
|
||||
this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder),
|
||||
checker));
|
||||
|
||||
try {
|
||||
List<TransportClientBootstrap> clientBootstraps = Lists.newArrayList();
|
||||
clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder, encrypt));
|
||||
if (disableClientEncryption) {
|
||||
clientBootstraps.add(new EncryptionDisablerBootstrap());
|
||||
}
|
||||
|
||||
this.client = ctx.createClientFactory(clientBootstraps)
|
||||
.createClient(TestUtils.getLocalHost(), server.getPort());
|
||||
} catch (Exception e) {
|
||||
close();
|
||||
throw e;
|
||||
}
|
||||
|
||||
this.encrypt = encrypt;
|
||||
this.disableClientEncryption = disableClientEncryption;
|
||||
}
|
||||
|
||||
void close() {
|
||||
if (!disableClientEncryption) {
|
||||
assertEquals(encrypt, checker.foundEncryptionHandler);
|
||||
}
|
||||
if (client != null) {
|
||||
client.close();
|
||||
}
|
||||
if (server != null) {
|
||||
server.close();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAdapter
|
||||
implements TransportServerBootstrap {
|
||||
|
||||
boolean foundEncryptionHandler;
|
||||
|
||||
@Override
|
||||
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
|
||||
throws Exception {
|
||||
if (!foundEncryptionHandler) {
|
||||
foundEncryptionHandler =
|
||||
ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null;
|
||||
}
|
||||
ctx.write(msg, promise);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
|
||||
super.handlerRemoved(ctx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) {
|
||||
channel.pipeline().addFirst("encryptionChecker", this);
|
||||
return rpcHandler;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static class EncryptionDisablerBootstrap implements TransportClientBootstrap {
|
||||
|
||||
@Override
|
||||
public void doBootstrap(TransportClient client, Channel channel) {
|
||||
channel.pipeline().remove(SaslEncryption.ENCRYPTION_HANDLER_NAME);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.network.shuffle;
|
|||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
@ -46,6 +47,7 @@ public class ExternalShuffleClient extends ShuffleClient {
|
|||
|
||||
private final TransportConf conf;
|
||||
private final boolean saslEnabled;
|
||||
private final boolean saslEncryptionEnabled;
|
||||
private final SecretKeyHolder secretKeyHolder;
|
||||
|
||||
private TransportClientFactory clientFactory;
|
||||
|
@ -58,10 +60,15 @@ public class ExternalShuffleClient extends ShuffleClient {
|
|||
public ExternalShuffleClient(
|
||||
TransportConf conf,
|
||||
SecretKeyHolder secretKeyHolder,
|
||||
boolean saslEnabled) {
|
||||
boolean saslEnabled,
|
||||
boolean saslEncryptionEnabled) {
|
||||
Preconditions.checkArgument(
|
||||
!saslEncryptionEnabled || saslEnabled,
|
||||
"SASL encryption can only be enabled if SASL is also enabled.");
|
||||
this.conf = conf;
|
||||
this.secretKeyHolder = secretKeyHolder;
|
||||
this.saslEnabled = saslEnabled;
|
||||
this.saslEncryptionEnabled = saslEncryptionEnabled;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -70,7 +77,7 @@ public class ExternalShuffleClient extends ShuffleClient {
|
|||
TransportContext context = new TransportContext(conf, new NoOpRpcHandler());
|
||||
List<TransportClientBootstrap> bootstraps = Lists.newArrayList();
|
||||
if (saslEnabled) {
|
||||
bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder));
|
||||
bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled));
|
||||
}
|
||||
clientFactory = context.createClientFactory(bootstraps);
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.network.sasl;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.After;
|
||||
|
@ -37,6 +38,7 @@ import org.apache.spark.network.server.OneForOneStreamManager;
|
|||
import org.apache.spark.network.server.RpcHandler;
|
||||
import org.apache.spark.network.server.StreamManager;
|
||||
import org.apache.spark.network.server.TransportServer;
|
||||
import org.apache.spark.network.server.TransportServerBootstrap;
|
||||
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
|
||||
import org.apache.spark.network.util.SystemPropertyConfigProvider;
|
||||
import org.apache.spark.network.util.TransportConf;
|
||||
|
@ -72,10 +74,11 @@ public class SaslIntegrationSuite {
|
|||
@BeforeClass
|
||||
public static void beforeAll() throws IOException {
|
||||
SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key");
|
||||
SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder);
|
||||
conf = new TransportConf(new SystemPropertyConfigProvider());
|
||||
context = new TransportContext(conf, handler);
|
||||
server = context.createServer();
|
||||
context = new TransportContext(conf, new TestRpcHandler());
|
||||
|
||||
TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
|
||||
server = context.createServer(Arrays.asList(bootstrap));
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -136,7 +136,7 @@ public class ExternalShuffleIntegrationSuite {
|
|||
|
||||
final Semaphore requestsRemaining = new Semaphore(0);
|
||||
|
||||
ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false);
|
||||
ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false);
|
||||
client.init(APP_ID);
|
||||
client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds,
|
||||
new BlockFetchingListener() {
|
||||
|
@ -274,7 +274,7 @@ public class ExternalShuffleIntegrationSuite {
|
|||
|
||||
private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo)
|
||||
throws IOException {
|
||||
ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false);
|
||||
ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false);
|
||||
client.init(APP_ID);
|
||||
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(),
|
||||
executorId, executorInfo);
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.network.shuffle;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
@ -27,10 +28,11 @@ import static org.junit.Assert.*;
|
|||
|
||||
import org.apache.spark.network.TestUtils;
|
||||
import org.apache.spark.network.TransportContext;
|
||||
import org.apache.spark.network.sasl.SaslRpcHandler;
|
||||
import org.apache.spark.network.sasl.SaslServerBootstrap;
|
||||
import org.apache.spark.network.sasl.SecretKeyHolder;
|
||||
import org.apache.spark.network.server.RpcHandler;
|
||||
import org.apache.spark.network.server.TransportServer;
|
||||
import org.apache.spark.network.server.TransportServerBootstrap;
|
||||
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
|
||||
import org.apache.spark.network.util.SystemPropertyConfigProvider;
|
||||
import org.apache.spark.network.util.TransportConf;
|
||||
|
@ -42,10 +44,10 @@ public class ExternalShuffleSecuritySuite {
|
|||
|
||||
@Before
|
||||
public void beforeEach() {
|
||||
RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(conf),
|
||||
new TestSecretKeyHolder("my-app-id", "secret"));
|
||||
TransportContext context = new TransportContext(conf, handler);
|
||||
this.server = context.createServer();
|
||||
TransportContext context = new TransportContext(conf, new ExternalShuffleBlockHandler(conf));
|
||||
TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf,
|
||||
new TestSecretKeyHolder("my-app-id", "secret"));
|
||||
this.server = context.createServer(Arrays.asList(bootstrap));
|
||||
}
|
||||
|
||||
@After
|
||||
|
@ -58,13 +60,13 @@ public class ExternalShuffleSecuritySuite {
|
|||
|
||||
@Test
|
||||
public void testValid() throws IOException {
|
||||
validate("my-app-id", "secret");
|
||||
validate("my-app-id", "secret", false);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBadAppId() {
|
||||
try {
|
||||
validate("wrong-app-id", "secret");
|
||||
validate("wrong-app-id", "secret", false);
|
||||
} catch (Exception e) {
|
||||
assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!"));
|
||||
}
|
||||
|
@ -73,16 +75,21 @@ public class ExternalShuffleSecuritySuite {
|
|||
@Test
|
||||
public void testBadSecret() {
|
||||
try {
|
||||
validate("my-app-id", "bad-secret");
|
||||
validate("my-app-id", "bad-secret", false);
|
||||
} catch (Exception e) {
|
||||
assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEncryption() throws IOException {
|
||||
validate("my-app-id", "secret", true);
|
||||
}
|
||||
|
||||
/** Creates an ExternalShuffleClient and attempts to register with the server. */
|
||||
private void validate(String appId, String secretKey) throws IOException {
|
||||
private void validate(String appId, String secretKey, boolean encrypt) throws IOException {
|
||||
ExternalShuffleClient client =
|
||||
new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true);
|
||||
new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true, encrypt);
|
||||
client.init(appId);
|
||||
// Registration either succeeds or throws an exception.
|
||||
client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0",
|
||||
|
|
|
@ -17,9 +17,10 @@
|
|||
|
||||
package org.apache.spark.network.yarn;
|
||||
|
||||
import java.lang.Override;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.List;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.yarn.api.records.ApplicationId;
|
||||
import org.apache.hadoop.yarn.api.records.ContainerId;
|
||||
|
@ -32,10 +33,11 @@ import org.slf4j.Logger;
|
|||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.apache.spark.network.TransportContext;
|
||||
import org.apache.spark.network.sasl.SaslRpcHandler;
|
||||
import org.apache.spark.network.sasl.SaslServerBootstrap;
|
||||
import org.apache.spark.network.sasl.ShuffleSecretManager;
|
||||
import org.apache.spark.network.server.RpcHandler;
|
||||
import org.apache.spark.network.server.TransportServer;
|
||||
import org.apache.spark.network.server.TransportServerBootstrap;
|
||||
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler;
|
||||
import org.apache.spark.network.util.TransportConf;
|
||||
import org.apache.spark.network.yarn.util.HadoopConfigProvider;
|
||||
|
@ -103,16 +105,17 @@ public class YarnShuffleService extends AuxiliaryService {
|
|||
// special RPC handler that filters out unauthenticated fetch requests
|
||||
boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE);
|
||||
blockHandler = new ExternalShuffleBlockHandler(transportConf);
|
||||
RpcHandler rpcHandler = blockHandler;
|
||||
|
||||
List<TransportServerBootstrap> bootstraps = Lists.newArrayList();
|
||||
if (authEnabled) {
|
||||
secretManager = new ShuffleSecretManager();
|
||||
rpcHandler = new SaslRpcHandler(rpcHandler, secretManager);
|
||||
bootstraps.add(new SaslServerBootstrap(transportConf, secretManager));
|
||||
}
|
||||
|
||||
int port = conf.getInt(
|
||||
SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT);
|
||||
TransportContext transportContext = new TransportContext(transportConf, rpcHandler);
|
||||
shuffleServer = transportContext.createServer(port);
|
||||
TransportContext transportContext = new TransportContext(transportConf, blockHandler);
|
||||
shuffleServer = transportContext.createServer(port, bootstraps);
|
||||
String authEnabledString = authEnabled ? "enabled" : "not enabled";
|
||||
logger.info("Started YARN shuffle service for Spark on port {}. " +
|
||||
"Authentication is {}.", port, authEnabledString);
|
||||
|
|
Loading…
Reference in a new issue