[SPARK-30129][CORE] Set client's id in TransportClient after successful auth
The new auth code was missing this bit, so it was not possible to know which app a client belonged to when auth was on. I also refactored the SASL test that checks for this so it also checks the new protocol (test failed before the fix, passes now). Closes #26760 from vanzin/SPARK-30129. Authored-by: Marcelo Vanzin <vanzin@cloudera.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
parent
29e09a83b7
commit
c5f312a6ac
|
@ -78,6 +78,7 @@ public class AuthClientBootstrap implements TransportClientBootstrap {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
doSparkAuth(client, channel);
|
doSparkAuth(client, channel);
|
||||||
|
client.setClientId(appId);
|
||||||
} catch (GeneralSecurityException | IOException e) {
|
} catch (GeneralSecurityException | IOException e) {
|
||||||
throw Throwables.propagate(e);
|
throw Throwables.propagate(e);
|
||||||
} catch (RuntimeException e) {
|
} catch (RuntimeException e) {
|
||||||
|
|
|
@ -125,6 +125,7 @@ class AuthRpcHandler extends RpcHandler {
|
||||||
response.encode(responseData);
|
response.encode(responseData);
|
||||||
callback.onSuccess(responseData.nioBuffer());
|
callback.onSuccess(responseData.nioBuffer());
|
||||||
engine.sessionCipher().addToChannel(channel);
|
engine.sessionCipher().addToChannel(channel);
|
||||||
|
client.setClientId(challenge.appId);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
// This is a fatal error: authentication has failed. Close the channel explicitly.
|
// This is a fatal error: authentication has failed. Close the channel explicitly.
|
||||||
LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress());
|
LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress());
|
||||||
|
|
|
@ -21,8 +21,6 @@ import java.io.IOException;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.concurrent.CountDownLatch;
|
|
||||||
import java.util.concurrent.atomic.AtomicReference;
|
|
||||||
|
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
import org.junit.AfterClass;
|
import org.junit.AfterClass;
|
||||||
|
@ -34,8 +32,6 @@ import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
import org.apache.spark.network.TestUtils;
|
import org.apache.spark.network.TestUtils;
|
||||||
import org.apache.spark.network.TransportContext;
|
import org.apache.spark.network.TransportContext;
|
||||||
import org.apache.spark.network.buffer.ManagedBuffer;
|
|
||||||
import org.apache.spark.network.client.ChunkReceivedCallback;
|
|
||||||
import org.apache.spark.network.client.RpcResponseCallback;
|
import org.apache.spark.network.client.RpcResponseCallback;
|
||||||
import org.apache.spark.network.client.TransportClient;
|
import org.apache.spark.network.client.TransportClient;
|
||||||
import org.apache.spark.network.client.TransportClientFactory;
|
import org.apache.spark.network.client.TransportClientFactory;
|
||||||
|
@ -44,15 +40,6 @@ import org.apache.spark.network.server.RpcHandler;
|
||||||
import org.apache.spark.network.server.StreamManager;
|
import org.apache.spark.network.server.StreamManager;
|
||||||
import org.apache.spark.network.server.TransportServer;
|
import org.apache.spark.network.server.TransportServer;
|
||||||
import org.apache.spark.network.server.TransportServerBootstrap;
|
import org.apache.spark.network.server.TransportServerBootstrap;
|
||||||
import org.apache.spark.network.shuffle.BlockFetchingListener;
|
|
||||||
import org.apache.spark.network.shuffle.ExternalBlockHandler;
|
|
||||||
import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver;
|
|
||||||
import org.apache.spark.network.shuffle.OneForOneBlockFetcher;
|
|
||||||
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
|
|
||||||
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
|
|
||||||
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
|
|
||||||
import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
|
|
||||||
import org.apache.spark.network.shuffle.protocol.StreamHandle;
|
|
||||||
import org.apache.spark.network.util.JavaUtils;
|
import org.apache.spark.network.util.JavaUtils;
|
||||||
import org.apache.spark.network.util.MapConfigProvider;
|
import org.apache.spark.network.util.MapConfigProvider;
|
||||||
import org.apache.spark.network.util.TransportConf;
|
import org.apache.spark.network.util.TransportConf;
|
||||||
|
@ -165,93 +152,6 @@ public class SaslIntegrationSuite {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* This test is not actually testing SASL behavior, but testing that the shuffle service
|
|
||||||
* performs correct authorization checks based on the SASL authentication data.
|
|
||||||
*/
|
|
||||||
@Test
|
|
||||||
public void testAppIsolation() throws Exception {
|
|
||||||
// Start a new server with the correct RPC handler to serve block data.
|
|
||||||
ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class);
|
|
||||||
ExternalBlockHandler blockHandler = new ExternalBlockHandler(
|
|
||||||
new OneForOneStreamManager(), blockResolver);
|
|
||||||
TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
|
|
||||||
|
|
||||||
try (
|
|
||||||
TransportContext blockServerContext = new TransportContext(conf, blockHandler);
|
|
||||||
TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap));
|
|
||||||
// Create a client, and make a request to fetch blocks from a different app.
|
|
||||||
TransportClientFactory clientFactory1 = blockServerContext.createClientFactory(
|
|
||||||
Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
|
|
||||||
TransportClient client1 = clientFactory1.createClient(
|
|
||||||
TestUtils.getLocalHost(), blockServer.getPort())) {
|
|
||||||
|
|
||||||
AtomicReference<Throwable> exception = new AtomicReference<>();
|
|
||||||
|
|
||||||
CountDownLatch blockFetchLatch = new CountDownLatch(1);
|
|
||||||
BlockFetchingListener listener = new BlockFetchingListener() {
|
|
||||||
@Override
|
|
||||||
public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
|
|
||||||
blockFetchLatch.countDown();
|
|
||||||
}
|
|
||||||
@Override
|
|
||||||
public void onBlockFetchFailure(String blockId, Throwable t) {
|
|
||||||
exception.set(t);
|
|
||||||
blockFetchLatch.countDown();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" };
|
|
||||||
OneForOneBlockFetcher fetcher =
|
|
||||||
new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf);
|
|
||||||
fetcher.start();
|
|
||||||
blockFetchLatch.await();
|
|
||||||
checkSecurityException(exception.get());
|
|
||||||
|
|
||||||
// Register an executor so that the next steps work.
|
|
||||||
ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo(
|
|
||||||
new String[] { System.getProperty("java.io.tmpdir") }, 1,
|
|
||||||
"org.apache.spark.shuffle.sort.SortShuffleManager");
|
|
||||||
RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
|
|
||||||
client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS);
|
|
||||||
|
|
||||||
// Make a successful request to fetch blocks, which creates a new stream. But do not actually
|
|
||||||
// fetch any blocks, to keep the stream open.
|
|
||||||
OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
|
|
||||||
ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS);
|
|
||||||
StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
|
|
||||||
long streamId = stream.streamId;
|
|
||||||
|
|
||||||
try (
|
|
||||||
// Create a second client, authenticated with a different app ID, and try to read from
|
|
||||||
// the stream created for the previous app.
|
|
||||||
TransportClientFactory clientFactory2 = blockServerContext.createClientFactory(
|
|
||||||
Arrays.asList(new SaslClientBootstrap(conf, "app-2", secretKeyHolder)));
|
|
||||||
TransportClient client2 = clientFactory2.createClient(
|
|
||||||
TestUtils.getLocalHost(), blockServer.getPort())
|
|
||||||
) {
|
|
||||||
CountDownLatch chunkReceivedLatch = new CountDownLatch(1);
|
|
||||||
ChunkReceivedCallback callback = new ChunkReceivedCallback() {
|
|
||||||
@Override
|
|
||||||
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
|
|
||||||
chunkReceivedLatch.countDown();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onFailure(int chunkIndex, Throwable t) {
|
|
||||||
exception.set(t);
|
|
||||||
chunkReceivedLatch.countDown();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
exception.set(null);
|
|
||||||
client2.fetchChunk(streamId, 0, callback);
|
|
||||||
chunkReceivedLatch.await();
|
|
||||||
checkSecurityException(exception.get());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/** RPC handler which simply responds with the message it received. */
|
/** RPC handler which simply responds with the message it received. */
|
||||||
public static class TestRpcHandler extends RpcHandler {
|
public static class TestRpcHandler extends RpcHandler {
|
||||||
@Override
|
@Override
|
||||||
|
@ -264,10 +164,4 @@ public class SaslIntegrationSuite {
|
||||||
return new OneForOneStreamManager();
|
return new OneForOneStreamManager();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void checkSecurityException(Throwable t) {
|
|
||||||
assertNotNull("No exception was caught.", t);
|
|
||||||
assertTrue("Expected SecurityException.",
|
|
||||||
t.getMessage().contains(SecurityException.class.getName()));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,184 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.network.shuffle;
|
||||||
|
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
import java.util.function.Function;
|
||||||
|
import java.util.function.Supplier;
|
||||||
|
|
||||||
|
import org.junit.BeforeClass;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
|
import org.apache.spark.network.TestUtils;
|
||||||
|
import org.apache.spark.network.TransportContext;
|
||||||
|
import org.apache.spark.network.buffer.ManagedBuffer;
|
||||||
|
import org.apache.spark.network.client.ChunkReceivedCallback;
|
||||||
|
import org.apache.spark.network.client.TransportClient;
|
||||||
|
import org.apache.spark.network.client.TransportClientBootstrap;
|
||||||
|
import org.apache.spark.network.client.TransportClientFactory;
|
||||||
|
import org.apache.spark.network.crypto.AuthClientBootstrap;
|
||||||
|
import org.apache.spark.network.crypto.AuthServerBootstrap;
|
||||||
|
import org.apache.spark.network.sasl.SaslClientBootstrap;
|
||||||
|
import org.apache.spark.network.sasl.SaslServerBootstrap;
|
||||||
|
import org.apache.spark.network.sasl.SecretKeyHolder;
|
||||||
|
import org.apache.spark.network.server.OneForOneStreamManager;
|
||||||
|
import org.apache.spark.network.server.TransportServer;
|
||||||
|
import org.apache.spark.network.server.TransportServerBootstrap;
|
||||||
|
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
|
||||||
|
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
|
||||||
|
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
|
||||||
|
import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
|
||||||
|
import org.apache.spark.network.shuffle.protocol.StreamHandle;
|
||||||
|
import org.apache.spark.network.util.MapConfigProvider;
|
||||||
|
import org.apache.spark.network.util.TransportConf;
|
||||||
|
|
||||||
|
public class AppIsolationSuite {
|
||||||
|
|
||||||
|
// Use a long timeout to account for slow / overloaded build machines. In the normal case,
|
||||||
|
// tests should finish way before the timeout expires.
|
||||||
|
private static final long TIMEOUT_MS = 10_000;
|
||||||
|
|
||||||
|
private static SecretKeyHolder secretKeyHolder;
|
||||||
|
private static TransportConf conf;
|
||||||
|
|
||||||
|
@BeforeClass
|
||||||
|
public static void beforeAll() {
|
||||||
|
Map<String, String> confMap = new HashMap<>();
|
||||||
|
confMap.put("spark.network.crypto.enabled", "true");
|
||||||
|
confMap.put("spark.network.crypto.saslFallback", "false");
|
||||||
|
conf = new TransportConf("shuffle", new MapConfigProvider(confMap));
|
||||||
|
|
||||||
|
secretKeyHolder = mock(SecretKeyHolder.class);
|
||||||
|
when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1");
|
||||||
|
when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1");
|
||||||
|
when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2");
|
||||||
|
when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSaslAppIsolation() throws Exception {
|
||||||
|
testAppIsolation(
|
||||||
|
() -> new SaslServerBootstrap(conf, secretKeyHolder),
|
||||||
|
appId -> new SaslClientBootstrap(conf, appId, secretKeyHolder));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testAuthEngineAppIsolation() throws Exception {
|
||||||
|
testAppIsolation(
|
||||||
|
() -> new AuthServerBootstrap(conf, secretKeyHolder),
|
||||||
|
appId -> new AuthClientBootstrap(conf, appId, secretKeyHolder));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void testAppIsolation(
|
||||||
|
Supplier<TransportServerBootstrap> serverBootstrap,
|
||||||
|
Function<String, TransportClientBootstrap> clientBootstrapFactory) throws Exception {
|
||||||
|
// Start a new server with the correct RPC handler to serve block data.
|
||||||
|
ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class);
|
||||||
|
ExternalBlockHandler blockHandler = new ExternalBlockHandler(
|
||||||
|
new OneForOneStreamManager(), blockResolver);
|
||||||
|
TransportServerBootstrap bootstrap = serverBootstrap.get();
|
||||||
|
|
||||||
|
try (
|
||||||
|
TransportContext blockServerContext = new TransportContext(conf, blockHandler);
|
||||||
|
TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap));
|
||||||
|
// Create a client, and make a request to fetch blocks from a different app.
|
||||||
|
TransportClientFactory clientFactory1 = blockServerContext.createClientFactory(
|
||||||
|
Arrays.asList(clientBootstrapFactory.apply("app-1")));
|
||||||
|
TransportClient client1 = clientFactory1.createClient(
|
||||||
|
TestUtils.getLocalHost(), blockServer.getPort())) {
|
||||||
|
|
||||||
|
AtomicReference<Throwable> exception = new AtomicReference<>();
|
||||||
|
|
||||||
|
CountDownLatch blockFetchLatch = new CountDownLatch(1);
|
||||||
|
BlockFetchingListener listener = new BlockFetchingListener() {
|
||||||
|
@Override
|
||||||
|
public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
|
||||||
|
blockFetchLatch.countDown();
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public void onBlockFetchFailure(String blockId, Throwable t) {
|
||||||
|
exception.set(t);
|
||||||
|
blockFetchLatch.countDown();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" };
|
||||||
|
OneForOneBlockFetcher fetcher =
|
||||||
|
new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf);
|
||||||
|
fetcher.start();
|
||||||
|
blockFetchLatch.await();
|
||||||
|
checkSecurityException(exception.get());
|
||||||
|
|
||||||
|
// Register an executor so that the next steps work.
|
||||||
|
ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo(
|
||||||
|
new String[] { System.getProperty("java.io.tmpdir") }, 1,
|
||||||
|
"org.apache.spark.shuffle.sort.SortShuffleManager");
|
||||||
|
RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
|
||||||
|
client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS);
|
||||||
|
|
||||||
|
// Make a successful request to fetch blocks, which creates a new stream. But do not actually
|
||||||
|
// fetch any blocks, to keep the stream open.
|
||||||
|
OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
|
||||||
|
ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS);
|
||||||
|
StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
|
||||||
|
long streamId = stream.streamId;
|
||||||
|
|
||||||
|
try (
|
||||||
|
// Create a second client, authenticated with a different app ID, and try to read from
|
||||||
|
// the stream created for the previous app.
|
||||||
|
TransportClientFactory clientFactory2 = blockServerContext.createClientFactory(
|
||||||
|
Arrays.asList(clientBootstrapFactory.apply("app-2")));
|
||||||
|
TransportClient client2 = clientFactory2.createClient(
|
||||||
|
TestUtils.getLocalHost(), blockServer.getPort())
|
||||||
|
) {
|
||||||
|
CountDownLatch chunkReceivedLatch = new CountDownLatch(1);
|
||||||
|
ChunkReceivedCallback callback = new ChunkReceivedCallback() {
|
||||||
|
@Override
|
||||||
|
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
|
||||||
|
chunkReceivedLatch.countDown();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onFailure(int chunkIndex, Throwable t) {
|
||||||
|
exception.set(t);
|
||||||
|
chunkReceivedLatch.countDown();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
exception.set(null);
|
||||||
|
client2.fetchChunk(streamId, 0, callback);
|
||||||
|
chunkReceivedLatch.await();
|
||||||
|
checkSecurityException(exception.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void checkSecurityException(Throwable t) {
|
||||||
|
assertNotNull("No exception was caught.", t);
|
||||||
|
assertTrue("Expected SecurityException.",
|
||||||
|
t.getMessage().contains(SecurityException.class.getName()));
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue