Apply appropriate RPC handler to receive, receiveStream when auth enabled
This commit is contained in:
parent
a7fb330ed3
commit
61b7d446b3
|
@ -29,12 +29,11 @@ import org.slf4j.Logger;
|
|||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.apache.spark.network.client.RpcResponseCallback;
|
||||
import org.apache.spark.network.client.StreamCallbackWithID;
|
||||
import org.apache.spark.network.client.TransportClient;
|
||||
import org.apache.spark.network.sasl.SecretKeyHolder;
|
||||
import org.apache.spark.network.sasl.SaslRpcHandler;
|
||||
import org.apache.spark.network.server.AbstractAuthRpcHandler;
|
||||
import org.apache.spark.network.server.RpcHandler;
|
||||
import org.apache.spark.network.server.StreamManager;
|
||||
import org.apache.spark.network.util.TransportConf;
|
||||
|
||||
/**
|
||||
|
@ -46,7 +45,7 @@ import org.apache.spark.network.util.TransportConf;
|
|||
* The delegate will only receive messages if the given connection has been successfully
|
||||
* authenticated. A connection may be authenticated at most once.
|
||||
*/
|
||||
class AuthRpcHandler extends RpcHandler {
|
||||
class AuthRpcHandler extends AbstractAuthRpcHandler {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(AuthRpcHandler.class);
|
||||
|
||||
/** Transport configuration. */
|
||||
|
@ -55,36 +54,31 @@ class AuthRpcHandler extends RpcHandler {
|
|||
/** The client channel. */
|
||||
private final Channel channel;
|
||||
|
||||
/**
|
||||
* RpcHandler we will delegate to for authenticated connections. When falling back to SASL
|
||||
* this will be replaced with the SASL RPC handler.
|
||||
*/
|
||||
@VisibleForTesting
|
||||
RpcHandler delegate;
|
||||
|
||||
/** Class which provides secret keys which are shared by server and client on a per-app basis. */
|
||||
private final SecretKeyHolder secretKeyHolder;
|
||||
|
||||
/** Whether auth is done and future calls should be delegated. */
|
||||
/** RPC handler for auth handshake when falling back to SASL auth. */
|
||||
@VisibleForTesting
|
||||
boolean doDelegate;
|
||||
SaslRpcHandler saslHandler;
|
||||
|
||||
AuthRpcHandler(
|
||||
TransportConf conf,
|
||||
Channel channel,
|
||||
RpcHandler delegate,
|
||||
SecretKeyHolder secretKeyHolder) {
|
||||
super(delegate);
|
||||
this.conf = conf;
|
||||
this.channel = channel;
|
||||
this.delegate = delegate;
|
||||
this.secretKeyHolder = secretKeyHolder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
|
||||
if (doDelegate) {
|
||||
delegate.receive(client, message, callback);
|
||||
return;
|
||||
protected boolean doAuthChallenge(
|
||||
TransportClient client,
|
||||
ByteBuffer message,
|
||||
RpcResponseCallback callback) {
|
||||
if (saslHandler != null) {
|
||||
return saslHandler.doAuthChallenge(client, message, callback);
|
||||
}
|
||||
|
||||
int position = message.position();
|
||||
|
@ -98,18 +92,17 @@ class AuthRpcHandler extends RpcHandler {
|
|||
if (conf.saslFallback()) {
|
||||
LOG.warn("Failed to parse new auth challenge, reverting to SASL for client {}.",
|
||||
channel.remoteAddress());
|
||||
delegate = new SaslRpcHandler(conf, channel, delegate, secretKeyHolder);
|
||||
saslHandler = new SaslRpcHandler(conf, channel, null, secretKeyHolder);
|
||||
message.position(position);
|
||||
message.limit(limit);
|
||||
delegate.receive(client, message, callback);
|
||||
doDelegate = true;
|
||||
return saslHandler.doAuthChallenge(client, message, callback);
|
||||
} else {
|
||||
LOG.debug("Unexpected challenge message from client {}, closing channel.",
|
||||
channel.remoteAddress());
|
||||
callback.onFailure(new IllegalArgumentException("Unknown challenge message."));
|
||||
channel.close();
|
||||
}
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Here we have the client challenge, so perform the new auth protocol and set up the channel.
|
||||
|
@ -131,7 +124,7 @@ class AuthRpcHandler extends RpcHandler {
|
|||
LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress());
|
||||
callback.onFailure(new IllegalArgumentException("Authentication failed."));
|
||||
channel.close();
|
||||
return;
|
||||
return false;
|
||||
} finally {
|
||||
if (engine != null) {
|
||||
try {
|
||||
|
@ -143,40 +136,6 @@ class AuthRpcHandler extends RpcHandler {
|
|||
}
|
||||
|
||||
LOG.debug("Authorization successful for client {}.", channel.remoteAddress());
|
||||
doDelegate = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void receive(TransportClient client, ByteBuffer message) {
|
||||
delegate.receive(client, message);
|
||||
}
|
||||
|
||||
@Override
|
||||
public StreamCallbackWithID receiveStream(
|
||||
TransportClient client,
|
||||
ByteBuffer message,
|
||||
RpcResponseCallback callback) {
|
||||
return delegate.receiveStream(client, message, callback);
|
||||
}
|
||||
|
||||
@Override
|
||||
public StreamManager getStreamManager() {
|
||||
return delegate.getStreamManager();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelActive(TransportClient client) {
|
||||
delegate.channelActive(client);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelInactive(TransportClient client) {
|
||||
delegate.channelInactive(client);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(Throwable cause, TransportClient client) {
|
||||
delegate.exceptionCaught(cause, client);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -28,10 +28,9 @@ import org.slf4j.Logger;
|
|||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.apache.spark.network.client.RpcResponseCallback;
|
||||
import org.apache.spark.network.client.StreamCallbackWithID;
|
||||
import org.apache.spark.network.client.TransportClient;
|
||||
import org.apache.spark.network.server.AbstractAuthRpcHandler;
|
||||
import org.apache.spark.network.server.RpcHandler;
|
||||
import org.apache.spark.network.server.StreamManager;
|
||||
import org.apache.spark.network.util.JavaUtils;
|
||||
import org.apache.spark.network.util.TransportConf;
|
||||
|
||||
|
@ -43,7 +42,7 @@ import org.apache.spark.network.util.TransportConf;
|
|||
* Note that the authentication process consists of multiple challenge-response pairs, each of
|
||||
* which are individual RPCs.
|
||||
*/
|
||||
public class SaslRpcHandler extends RpcHandler {
|
||||
public class SaslRpcHandler extends AbstractAuthRpcHandler {
|
||||
private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
|
||||
|
||||
/** Transport configuration. */
|
||||
|
@ -52,37 +51,28 @@ public class SaslRpcHandler extends RpcHandler {
|
|||
/** The client channel. */
|
||||
private final Channel channel;
|
||||
|
||||
/** RpcHandler we will delegate to for authenticated connections. */
|
||||
private final RpcHandler delegate;
|
||||
|
||||
/** Class which provides secret keys which are shared by server and client on a per-app basis. */
|
||||
private final SecretKeyHolder secretKeyHolder;
|
||||
|
||||
private SparkSaslServer saslServer;
|
||||
private boolean isComplete;
|
||||
private boolean isAuthenticated;
|
||||
|
||||
public SaslRpcHandler(
|
||||
TransportConf conf,
|
||||
Channel channel,
|
||||
RpcHandler delegate,
|
||||
SecretKeyHolder secretKeyHolder) {
|
||||
super(delegate);
|
||||
this.conf = conf;
|
||||
this.channel = channel;
|
||||
this.delegate = delegate;
|
||||
this.secretKeyHolder = secretKeyHolder;
|
||||
this.saslServer = null;
|
||||
this.isComplete = false;
|
||||
this.isAuthenticated = false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
|
||||
if (isComplete) {
|
||||
// Authentication complete, delegate to base handler.
|
||||
delegate.receive(client, message, callback);
|
||||
return;
|
||||
}
|
||||
public boolean doAuthChallenge(
|
||||
TransportClient client,
|
||||
ByteBuffer message,
|
||||
RpcResponseCallback callback) {
|
||||
if (saslServer == null || !saslServer.isComplete()) {
|
||||
ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
|
||||
SaslMessage saslMessage;
|
||||
|
@ -118,43 +108,21 @@ public class SaslRpcHandler extends RpcHandler {
|
|||
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
|
||||
logger.debug("SASL authentication successful for channel {}", client);
|
||||
complete(true);
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
|
||||
logger.debug("Enabling encryption for channel {}", client);
|
||||
SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
|
||||
complete(false);
|
||||
return;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void receive(TransportClient client, ByteBuffer message) {
|
||||
delegate.receive(client, message);
|
||||
}
|
||||
|
||||
@Override
|
||||
public StreamCallbackWithID receiveStream(
|
||||
TransportClient client,
|
||||
ByteBuffer message,
|
||||
RpcResponseCallback callback) {
|
||||
return delegate.receiveStream(client, message, callback);
|
||||
}
|
||||
|
||||
@Override
|
||||
public StreamManager getStreamManager() {
|
||||
return delegate.getStreamManager();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelActive(TransportClient client) {
|
||||
delegate.channelActive(client);
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelInactive(TransportClient client) {
|
||||
try {
|
||||
delegate.channelInactive(client);
|
||||
super.channelInactive(client);
|
||||
} finally {
|
||||
if (saslServer != null) {
|
||||
saslServer.dispose();
|
||||
|
@ -162,11 +130,6 @@ public class SaslRpcHandler extends RpcHandler {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(Throwable cause, TransportClient client) {
|
||||
delegate.exceptionCaught(cause, client);
|
||||
}
|
||||
|
||||
private void complete(boolean dispose) {
|
||||
if (dispose) {
|
||||
try {
|
||||
|
@ -177,7 +140,6 @@ public class SaslRpcHandler extends RpcHandler {
|
|||
}
|
||||
|
||||
saslServer = null;
|
||||
isComplete = true;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.network.server;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
import org.apache.spark.network.client.RpcResponseCallback;
|
||||
import org.apache.spark.network.client.StreamCallbackWithID;
|
||||
import org.apache.spark.network.client.TransportClient;
|
||||
|
||||
/**
|
||||
* RPC Handler which performs authentication, and when it's successful, delegates further
|
||||
* calls to another RPC handler. The authentication handshake itself should be implemented
|
||||
* by subclasses.
|
||||
*/
|
||||
public abstract class AbstractAuthRpcHandler extends RpcHandler {
|
||||
/** RpcHandler we will delegate to for authenticated connections. */
|
||||
private final RpcHandler delegate;
|
||||
|
||||
private boolean isAuthenticated;
|
||||
|
||||
protected AbstractAuthRpcHandler(RpcHandler delegate) {
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
/**
|
||||
* Responds to an authentication challenge.
|
||||
*
|
||||
* @return Whether the client is authenticated.
|
||||
*/
|
||||
protected abstract boolean doAuthChallenge(
|
||||
TransportClient client,
|
||||
ByteBuffer message,
|
||||
RpcResponseCallback callback);
|
||||
|
||||
@Override
|
||||
public final void receive(
|
||||
TransportClient client,
|
||||
ByteBuffer message,
|
||||
RpcResponseCallback callback) {
|
||||
if (isAuthenticated) {
|
||||
delegate.receive(client, message, callback);
|
||||
} else {
|
||||
isAuthenticated = doAuthChallenge(client, message, callback);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void receive(TransportClient client, ByteBuffer message) {
|
||||
if (isAuthenticated) {
|
||||
delegate.receive(client, message);
|
||||
} else {
|
||||
throw new SecurityException("Unauthenticated call to receive().");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public final StreamCallbackWithID receiveStream(
|
||||
TransportClient client,
|
||||
ByteBuffer message,
|
||||
RpcResponseCallback callback) {
|
||||
if (isAuthenticated) {
|
||||
return delegate.receiveStream(client, message, callback);
|
||||
} else {
|
||||
throw new SecurityException("Unauthenticated call to receiveStream().");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public StreamManager getStreamManager() {
|
||||
return delegate.getStreamManager();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelActive(TransportClient client) {
|
||||
delegate.channelActive(client);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelInactive(TransportClient client) {
|
||||
delegate.channelInactive(client);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(Throwable cause, TransportClient client) {
|
||||
delegate.exceptionCaught(cause, client);
|
||||
}
|
||||
|
||||
public boolean isAuthenticated() {
|
||||
return isAuthenticated;
|
||||
}
|
||||
}
|
|
@ -34,7 +34,6 @@ import org.apache.spark.network.TransportContext;
|
|||
import org.apache.spark.network.client.RpcResponseCallback;
|
||||
import org.apache.spark.network.client.TransportClient;
|
||||
import org.apache.spark.network.client.TransportClientBootstrap;
|
||||
import org.apache.spark.network.sasl.SaslRpcHandler;
|
||||
import org.apache.spark.network.sasl.SaslServerBootstrap;
|
||||
import org.apache.spark.network.sasl.SecretKeyHolder;
|
||||
import org.apache.spark.network.server.RpcHandler;
|
||||
|
@ -65,8 +64,7 @@ public class AuthIntegrationSuite {
|
|||
|
||||
ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
|
||||
assertEquals("Pong", JavaUtils.bytesToString(reply));
|
||||
assertTrue(ctx.authRpcHandler.doDelegate);
|
||||
assertFalse(ctx.authRpcHandler.delegate instanceof SaslRpcHandler);
|
||||
assertNull(ctx.authRpcHandler.saslHandler);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -78,7 +76,7 @@ public class AuthIntegrationSuite {
|
|||
ctx.createClient("client");
|
||||
fail("Should have failed to create client.");
|
||||
} catch (Exception e) {
|
||||
assertFalse(ctx.authRpcHandler.doDelegate);
|
||||
assertFalse(ctx.authRpcHandler.isAuthenticated());
|
||||
assertFalse(ctx.serverChannel.isActive());
|
||||
}
|
||||
}
|
||||
|
@ -91,6 +89,8 @@ public class AuthIntegrationSuite {
|
|||
|
||||
ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
|
||||
assertEquals("Pong", JavaUtils.bytesToString(reply));
|
||||
assertNotNull(ctx.authRpcHandler.saslHandler);
|
||||
assertTrue(ctx.authRpcHandler.isAuthenticated());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -120,7 +120,7 @@ public class AuthIntegrationSuite {
|
|||
ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
|
||||
fail("Should have failed unencrypted RPC.");
|
||||
} catch (Exception e) {
|
||||
assertTrue(ctx.authRpcHandler.doDelegate);
|
||||
assertTrue(ctx.authRpcHandler.isAuthenticated());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -151,7 +151,7 @@ public class AuthIntegrationSuite {
|
|||
ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
|
||||
fail("Should have failed unencrypted RPC.");
|
||||
} catch (Exception e) {
|
||||
assertTrue(ctx.authRpcHandler.doDelegate);
|
||||
assertTrue(ctx.authRpcHandler.isAuthenticated());
|
||||
assertTrue(e.getMessage() + " is not an expected error", e.getMessage().contains("DDDDD"));
|
||||
// Verify we receive the complete error message
|
||||
int messageStart = e.getMessage().indexOf("DDDDD");
|
||||
|
|
|
@ -357,7 +357,8 @@ public class SparkSaslSuite {
|
|||
public void testDelegates() throws Exception {
|
||||
Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods();
|
||||
for (Method m : rpcHandlerMethods) {
|
||||
SaslRpcHandler.class.getDeclaredMethod(m.getName(), m.getParameterTypes());
|
||||
Method delegate = SaslRpcHandler.class.getMethod(m.getName(), m.getParameterTypes());
|
||||
assertNotEquals(delegate.getDeclaringClass(), RpcHandler.class);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue