[SPARK-27275][CORE] Fix potential corruption in EncryptedMessage.transferTo

## What changes were proposed in this pull request?

Right now there are several issues in `EncryptedMessage.transferTo`:

- When the underlying buffer has more than `1024 * 32` bytes (this should be rare but it could happen in error messages that send over the wire), it may just send a partial message as `EncryptedMessage.count` becomes less than `transferred`. This will cause the client hang forever (or timeout) as it will wait until receiving expected length of bytes, or weird errors (such as corruption or silent correctness issue) if the channel is reused by other messages.
- When the underlying buffer is full, it's still trying to write out bytes in a busy loop.

This PR fixes  the issues in `EncryptedMessage.transferTo` and also makes it follow the contract of `FileRegion`:

- `count` should be a fixed value which is just the length of the whole message.
- It should be non-blocking. When the underlying socket is not ready to write, it should give up and give control back.
- `transferTo` should return the length of written bytes.

## How was this patch tested?

The new added tests.

Closes #24211 from zsxwing/fix-enc.

Authored-by: Shixiong Zhu <zsxwing@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Shixiong Zhu 2019-03-26 15:48:29 -07:00 committed by Wenchen Fan
parent 69035684d4
commit 5624bfbcfe
3 changed files with 166 additions and 18 deletions

View file

@ -44,7 +44,8 @@ public class TransportCipher {
@VisibleForTesting
static final String ENCRYPTION_HANDLER_NAME = "TransportEncryption";
private static final String DECRYPTION_HANDLER_NAME = "TransportDecryption";
private static final int STREAM_BUFFER_SIZE = 1024 * 32;
@VisibleForTesting
static final int STREAM_BUFFER_SIZE = 1024 * 32;
private final Properties conf;
private final String cipher;
@ -84,7 +85,8 @@ public class TransportCipher {
return outIv;
}
private CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException {
@VisibleForTesting
CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException {
return new CryptoOutputStream(cipher, conf, ch, key, new IvParameterSpec(outIv));
}
@ -104,7 +106,8 @@ public class TransportCipher {
.addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this));
}
private static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
@VisibleForTesting
static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
private final ByteArrayWritableChannel byteChannel;
private final CryptoOutputStream cos;
private boolean isCipherValid;
@ -118,7 +121,12 @@ public class TransportCipher {
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
throws Exception {
ctx.write(new EncryptedMessage(this, cos, msg, byteChannel), promise);
ctx.write(createEncryptedMessage(msg), promise);
}
@VisibleForTesting
EncryptedMessage createEncryptedMessage(Object msg) {
return new EncryptedMessage(this, cos, msg, byteChannel);
}
@Override
@ -190,12 +198,14 @@ public class TransportCipher {
}
}
private static class EncryptedMessage extends AbstractFileRegion {
@VisibleForTesting
static class EncryptedMessage extends AbstractFileRegion {
private final boolean isByteBuf;
private final ByteBuf buf;
private final FileRegion region;
private final CryptoOutputStream cos;
private final EncryptionHandler handler;
private final long count;
private long transferred;
// Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has
@ -221,11 +231,12 @@ public class TransportCipher {
this.byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE);
this.cos = cos;
this.byteEncChannel = ch;
this.count = isByteBuf ? buf.readableBytes() : region.count();
}
@Override
public long count() {
return isByteBuf ? buf.readableBytes() : region.count();
return count;
}
@Override
@ -277,22 +288,38 @@ public class TransportCipher {
public long transferTo(WritableByteChannel target, long position) throws IOException {
Preconditions.checkArgument(position == transferred(), "Invalid position.");
if (transferred == count) {
return 0;
}
long totalBytesWritten = 0L;
do {
if (currentEncrypted == null) {
encryptMore();
}
int bytesWritten = currentEncrypted.remaining();
target.write(currentEncrypted);
bytesWritten -= currentEncrypted.remaining();
transferred += bytesWritten;
if (!currentEncrypted.hasRemaining()) {
long remaining = currentEncrypted.remaining();
if (remaining == 0) {
// Just for safety to avoid endless loop. It usually won't happen, but since the
// underlying `region.transferTo` is allowed to transfer 0 bytes, we should handle it for
// safety.
currentEncrypted = null;
byteEncChannel.reset();
return totalBytesWritten;
}
} while (transferred < count());
return transferred;
long bytesWritten = target.write(currentEncrypted);
totalBytesWritten += bytesWritten;
transferred += bytesWritten;
if (bytesWritten < remaining) {
// break as the underlying buffer in "target" is full
break;
}
currentEncrypted = null;
byteEncChannel.reset();
} while (transferred < count);
return totalBytesWritten;
}
private void encryptMore() throws IOException {

View file

@ -17,16 +17,27 @@
package org.apache.spark.network.crypto;
import java.nio.ByteBuffer;
import java.nio.channels.WritableByteChannel;
import java.util.Arrays;
import java.util.Map;
import java.security.InvalidKeyException;
import java.util.Random;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.collect.ImmutableMap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.FileRegion;
import org.junit.BeforeClass;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
import org.apache.spark.network.util.ByteArrayWritableChannel;
import org.apache.spark.network.util.MapConfigProvider;
import org.apache.spark.network.util.TransportConf;
@ -121,4 +132,77 @@ public class AuthEngineSuite {
}
}
@Test
public void testEncryptedMessage() throws Exception {
AuthEngine client = new AuthEngine("appId", "secret", conf);
AuthEngine server = new AuthEngine("appId", "secret", conf);
try {
ClientChallenge clientChallenge = client.challenge();
ServerResponse serverResponse = server.respond(clientChallenge);
client.validate(serverResponse);
TransportCipher cipher = server.sessionCipher();
TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher);
byte[] data = new byte[TransportCipher.STREAM_BUFFER_SIZE + 1];
new Random().nextBytes(data);
ByteBuf buf = Unpooled.wrappedBuffer(data);
ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length);
TransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(buf);
while (emsg.transfered() < emsg.count()) {
emsg.transferTo(channel, emsg.transfered());
}
assertEquals(data.length, channel.length());
} finally {
client.close();
server.close();
}
}
@Test
public void testEncryptedMessageWhenTransferringZeroBytes() throws Exception {
AuthEngine client = new AuthEngine("appId", "secret", conf);
AuthEngine server = new AuthEngine("appId", "secret", conf);
try {
ClientChallenge clientChallenge = client.challenge();
ServerResponse serverResponse = server.respond(clientChallenge);
client.validate(serverResponse);
TransportCipher cipher = server.sessionCipher();
TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher);
int testDataLength = 4;
FileRegion region = mock(FileRegion.class);
when(region.count()).thenReturn((long) testDataLength);
// Make `region.transferTo` do nothing in first call and transfer 4 bytes in the second one.
when(region.transferTo(any(), anyLong())).thenAnswer(new Answer<Long>() {
private boolean firstTime = true;
@Override
public Long answer(InvocationOnMock invocationOnMock) throws Throwable {
if (firstTime) {
firstTime = false;
return 0L;
} else {
WritableByteChannel channel = invocationOnMock.getArgument(0);
channel.write(ByteBuffer.wrap(new byte[testDataLength]));
return (long) testDataLength;
}
}
});
TransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(region);
ByteArrayWritableChannel channel = new ByteArrayWritableChannel(testDataLength);
// "transferTo" should act correctly when the underlying FileRegion transfers 0 bytes.
assertEquals(0L, emsg.transferTo(channel, emsg.transfered()));
assertEquals(testDataLength, emsg.transferTo(channel, emsg.transfered()));
assertEquals(emsg.transfered(), emsg.count());
assertEquals(4, channel.length());
} finally {
client.close();
server.close();
}
}
}

View file

@ -124,6 +124,42 @@ public class AuthIntegrationSuite {
}
}
@Test
public void testLargeMessageEncryption() throws Exception {
// Use a big length to create a message that cannot be put into the encryption buffer completely
final int testErrorMessageLength = TransportCipher.STREAM_BUFFER_SIZE;
ctx = new AuthTestCtx(new RpcHandler() {
@Override
public void receive(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback) {
char[] longMessage = new char[testErrorMessageLength];
Arrays.fill(longMessage, 'D');
callback.onFailure(new RuntimeException(new String(longMessage)));
}
@Override
public StreamManager getStreamManager() {
return null;
}
});
ctx.createServer("secret");
ctx.createClient("secret");
try {
ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
fail("Should have failed unencrypted RPC.");
} catch (Exception e) {
assertTrue(ctx.authRpcHandler.doDelegate);
assertTrue(e.getMessage() + " is not an expected error", e.getMessage().contains("DDDDD"));
// Verify we receive the complete error message
int messageStart = e.getMessage().indexOf("DDDDD");
int messageEnd = e.getMessage().lastIndexOf("DDDDD") + 5;
assertEquals(testErrorMessageLength, messageEnd - messageStart);
}
}
private class AuthTestCtx {
private final String appId = "testAppId";
@ -136,10 +172,7 @@ public class AuthIntegrationSuite {
volatile AuthRpcHandler authRpcHandler;
AuthTestCtx() throws Exception {
Map<String, String> testConf = ImmutableMap.of("spark.network.crypto.enabled", "true");
this.conf = new TransportConf("rpc", new MapConfigProvider(testConf));
RpcHandler rpcHandler = new RpcHandler() {
this(new RpcHandler() {
@Override
public void receive(
TransportClient client,
@ -153,8 +186,12 @@ public class AuthIntegrationSuite {
public StreamManager getStreamManager() {
return null;
}
};
});
}
AuthTestCtx(RpcHandler rpcHandler) throws Exception {
Map<String, String> testConf = ImmutableMap.of("spark.network.crypto.enabled", "true");
this.conf = new TransportConf("rpc", new MapConfigProvider(testConf));
this.ctx = new TransportContext(conf, rpcHandler);
}