[SPARK-17839][CORE] Use Nio's directbuffer instead of BufferedInputStream in order to avoid additional copy from os buffer cache to user buffer

## What changes were proposed in this pull request?

Currently we use BufferedInputStream to read the shuffle file which copies the file content from os buffer cache to the user buffer. This adds additional latency in reading the spill files. We made a change to use java nio's direct buffer to read the spill files and for certain pipelines spilling significant amount of data, we see up to 7% speedup for the entire pipeline.

## How was this patch tested?
Tested by running the job in the cluster and observed up to 7% speedup.

Author: Sital Kedia <skedia@fb.com>

Closes #15408 from sitalkedia/skedia/nio_spill_read.
This commit is contained in:
Sital Kedia 2016-10-17 11:03:04 -07:00 committed by Shixiong Zhu
parent e3bf37fa3a
commit c7ac027d5f
5 changed files with 279 additions and 4 deletions

View file

@ -0,0 +1,137 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.io;
import org.apache.spark.storage.StorageUtils;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.StandardOpenOption;
/**
* {@link InputStream} implementation which uses direct buffer
* to read a file to avoid extra copy of data between Java and
* native memory which happens when using {@link java.io.BufferedInputStream}.
* Unfortunately, this is not something already available in JDK,
* {@link sun.nio.ch.ChannelInputStream} supports reading a file using nio,
* but does not support buffering.
*/
public final class NioBufferedFileInputStream extends InputStream {
private static final int DEFAULT_BUFFER_SIZE_BYTES = 8192;
private final ByteBuffer byteBuffer;
private final FileChannel fileChannel;
public NioBufferedFileInputStream(File file, int bufferSizeInBytes) throws IOException {
byteBuffer = ByteBuffer.allocateDirect(bufferSizeInBytes);
fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ);
byteBuffer.flip();
}
public NioBufferedFileInputStream(File file) throws IOException {
this(file, DEFAULT_BUFFER_SIZE_BYTES);
}
/**
* Checks weather data is left to be read from the input stream.
* @return true if data is left, false otherwise
* @throws IOException
*/
private boolean refill() throws IOException {
if (!byteBuffer.hasRemaining()) {
byteBuffer.clear();
int nRead = 0;
while (nRead == 0) {
nRead = fileChannel.read(byteBuffer);
}
if (nRead < 0) {
return false;
}
byteBuffer.flip();
}
return true;
}
@Override
public synchronized int read() throws IOException {
if (!refill()) {
return -1;
}
return byteBuffer.get() & 0xFF;
}
@Override
public synchronized int read(byte[] b, int offset, int len) throws IOException {
if (offset < 0 || len < 0 || offset + len < 0 || offset + len > b.length) {
throw new IndexOutOfBoundsException();
}
if (!refill()) {
return -1;
}
len = Math.min(len, byteBuffer.remaining());
byteBuffer.get(b, offset, len);
return len;
}
@Override
public synchronized int available() throws IOException {
return byteBuffer.remaining();
}
@Override
public synchronized long skip(long n) throws IOException {
if (n <= 0L) {
return 0L;
}
if (byteBuffer.remaining() >= n) {
// The buffered content is enough to skip
byteBuffer.position(byteBuffer.position() + (int) n);
return n;
}
long skippedFromBuffer = byteBuffer.remaining();
long toSkipFromFileChannel = n - skippedFromBuffer;
// Discard everything we have read in the buffer.
byteBuffer.position(0);
byteBuffer.flip();
return skippedFromBuffer + skipFromFileChannel(toSkipFromFileChannel);
}
private long skipFromFileChannel(long n) throws IOException {
long currentFilePosition = fileChannel.position();
long size = fileChannel.size();
if (n > size - currentFilePosition) {
fileChannel.position(size);
return size - currentFilePosition;
} else {
fileChannel.position(currentFilePosition + n);
return n;
}
}
@Override
public synchronized void close() throws IOException {
fileChannel.close();
StorageUtils.dispose(byteBuffer);
}
@Override
protected void finalize() throws IOException {
close();
}
}

View file

@ -23,6 +23,7 @@ import com.google.common.io.ByteStreams;
import com.google.common.io.Closeables;
import org.apache.spark.SparkEnv;
import org.apache.spark.io.NioBufferedFileInputStream;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockId;
import org.apache.spark.unsafe.Platform;
@ -69,8 +70,8 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES;
}
final BufferedInputStream bs =
new BufferedInputStream(new FileInputStream(file), (int) bufferSizeBytes);
final InputStream bs =
new NioBufferedFileInputStream(file, (int) bufferSizeBytes);
try {
this.in = serializerManager.wrapStream(blockId, bs);
this.din = new DataInputStream(this.in);

View file

@ -23,6 +23,7 @@ import com.google.common.io.ByteStreams
import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.io.NioBufferedFileInputStream
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
@ -89,7 +90,7 @@ private[spark] class IndexShuffleBlockResolver(
val lengths = new Array[Long](blocks)
// Read the lengths of blocks
val in = try {
new DataInputStream(new BufferedInputStream(new FileInputStream(index)))
new DataInputStream(new NioBufferedFileInputStream(index))
} catch {
case e: IOException =>
return null

View file

@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.io;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.RandomUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import static org.junit.Assert.assertEquals;
/**
* Tests functionality of {@link NioBufferedFileInputStream}
*/
public class NioBufferedFileInputStreamSuite {
private byte[] randomBytes;
private File inputFile;
@Before
public void setUp() throws IOException {
// Create a byte array of size 2 MB with random bytes
randomBytes = RandomUtils.nextBytes(2 * 1024 * 1024);
inputFile = File.createTempFile("temp-file", ".tmp");
FileUtils.writeByteArrayToFile(inputFile, randomBytes);
}
@After
public void tearDown() {
inputFile.delete();
}
@Test
public void testReadOneByte() throws IOException {
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
for (int i = 0; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
}
@Test
public void testReadMultipleBytes() throws IOException {
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
byte[] readBytes = new byte[8 * 1024];
int i = 0;
while (i < randomBytes.length) {
int read = inputStream.read(readBytes, 0, 8 * 1024);
for (int j = 0; j < read; j++) {
assertEquals(randomBytes[i], readBytes[j]);
i++;
}
}
}
@Test
public void testBytesSkipped() throws IOException {
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
assertEquals(1024, inputStream.skip(1024));
for (int i = 1024; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
}
@Test
public void testBytesSkippedAfterRead() throws IOException {
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
for (int i = 0; i < 1024; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
assertEquals(1024, inputStream.skip(1024));
for (int i = 2048; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
}
@Test
public void testNegativeBytesSkippedAfterRead() throws IOException {
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
for (int i = 0; i < 1024; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
// Skipping negative bytes should essential be a no-op
assertEquals(0, inputStream.skip(-1));
assertEquals(0, inputStream.skip(-1024));
assertEquals(0, inputStream.skip(Long.MIN_VALUE));
assertEquals(1024, inputStream.skip(1024));
for (int i = 2048; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
}
@Test
public void testSkipFromFileChannel() throws IOException {
InputStream inputStream = new NioBufferedFileInputStream(inputFile, 10);
// Since the buffer is smaller than the skipped bytes, this will guarantee
// we skip from underlying file channel.
assertEquals(1024, inputStream.skip(1024));
for (int i = 1024; i < 2048; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
assertEquals(256, inputStream.skip(256));
assertEquals(256, inputStream.skip(256));
assertEquals(512, inputStream.skip(512));
for (int i = 3072; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
}
@Test
public void testBytesSkippedAfterEOF() throws IOException {
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1));
assertEquals(-1, inputStream.read());
}
}

View file

@ -22,6 +22,7 @@ import java.io._
import com.google.common.io.Closeables
import org.apache.spark.SparkException
import org.apache.spark.io.NioBufferedFileInputStream
import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.unsafe.Platform
@ -130,7 +131,7 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu
if (out != null) {
out.close()
out = null
in = new DataInputStream(new BufferedInputStream(new FileInputStream(file.toString)))
in = new DataInputStream(new NioBufferedFileInputStream(file))
}
if (unreadBytes > 0) {