[SPARK-30129][CORE] Set client's id in TransportClient after successful auth

The new auth code was missing this bit, so it was not possible to know which
app a client belonged to when auth was on.

I also refactored the SASL test that checks for this so it also checks the
new protocol (test failed before the fix, passes now).

Closes #26760 from vanzin/SPARK-30129.

Authored-by: Marcelo Vanzin <vanzin@cloudera.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
Marcelo Vanzin 2019-12-04 17:11:50 -08:00 committed by Dongjoon Hyun
parent 29e09a83b7
commit c5f312a6ac
4 changed files with 186 additions and 106 deletions

View file

@ -78,6 +78,7 @@ public class AuthClientBootstrap implements TransportClientBootstrap {
try { try {
doSparkAuth(client, channel); doSparkAuth(client, channel);
client.setClientId(appId);
} catch (GeneralSecurityException | IOException e) { } catch (GeneralSecurityException | IOException e) {
throw Throwables.propagate(e); throw Throwables.propagate(e);
} catch (RuntimeException e) { } catch (RuntimeException e) {

View file

@ -125,6 +125,7 @@ class AuthRpcHandler extends RpcHandler {
response.encode(responseData); response.encode(responseData);
callback.onSuccess(responseData.nioBuffer()); callback.onSuccess(responseData.nioBuffer());
engine.sessionCipher().addToChannel(channel); engine.sessionCipher().addToChannel(channel);
client.setClientId(challenge.appId);
} catch (Exception e) { } catch (Exception e) {
// This is a fatal error: authentication has failed. Close the channel explicitly. // This is a fatal error: authentication has failed. Close the channel explicitly.
LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress()); LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress());

View file

@ -21,8 +21,6 @@ import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.After; import org.junit.After;
import org.junit.AfterClass; import org.junit.AfterClass;
@ -34,8 +32,6 @@ import static org.mockito.Mockito.*;
import org.apache.spark.network.TestUtils; import org.apache.spark.network.TestUtils;
import org.apache.spark.network.TransportContext; import org.apache.spark.network.TransportContext;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.client.TransportClientFactory;
@ -44,15 +40,6 @@ import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.shuffle.BlockFetchingListener;
import org.apache.spark.network.shuffle.ExternalBlockHandler;
import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver;
import org.apache.spark.network.shuffle.OneForOneBlockFetcher;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf; import org.apache.spark.network.util.TransportConf;
@ -165,93 +152,6 @@ public class SaslIntegrationSuite {
} }
} }
/**
* This test is not actually testing SASL behavior, but testing that the shuffle service
* performs correct authorization checks based on the SASL authentication data.
*/
@Test
public void testAppIsolation() throws Exception {
// Start a new server with the correct RPC handler to serve block data.
ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class);
ExternalBlockHandler blockHandler = new ExternalBlockHandler(
new OneForOneStreamManager(), blockResolver);
TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder);
try (
TransportContext blockServerContext = new TransportContext(conf, blockHandler);
TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap));
// Create a client, and make a request to fetch blocks from a different app.
TransportClientFactory clientFactory1 = blockServerContext.createClientFactory(
Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
TransportClient client1 = clientFactory1.createClient(
TestUtils.getLocalHost(), blockServer.getPort())) {
AtomicReference<Throwable> exception = new AtomicReference<>();
CountDownLatch blockFetchLatch = new CountDownLatch(1);
BlockFetchingListener listener = new BlockFetchingListener() {
@Override
public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
blockFetchLatch.countDown();
}
@Override
public void onBlockFetchFailure(String blockId, Throwable t) {
exception.set(t);
blockFetchLatch.countDown();
}
};
String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" };
OneForOneBlockFetcher fetcher =
new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf);
fetcher.start();
blockFetchLatch.await();
checkSecurityException(exception.get());
// Register an executor so that the next steps work.
ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo(
new String[] { System.getProperty("java.io.tmpdir") }, 1,
"org.apache.spark.shuffle.sort.SortShuffleManager");
RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS);
// Make a successful request to fetch blocks, which creates a new stream. But do not actually
// fetch any blocks, to keep the stream open.
OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS);
StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
long streamId = stream.streamId;
try (
// Create a second client, authenticated with a different app ID, and try to read from
// the stream created for the previous app.
TransportClientFactory clientFactory2 = blockServerContext.createClientFactory(
Arrays.asList(new SaslClientBootstrap(conf, "app-2", secretKeyHolder)));
TransportClient client2 = clientFactory2.createClient(
TestUtils.getLocalHost(), blockServer.getPort())
) {
CountDownLatch chunkReceivedLatch = new CountDownLatch(1);
ChunkReceivedCallback callback = new ChunkReceivedCallback() {
@Override
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
chunkReceivedLatch.countDown();
}
@Override
public void onFailure(int chunkIndex, Throwable t) {
exception.set(t);
chunkReceivedLatch.countDown();
}
};
exception.set(null);
client2.fetchChunk(streamId, 0, callback);
chunkReceivedLatch.await();
checkSecurityException(exception.get());
}
}
}
/** RPC handler which simply responds with the message it received. */ /** RPC handler which simply responds with the message it received. */
public static class TestRpcHandler extends RpcHandler { public static class TestRpcHandler extends RpcHandler {
@Override @Override
@ -264,10 +164,4 @@ public class SaslIntegrationSuite {
return new OneForOneStreamManager(); return new OneForOneStreamManager();
} }
} }
private static void checkSecurityException(Throwable t) {
assertNotNull("No exception was caught.", t);
assertTrue("Expected SecurityException.",
t.getMessage().contains(SecurityException.class.getName()));
}
} }

View file

@ -0,0 +1,184 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import org.junit.BeforeClass;
import org.junit.Test;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
import org.apache.spark.network.TestUtils;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.client.TransportClientFactory;
import org.apache.spark.network.crypto.AuthClientBootstrap;
import org.apache.spark.network.crypto.AuthServerBootstrap;
import org.apache.spark.network.sasl.SaslClientBootstrap;
import org.apache.spark.network.sasl.SaslServerBootstrap;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
public class AppIsolationSuite {
// Use a long timeout to account for slow / overloaded build machines. In the normal case,
// tests should finish way before the timeout expires.
private static final long TIMEOUT_MS = 10_000;
private static SecretKeyHolder secretKeyHolder;
private static TransportConf conf;
@BeforeClass
public static void beforeAll() {
Map<String, String> confMap = new HashMap<>();
confMap.put("spark.network.crypto.enabled", "true");
confMap.put("spark.network.crypto.saslFallback", "false");
conf = new TransportConf("shuffle", new MapConfigProvider(confMap));
secretKeyHolder = mock(SecretKeyHolder.class);
when(secretKeyHolder.getSaslUser(eq("app-1"))).thenReturn("app-1");
when(secretKeyHolder.getSecretKey(eq("app-1"))).thenReturn("app-1");
when(secretKeyHolder.getSaslUser(eq("app-2"))).thenReturn("app-2");
when(secretKeyHolder.getSecretKey(eq("app-2"))).thenReturn("app-2");
}
@Test
public void testSaslAppIsolation() throws Exception {
testAppIsolation(
() -> new SaslServerBootstrap(conf, secretKeyHolder),
appId -> new SaslClientBootstrap(conf, appId, secretKeyHolder));
}
@Test
public void testAuthEngineAppIsolation() throws Exception {
testAppIsolation(
() -> new AuthServerBootstrap(conf, secretKeyHolder),
appId -> new AuthClientBootstrap(conf, appId, secretKeyHolder));
}
private void testAppIsolation(
Supplier<TransportServerBootstrap> serverBootstrap,
Function<String, TransportClientBootstrap> clientBootstrapFactory) throws Exception {
// Start a new server with the correct RPC handler to serve block data.
ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class);
ExternalBlockHandler blockHandler = new ExternalBlockHandler(
new OneForOneStreamManager(), blockResolver);
TransportServerBootstrap bootstrap = serverBootstrap.get();
try (
TransportContext blockServerContext = new TransportContext(conf, blockHandler);
TransportServer blockServer = blockServerContext.createServer(Arrays.asList(bootstrap));
// Create a client, and make a request to fetch blocks from a different app.
TransportClientFactory clientFactory1 = blockServerContext.createClientFactory(
Arrays.asList(clientBootstrapFactory.apply("app-1")));
TransportClient client1 = clientFactory1.createClient(
TestUtils.getLocalHost(), blockServer.getPort())) {
AtomicReference<Throwable> exception = new AtomicReference<>();
CountDownLatch blockFetchLatch = new CountDownLatch(1);
BlockFetchingListener listener = new BlockFetchingListener() {
@Override
public void onBlockFetchSuccess(String blockId, ManagedBuffer data) {
blockFetchLatch.countDown();
}
@Override
public void onBlockFetchFailure(String blockId, Throwable t) {
exception.set(t);
blockFetchLatch.countDown();
}
};
String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" };
OneForOneBlockFetcher fetcher =
new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf);
fetcher.start();
blockFetchLatch.await();
checkSecurityException(exception.get());
// Register an executor so that the next steps work.
ExecutorShuffleInfo executorInfo = new ExecutorShuffleInfo(
new String[] { System.getProperty("java.io.tmpdir") }, 1,
"org.apache.spark.shuffle.sort.SortShuffleManager");
RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo);
client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS);
// Make a successful request to fetch blocks, which creates a new stream. But do not actually
// fetch any blocks, to keep the stream open.
OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS);
StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
long streamId = stream.streamId;
try (
// Create a second client, authenticated with a different app ID, and try to read from
// the stream created for the previous app.
TransportClientFactory clientFactory2 = blockServerContext.createClientFactory(
Arrays.asList(clientBootstrapFactory.apply("app-2")));
TransportClient client2 = clientFactory2.createClient(
TestUtils.getLocalHost(), blockServer.getPort())
) {
CountDownLatch chunkReceivedLatch = new CountDownLatch(1);
ChunkReceivedCallback callback = new ChunkReceivedCallback() {
@Override
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
chunkReceivedLatch.countDown();
}
@Override
public void onFailure(int chunkIndex, Throwable t) {
exception.set(t);
chunkReceivedLatch.countDown();
}
};
exception.set(null);
client2.fetchChunk(streamId, 0, callback);
chunkReceivedLatch.await();
checkSecurityException(exception.get());
}
}
}
private static void checkSecurityException(Throwable t) {
assertNotNull("No exception was caught.", t);
assertTrue("Expected SecurityException.",
t.getMessage().contains(SecurityException.class.getName()));
}
}