[SPARK-13308] ManagedBuffers passed to OneToOneStreamManager need to be freed in non-error cases

ManagedBuffers that are passed to `OneToOneStreamManager.registerStream` need to be freed by the manager once it's done using them. However, the current code only frees them in certain error-cases and not during typical operation. This isn't a major problem today, but it will cause memory leaks after we implement better locking / pinning in the BlockManager (see #10705).

This patch modifies the relevant network code so that the ManagedBuffers are freed as soon as the messages containing them are processed by the lower-level Netty message sending code.

/cc zsxwing for review.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #11193 from JoshRosen/add-missing-release-calls-in-network-layer.
This commit is contained in:
Josh Rosen 2016-02-16 12:06:30 -08:00 committed by Shixiong Zhu
parent c7d00a24da
commit 5f37aad48c
7 changed files with 119 additions and 9 deletions

View file

@ -65,7 +65,11 @@ public abstract class ManagedBuffer {
public abstract ManagedBuffer release();
/**
* Convert the buffer into an Netty object, used to write the data out.
* Convert the buffer into an Netty object, used to write the data out. The return value is either
* a {@link io.netty.buffer.ByteBuf} or a {@link io.netty.channel.FileRegion}.
*
* If this method returns a ByteBuf, then that buffer's reference count will be incremented and
* the caller will be responsible for releasing this new reference.
*/
public abstract Object convertToNetty() throws IOException;
}

View file

@ -64,7 +64,7 @@ public final class NettyManagedBuffer extends ManagedBuffer {
@Override
public Object convertToNetty() throws IOException {
return buf.duplicate();
return buf.duplicate().retain();
}
@Override

View file

@ -54,6 +54,7 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
body = in.body().convertToNetty();
isBodyInFrame = in.isBodyInFrame();
} catch (Exception e) {
in.body().release();
if (in instanceof AbstractResponseMessage) {
AbstractResponseMessage resp = (AbstractResponseMessage) in;
// Re-encode this message as a failure response.
@ -80,8 +81,10 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
in.encode(header);
assert header.writableBytes() == 0;
if (body != null && bodyLength > 0) {
out.add(new MessageWithHeader(header, body, bodyLength));
if (body != null) {
// We transfer ownership of the reference on in.body() to MessageWithHeader.
// This reference will be freed when MessageWithHeader.deallocate() is called.
out.add(new MessageWithHeader(in.body(), header, body, bodyLength));
} else {
out.add(header);
}

View file

@ -19,6 +19,7 @@ package org.apache.spark.network.protocol;
import java.io.IOException;
import java.nio.channels.WritableByteChannel;
import javax.annotation.Nullable;
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
@ -26,6 +27,8 @@ import io.netty.channel.FileRegion;
import io.netty.util.AbstractReferenceCounted;
import io.netty.util.ReferenceCountUtil;
import org.apache.spark.network.buffer.ManagedBuffer;
/**
* A wrapper message that holds two separate pieces (a header and a body).
*
@ -33,15 +36,35 @@ import io.netty.util.ReferenceCountUtil;
*/
class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
@Nullable private final ManagedBuffer managedBuffer;
private final ByteBuf header;
private final int headerLength;
private final Object body;
private final long bodyLength;
private long totalBytesTransferred;
MessageWithHeader(ByteBuf header, Object body, long bodyLength) {
/**
* Construct a new MessageWithHeader.
*
* @param managedBuffer the {@link ManagedBuffer} that the message body came from. This needs to
* be passed in so that the buffer can be freed when this message is
* deallocated. Ownership of the caller's reference to this buffer is
* transferred to this class, so if the caller wants to continue to use the
* ManagedBuffer in other messages then they will need to call retain() on
* it before passing it to this constructor. This may be null if and only if
* `body` is a {@link FileRegion}.
* @param header the message header.
* @param body the message body. Must be either a {@link ByteBuf} or a {@link FileRegion}.
* @param bodyLength the length of the message body, in bytes.
*/
MessageWithHeader(
@Nullable ManagedBuffer managedBuffer,
ByteBuf header,
Object body,
long bodyLength) {
Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion,
"Body must be a ByteBuf or a FileRegion.");
this.managedBuffer = managedBuffer;
this.header = header;
this.headerLength = header.readableBytes();
this.body = body;
@ -99,6 +122,9 @@ class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
protected void deallocate() {
header.release();
ReferenceCountUtil.release(body);
if (managedBuffer != null) {
managedBuffer.release();
}
}
private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException {

View file

@ -20,7 +20,6 @@ package org.apache.spark.network.server;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

View file

@ -26,9 +26,13 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.FileRegion;
import io.netty.util.AbstractReferenceCounted;
import org.junit.Test;
import org.mockito.Mockito;
import static org.junit.Assert.*;
import org.apache.spark.network.TestManagedBuffer;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NettyManagedBuffer;
import org.apache.spark.network.util.ByteArrayWritableChannel;
public class MessageWithHeaderSuite {
@ -46,20 +50,43 @@ public class MessageWithHeaderSuite {
@Test
public void testByteBufBody() throws Exception {
ByteBuf header = Unpooled.copyLong(42);
ByteBuf body = Unpooled.copyLong(84);
MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes());
ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84);
assertEquals(1, header.refCnt());
assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt());
ManagedBuffer managedBuf = new NettyManagedBuffer(bodyPassedToNettyManagedBuffer);
Object body = managedBuf.convertToNetty();
assertEquals(2, bodyPassedToNettyManagedBuffer.refCnt());
assertEquals(1, header.refCnt());
MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size());
ByteBuf result = doWrite(msg, 1);
assertEquals(msg.count(), result.readableBytes());
assertEquals(42, result.readLong());
assertEquals(84, result.readLong());
assert(msg.release());
assertEquals(0, bodyPassedToNettyManagedBuffer.refCnt());
assertEquals(0, header.refCnt());
}
@Test
public void testDeallocateReleasesManagedBuffer() throws Exception {
ByteBuf header = Unpooled.copyLong(42);
ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84));
ByteBuf body = (ByteBuf) managedBuf.convertToNetty();
assertEquals(2, body.refCnt());
MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes());
assert(msg.release());
Mockito.verify(managedBuf, Mockito.times(1)).release();
assertEquals(0, body.refCnt());
}
private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception {
ByteBuf header = Unpooled.copyLong(42);
int headerLength = header.readableBytes();
TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall);
MessageWithHeader msg = new MessageWithHeader(header, region, region.count());
MessageWithHeader msg = new MessageWithHeader(null, header, region, region.count());
ByteBuf result = doWrite(msg, totalWrites / writesPerCall);
assertEquals(headerLength + region.count(), result.readableBytes());
@ -67,6 +94,7 @@ public class MessageWithHeaderSuite {
for (long i = 0; i < 8; i++) {
assertEquals(i, result.readLong());
}
assert(msg.release());
}
private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception {

View file

@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.server;
import java.util.ArrayList;
import java.util.List;
import io.netty.channel.Channel;
import org.junit.Test;
import org.mockito.Mockito;
import org.apache.spark.network.TestManagedBuffer;
import org.apache.spark.network.buffer.ManagedBuffer;
public class OneForOneStreamManagerSuite {
@Test
public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception {
OneForOneStreamManager manager = new OneForOneStreamManager();
List<ManagedBuffer> buffers = new ArrayList<>();
TestManagedBuffer buffer1 = Mockito.spy(new TestManagedBuffer(10));
TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20));
buffers.add(buffer1);
buffers.add(buffer2);
long streamId = manager.registerStream("appId", buffers.iterator());
Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
manager.registerChannel(dummyChannel, streamId);
manager.connectionTerminated(dummyChannel);
Mockito.verify(buffer1, Mockito.times(1)).release();
Mockito.verify(buffer2, Mockito.times(1)).release();
}
}