diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 8b8f989284..45fee541a4 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -77,16 +77,16 @@ public final class FileSegmentManagedBuffer extends ManagedBuffer { return channel.map(FileChannel.MapMode.READ_ONLY, offset, length); } } catch (IOException e) { + String errorMessage = "Error in reading " + this; try { if (channel != null) { long size = channel.size(); - throw new IOException("Error in reading " + this + " (actual file length " + size + ")", - e); + errorMessage = "Error in reading " + this + " (actual file length " + size + ")"; } } catch (IOException ignored) { // ignore } - throw new IOException("Error in opening " + this, e); + throw new IOException(errorMessage, e); } finally { JavaUtils.closeQuietly(channel); } @@ -95,26 +95,24 @@ public final class FileSegmentManagedBuffer extends ManagedBuffer { @Override public InputStream createInputStream() throws IOException { FileInputStream is = null; + boolean shouldClose = true; try { is = new FileInputStream(file); ByteStreams.skipFully(is, offset); - return new LimitedInputStream(is, length); + InputStream r = new LimitedInputStream(is, length); + shouldClose = false; + return r; } catch (IOException e) { - try { - if (is != null) { - long size = file.length(); - throw new IOException("Error in reading " + this + " (actual file length " + size + ")", - e); - } - } catch (IOException ignored) { - // ignore - } finally { + String errorMessage = "Error in reading " + this; + if (is != null) { + long size = file.length(); + errorMessage = "Error in reading " + this + " (actual file length " + size + ")"; + } + throw new IOException(errorMessage, e); + } finally { + if (shouldClose) { JavaUtils.closeQuietly(is); } - throw new IOException("Error in opening " + this, e); - } catch (RuntimeException e) { - JavaUtils.closeQuietly(is); - throw e; } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index d95ed22912..9c85ab2f5f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -70,11 +70,14 @@ public class TransportServer implements Closeable { this.appRpcHandler = appRpcHandler; this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); + boolean shouldClose = true; try { init(hostToBind, portToBind); - } catch (RuntimeException e) { - JavaUtils.closeQuietly(this); - throw e; + shouldClose = false; + } finally { + if (shouldClose) { + JavaUtils.closeQuietly(this); + } } } diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala index d15e7937b0..ea38ccb289 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala @@ -42,43 +42,59 @@ private[spark] class SocketAuthHelper(conf: SparkConf) { * Read the auth secret from the socket and compare to the expected value. Write the reply back * to the socket. * - * If authentication fails, this method will close the socket. + * If authentication fails or error is thrown, this method will close the socket. * * @param s The client socket. * @throws IllegalArgumentException If authentication fails. */ def authClient(s: Socket): Unit = { - // Set the socket timeout while checking the auth secret. Reset it before returning. - val currentTimeout = s.getSoTimeout() + var shouldClose = true try { - s.setSoTimeout(10000) - val clientSecret = readUtf8(s) - if (secret == clientSecret) { - writeUtf8("ok", s) - } else { - writeUtf8("err", s) - JavaUtils.closeQuietly(s) + // Set the socket timeout while checking the auth secret. Reset it before returning. + val currentTimeout = s.getSoTimeout() + try { + s.setSoTimeout(10000) + val clientSecret = readUtf8(s) + if (secret == clientSecret) { + writeUtf8("ok", s) + shouldClose = false + } else { + writeUtf8("err", s) + throw new IllegalArgumentException("Authentication failed.") + } + } finally { + s.setSoTimeout(currentTimeout) } } finally { - s.setSoTimeout(currentTimeout) + if (shouldClose) { + JavaUtils.closeQuietly(s) + } } } /** * Authenticate with a server by writing the auth secret and checking the server's reply. * - * If authentication fails, this method will close the socket. + * If authentication fails or error is thrown, this method will close the socket. * * @param s The socket connected to the server. * @throws IllegalArgumentException If authentication fails. */ def authToServer(s: Socket): Unit = { - writeUtf8(secret, s) + var shouldClose = true + try { + writeUtf8(secret, s) - val reply = readUtf8(s) - if (reply != "ok") { - JavaUtils.closeQuietly(s) - throw new IllegalArgumentException("Authentication failed.") + val reply = readUtf8(s) + if (reply != "ok") { + throw new IllegalArgumentException("Authentication failed.") + } else { + shouldClose = false + } + } finally { + if (shouldClose) { + JavaUtils.closeQuietly(s) + } } }