[SPARK-13308] ManagedBuffers passed to OneToOneStreamManager need to be freed in non-error cases
ManagedBuffers that are passed to `OneToOneStreamManager.registerStream` need to be freed by the manager once it's done using them. However, the current code only frees them in certain error-cases and not during typical operation. This isn't a major problem today, but it will cause memory leaks after we implement better locking / pinning in the BlockManager (see #10705). This patch modifies the relevant network code so that the ManagedBuffers are freed as soon as the messages containing them are processed by the lower-level Netty message sending code. /cc zsxwing for review. Author: Josh Rosen <joshrosen@databricks.com> Closes #11193 from JoshRosen/add-missing-release-calls-in-network-layer.
This commit is contained in:
parent
c7d00a24da
commit
5f37aad48c
|
@ -65,7 +65,11 @@ public abstract class ManagedBuffer {
|
|||
public abstract ManagedBuffer release();
|
||||
|
||||
/**
|
||||
* Convert the buffer into an Netty object, used to write the data out.
|
||||
* Convert the buffer into an Netty object, used to write the data out. The return value is either
|
||||
* a {@link io.netty.buffer.ByteBuf} or a {@link io.netty.channel.FileRegion}.
|
||||
*
|
||||
* If this method returns a ByteBuf, then that buffer's reference count will be incremented and
|
||||
* the caller will be responsible for releasing this new reference.
|
||||
*/
|
||||
public abstract Object convertToNetty() throws IOException;
|
||||
}
|
||||
|
|
|
@ -64,7 +64,7 @@ public final class NettyManagedBuffer extends ManagedBuffer {
|
|||
|
||||
@Override
|
||||
public Object convertToNetty() throws IOException {
|
||||
return buf.duplicate();
|
||||
return buf.duplicate().retain();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -54,6 +54,7 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
|
|||
body = in.body().convertToNetty();
|
||||
isBodyInFrame = in.isBodyInFrame();
|
||||
} catch (Exception e) {
|
||||
in.body().release();
|
||||
if (in instanceof AbstractResponseMessage) {
|
||||
AbstractResponseMessage resp = (AbstractResponseMessage) in;
|
||||
// Re-encode this message as a failure response.
|
||||
|
@ -80,8 +81,10 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
|
|||
in.encode(header);
|
||||
assert header.writableBytes() == 0;
|
||||
|
||||
if (body != null && bodyLength > 0) {
|
||||
out.add(new MessageWithHeader(header, body, bodyLength));
|
||||
if (body != null) {
|
||||
// We transfer ownership of the reference on in.body() to MessageWithHeader.
|
||||
// This reference will be freed when MessageWithHeader.deallocate() is called.
|
||||
out.add(new MessageWithHeader(in.body(), header, body, bodyLength));
|
||||
} else {
|
||||
out.add(header);
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.network.protocol;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.nio.channels.WritableByteChannel;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
|
@ -26,6 +27,8 @@ import io.netty.channel.FileRegion;
|
|||
import io.netty.util.AbstractReferenceCounted;
|
||||
import io.netty.util.ReferenceCountUtil;
|
||||
|
||||
import org.apache.spark.network.buffer.ManagedBuffer;
|
||||
|
||||
/**
|
||||
* A wrapper message that holds two separate pieces (a header and a body).
|
||||
*
|
||||
|
@ -33,15 +36,35 @@ import io.netty.util.ReferenceCountUtil;
|
|||
*/
|
||||
class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
|
||||
|
||||
@Nullable private final ManagedBuffer managedBuffer;
|
||||
private final ByteBuf header;
|
||||
private final int headerLength;
|
||||
private final Object body;
|
||||
private final long bodyLength;
|
||||
private long totalBytesTransferred;
|
||||
|
||||
MessageWithHeader(ByteBuf header, Object body, long bodyLength) {
|
||||
/**
|
||||
* Construct a new MessageWithHeader.
|
||||
*
|
||||
* @param managedBuffer the {@link ManagedBuffer} that the message body came from. This needs to
|
||||
* be passed in so that the buffer can be freed when this message is
|
||||
* deallocated. Ownership of the caller's reference to this buffer is
|
||||
* transferred to this class, so if the caller wants to continue to use the
|
||||
* ManagedBuffer in other messages then they will need to call retain() on
|
||||
* it before passing it to this constructor. This may be null if and only if
|
||||
* `body` is a {@link FileRegion}.
|
||||
* @param header the message header.
|
||||
* @param body the message body. Must be either a {@link ByteBuf} or a {@link FileRegion}.
|
||||
* @param bodyLength the length of the message body, in bytes.
|
||||
*/
|
||||
MessageWithHeader(
|
||||
@Nullable ManagedBuffer managedBuffer,
|
||||
ByteBuf header,
|
||||
Object body,
|
||||
long bodyLength) {
|
||||
Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion,
|
||||
"Body must be a ByteBuf or a FileRegion.");
|
||||
this.managedBuffer = managedBuffer;
|
||||
this.header = header;
|
||||
this.headerLength = header.readableBytes();
|
||||
this.body = body;
|
||||
|
@ -99,6 +122,9 @@ class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
|
|||
protected void deallocate() {
|
||||
header.release();
|
||||
ReferenceCountUtil.release(body);
|
||||
if (managedBuffer != null) {
|
||||
managedBuffer.release();
|
||||
}
|
||||
}
|
||||
|
||||
private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException {
|
||||
|
|
|
@ -20,7 +20,6 @@ package org.apache.spark.network.server;
|
|||
import java.util.Iterator;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
|
|
|
@ -26,9 +26,13 @@ import io.netty.buffer.Unpooled;
|
|||
import io.netty.channel.FileRegion;
|
||||
import io.netty.util.AbstractReferenceCounted;
|
||||
import org.junit.Test;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
import org.apache.spark.network.TestManagedBuffer;
|
||||
import org.apache.spark.network.buffer.ManagedBuffer;
|
||||
import org.apache.spark.network.buffer.NettyManagedBuffer;
|
||||
import org.apache.spark.network.util.ByteArrayWritableChannel;
|
||||
|
||||
public class MessageWithHeaderSuite {
|
||||
|
@ -46,20 +50,43 @@ public class MessageWithHeaderSuite {
|
|||
@Test
|
||||
public void testByteBufBody() throws Exception {
|
||||
ByteBuf header = Unpooled.copyLong(42);
|
||||
ByteBuf body = Unpooled.copyLong(84);
|
||||
MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes());
|
||||
ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84);
|
||||
assertEquals(1, header.refCnt());
|
||||
assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt());
|
||||
ManagedBuffer managedBuf = new NettyManagedBuffer(bodyPassedToNettyManagedBuffer);
|
||||
|
||||
Object body = managedBuf.convertToNetty();
|
||||
assertEquals(2, bodyPassedToNettyManagedBuffer.refCnt());
|
||||
assertEquals(1, header.refCnt());
|
||||
|
||||
MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size());
|
||||
ByteBuf result = doWrite(msg, 1);
|
||||
assertEquals(msg.count(), result.readableBytes());
|
||||
assertEquals(42, result.readLong());
|
||||
assertEquals(84, result.readLong());
|
||||
|
||||
assert(msg.release());
|
||||
assertEquals(0, bodyPassedToNettyManagedBuffer.refCnt());
|
||||
assertEquals(0, header.refCnt());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDeallocateReleasesManagedBuffer() throws Exception {
|
||||
ByteBuf header = Unpooled.copyLong(42);
|
||||
ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84));
|
||||
ByteBuf body = (ByteBuf) managedBuf.convertToNetty();
|
||||
assertEquals(2, body.refCnt());
|
||||
MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes());
|
||||
assert(msg.release());
|
||||
Mockito.verify(managedBuf, Mockito.times(1)).release();
|
||||
assertEquals(0, body.refCnt());
|
||||
}
|
||||
|
||||
private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception {
|
||||
ByteBuf header = Unpooled.copyLong(42);
|
||||
int headerLength = header.readableBytes();
|
||||
TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall);
|
||||
MessageWithHeader msg = new MessageWithHeader(header, region, region.count());
|
||||
MessageWithHeader msg = new MessageWithHeader(null, header, region, region.count());
|
||||
|
||||
ByteBuf result = doWrite(msg, totalWrites / writesPerCall);
|
||||
assertEquals(headerLength + region.count(), result.readableBytes());
|
||||
|
@ -67,6 +94,7 @@ public class MessageWithHeaderSuite {
|
|||
for (long i = 0; i < 8; i++) {
|
||||
assertEquals(i, result.readLong());
|
||||
}
|
||||
assert(msg.release());
|
||||
}
|
||||
|
||||
private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception {
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.network.server;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import io.netty.channel.Channel;
|
||||
import org.junit.Test;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import org.apache.spark.network.TestManagedBuffer;
|
||||
import org.apache.spark.network.buffer.ManagedBuffer;
|
||||
|
||||
public class OneForOneStreamManagerSuite {
|
||||
|
||||
@Test
|
||||
public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception {
|
||||
OneForOneStreamManager manager = new OneForOneStreamManager();
|
||||
List<ManagedBuffer> buffers = new ArrayList<>();
|
||||
TestManagedBuffer buffer1 = Mockito.spy(new TestManagedBuffer(10));
|
||||
TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20));
|
||||
buffers.add(buffer1);
|
||||
buffers.add(buffer2);
|
||||
long streamId = manager.registerStream("appId", buffers.iterator());
|
||||
|
||||
Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
|
||||
manager.registerChannel(dummyChannel, streamId);
|
||||
|
||||
manager.connectionTerminated(dummyChannel);
|
||||
|
||||
Mockito.verify(buffer1, Mockito.times(1)).release();
|
||||
Mockito.verify(buffer2, Mockito.times(1)).release();
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue