Apply appropriate RPC handler to receive, receiveStream when auth enabled

This commit is contained in:
Sean Owen 2020-04-17 13:25:12 -05:00
parent a7fb330ed3
commit 61b7d446b3
5 changed files with 142 additions and 113 deletions

View file

@ -29,12 +29,11 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallbackWithID;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.sasl.SaslRpcHandler;
import org.apache.spark.network.server.AbstractAuthRpcHandler;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.util.TransportConf;
/**
@ -46,7 +45,7 @@ import org.apache.spark.network.util.TransportConf;
* The delegate will only receive messages if the given connection has been successfully
* authenticated. A connection may be authenticated at most once.
*/
class AuthRpcHandler extends RpcHandler {
class AuthRpcHandler extends AbstractAuthRpcHandler {
private static final Logger LOG = LoggerFactory.getLogger(AuthRpcHandler.class);
/** Transport configuration. */
@ -55,36 +54,31 @@ class AuthRpcHandler extends RpcHandler {
/** The client channel. */
private final Channel channel;
/**
* RpcHandler we will delegate to for authenticated connections. When falling back to SASL
* this will be replaced with the SASL RPC handler.
*/
@VisibleForTesting
RpcHandler delegate;
/** Class which provides secret keys which are shared by server and client on a per-app basis. */
private final SecretKeyHolder secretKeyHolder;
/** Whether auth is done and future calls should be delegated. */
/** RPC handler for auth handshake when falling back to SASL auth. */
@VisibleForTesting
boolean doDelegate;
SaslRpcHandler saslHandler;
AuthRpcHandler(
TransportConf conf,
Channel channel,
RpcHandler delegate,
SecretKeyHolder secretKeyHolder) {
super(delegate);
this.conf = conf;
this.channel = channel;
this.delegate = delegate;
this.secretKeyHolder = secretKeyHolder;
}
@Override
public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
if (doDelegate) {
delegate.receive(client, message, callback);
return;
protected boolean doAuthChallenge(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback) {
if (saslHandler != null) {
return saslHandler.doAuthChallenge(client, message, callback);
}
int position = message.position();
@ -98,18 +92,17 @@ class AuthRpcHandler extends RpcHandler {
if (conf.saslFallback()) {
LOG.warn("Failed to parse new auth challenge, reverting to SASL for client {}.",
channel.remoteAddress());
delegate = new SaslRpcHandler(conf, channel, delegate, secretKeyHolder);
saslHandler = new SaslRpcHandler(conf, channel, null, secretKeyHolder);
message.position(position);
message.limit(limit);
delegate.receive(client, message, callback);
doDelegate = true;
return saslHandler.doAuthChallenge(client, message, callback);
} else {
LOG.debug("Unexpected challenge message from client {}, closing channel.",
channel.remoteAddress());
callback.onFailure(new IllegalArgumentException("Unknown challenge message."));
channel.close();
}
return;
return false;
}
// Here we have the client challenge, so perform the new auth protocol and set up the channel.
@ -131,7 +124,7 @@ class AuthRpcHandler extends RpcHandler {
LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress());
callback.onFailure(new IllegalArgumentException("Authentication failed."));
channel.close();
return;
return false;
} finally {
if (engine != null) {
try {
@ -143,40 +136,6 @@ class AuthRpcHandler extends RpcHandler {
}
LOG.debug("Authorization successful for client {}.", channel.remoteAddress());
doDelegate = true;
return true;
}
@Override
public void receive(TransportClient client, ByteBuffer message) {
delegate.receive(client, message);
}
@Override
public StreamCallbackWithID receiveStream(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback) {
return delegate.receiveStream(client, message, callback);
}
@Override
public StreamManager getStreamManager() {
return delegate.getStreamManager();
}
@Override
public void channelActive(TransportClient client) {
delegate.channelActive(client);
}
@Override
public void channelInactive(TransportClient client) {
delegate.channelInactive(client);
}
@Override
public void exceptionCaught(Throwable cause, TransportClient client) {
delegate.exceptionCaught(cause, client);
}
}

View file

@ -28,10 +28,9 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallbackWithID;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.AbstractAuthRpcHandler;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.TransportConf;
@ -43,7 +42,7 @@ import org.apache.spark.network.util.TransportConf;
* Note that the authentication process consists of multiple challenge-response pairs, each of
* which are individual RPCs.
*/
public class SaslRpcHandler extends RpcHandler {
public class SaslRpcHandler extends AbstractAuthRpcHandler {
private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
/** Transport configuration. */
@ -52,37 +51,28 @@ public class SaslRpcHandler extends RpcHandler {
/** The client channel. */
private final Channel channel;
/** RpcHandler we will delegate to for authenticated connections. */
private final RpcHandler delegate;
/** Class which provides secret keys which are shared by server and client on a per-app basis. */
private final SecretKeyHolder secretKeyHolder;
private SparkSaslServer saslServer;
private boolean isComplete;
private boolean isAuthenticated;
public SaslRpcHandler(
TransportConf conf,
Channel channel,
RpcHandler delegate,
SecretKeyHolder secretKeyHolder) {
super(delegate);
this.conf = conf;
this.channel = channel;
this.delegate = delegate;
this.secretKeyHolder = secretKeyHolder;
this.saslServer = null;
this.isComplete = false;
this.isAuthenticated = false;
}
@Override
public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
if (isComplete) {
// Authentication complete, delegate to base handler.
delegate.receive(client, message, callback);
return;
}
public boolean doAuthChallenge(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback) {
if (saslServer == null || !saslServer.isComplete()) {
ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
SaslMessage saslMessage;
@ -118,43 +108,21 @@ public class SaslRpcHandler extends RpcHandler {
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
logger.debug("SASL authentication successful for channel {}", client);
complete(true);
return;
return true;
}
logger.debug("Enabling encryption for channel {}", client);
SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
complete(false);
return;
return true;
}
}
@Override
public void receive(TransportClient client, ByteBuffer message) {
delegate.receive(client, message);
}
@Override
public StreamCallbackWithID receiveStream(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback) {
return delegate.receiveStream(client, message, callback);
}
@Override
public StreamManager getStreamManager() {
return delegate.getStreamManager();
}
@Override
public void channelActive(TransportClient client) {
delegate.channelActive(client);
return false;
}
@Override
public void channelInactive(TransportClient client) {
try {
delegate.channelInactive(client);
super.channelInactive(client);
} finally {
if (saslServer != null) {
saslServer.dispose();
@ -162,11 +130,6 @@ public class SaslRpcHandler extends RpcHandler {
}
}
@Override
public void exceptionCaught(Throwable cause, TransportClient client) {
delegate.exceptionCaught(cause, client);
}
private void complete(boolean dispose) {
if (dispose) {
try {
@ -177,7 +140,6 @@ public class SaslRpcHandler extends RpcHandler {
}
saslServer = null;
isComplete = true;
}
}

View file

@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.server;
import java.nio.ByteBuffer;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallbackWithID;
import org.apache.spark.network.client.TransportClient;
/**
* RPC Handler which performs authentication, and when it's successful, delegates further
* calls to another RPC handler. The authentication handshake itself should be implemented
* by subclasses.
*/
public abstract class AbstractAuthRpcHandler extends RpcHandler {
/** RpcHandler we will delegate to for authenticated connections. */
private final RpcHandler delegate;
private boolean isAuthenticated;
protected AbstractAuthRpcHandler(RpcHandler delegate) {
this.delegate = delegate;
}
/**
* Responds to an authentication challenge.
*
* @return Whether the client is authenticated.
*/
protected abstract boolean doAuthChallenge(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback);
@Override
public final void receive(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback) {
if (isAuthenticated) {
delegate.receive(client, message, callback);
} else {
isAuthenticated = doAuthChallenge(client, message, callback);
}
}
@Override
public final void receive(TransportClient client, ByteBuffer message) {
if (isAuthenticated) {
delegate.receive(client, message);
} else {
throw new SecurityException("Unauthenticated call to receive().");
}
}
@Override
public final StreamCallbackWithID receiveStream(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback) {
if (isAuthenticated) {
return delegate.receiveStream(client, message, callback);
} else {
throw new SecurityException("Unauthenticated call to receiveStream().");
}
}
@Override
public StreamManager getStreamManager() {
return delegate.getStreamManager();
}
@Override
public void channelActive(TransportClient client) {
delegate.channelActive(client);
}
@Override
public void channelInactive(TransportClient client) {
delegate.channelInactive(client);
}
@Override
public void exceptionCaught(Throwable cause, TransportClient client) {
delegate.exceptionCaught(cause, client);
}
public boolean isAuthenticated() {
return isAuthenticated;
}
}

View file

@ -34,7 +34,6 @@ import org.apache.spark.network.TransportContext;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.sasl.SaslRpcHandler;
import org.apache.spark.network.sasl.SaslServerBootstrap;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.RpcHandler;
@ -65,8 +64,7 @@ public class AuthIntegrationSuite {
ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
assertEquals("Pong", JavaUtils.bytesToString(reply));
assertTrue(ctx.authRpcHandler.doDelegate);
assertFalse(ctx.authRpcHandler.delegate instanceof SaslRpcHandler);
assertNull(ctx.authRpcHandler.saslHandler);
}
@Test
@ -78,7 +76,7 @@ public class AuthIntegrationSuite {
ctx.createClient("client");
fail("Should have failed to create client.");
} catch (Exception e) {
assertFalse(ctx.authRpcHandler.doDelegate);
assertFalse(ctx.authRpcHandler.isAuthenticated());
assertFalse(ctx.serverChannel.isActive());
}
}
@ -91,6 +89,8 @@ public class AuthIntegrationSuite {
ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
assertEquals("Pong", JavaUtils.bytesToString(reply));
assertNotNull(ctx.authRpcHandler.saslHandler);
assertTrue(ctx.authRpcHandler.isAuthenticated());
}
@Test
@ -120,7 +120,7 @@ public class AuthIntegrationSuite {
ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
fail("Should have failed unencrypted RPC.");
} catch (Exception e) {
assertTrue(ctx.authRpcHandler.doDelegate);
assertTrue(ctx.authRpcHandler.isAuthenticated());
}
}
@ -151,7 +151,7 @@ public class AuthIntegrationSuite {
ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
fail("Should have failed unencrypted RPC.");
} catch (Exception e) {
assertTrue(ctx.authRpcHandler.doDelegate);
assertTrue(ctx.authRpcHandler.isAuthenticated());
assertTrue(e.getMessage() + " is not an expected error", e.getMessage().contains("DDDDD"));
// Verify we receive the complete error message
int messageStart = e.getMessage().indexOf("DDDDD");

View file

@ -357,7 +357,8 @@ public class SparkSaslSuite {
public void testDelegates() throws Exception {
Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods();
for (Method m : rpcHandlerMethods) {
SaslRpcHandler.class.getDeclaredMethod(m.getName(), m.getParameterTypes());
Method delegate = SaslRpcHandler.class.getMethod(m.getName(), m.getParameterTypes());
assertNotEquals(delegate.getDeclaringClass(), RpcHandler.class);
}
}