[SPARK-21342] Fix DownloadCallback to work well with RetryingBlockFetcher.

## What changes were proposed in this pull request?

When `RetryingBlockFetcher` retries fetching blocks. There could be two `DownloadCallback`s download the same content to the same target file. It could cause `ShuffleBlockFetcherIterator` reading a partial result.

This pr proposes to create and delete the tmp files in `OneForOneBlockFetcher`

Author: jinxing <jinxing6042@126.com>
Author: Shixiong Zhu <zsxwing@gmail.com>

Closes #18565 from jinxing64/SPARK-21342.
This commit is contained in:
jinxing 2017-07-10 21:06:58 +08:00 committed by Wenchen Fan
parent 647963a26a
commit 6a06c4b03c
11 changed files with 108 additions and 45 deletions

View file

@ -17,7 +17,6 @@
package org.apache.spark.network.shuffle;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
@ -91,15 +90,15 @@ public class ExternalShuffleClient extends ShuffleClient {
String execId,
String[] blockIds,
BlockFetchingListener listener,
File[] shuffleFiles) {
TempShuffleFileManager tempShuffleFileManager) {
checkInit();
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
try {
RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
(blockIds1, listener1) -> {
TransportClient client = clientFactory.createClient(host, port);
new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1, conf,
shuffleFiles).start();
new OneForOneBlockFetcher(client, appId, execId,
blockIds1, listener1, conf, tempShuffleFileManager).start();
};
int maxRetries = conf.maxIORetries();

View file

@ -57,11 +57,21 @@ public class OneForOneBlockFetcher {
private final String[] blockIds;
private final BlockFetchingListener listener;
private final ChunkReceivedCallback chunkCallback;
private TransportConf transportConf = null;
private File[] shuffleFiles = null;
private final TransportConf transportConf;
private final TempShuffleFileManager tempShuffleFileManager;
private StreamHandle streamHandle = null;
public OneForOneBlockFetcher(
TransportClient client,
String appId,
String execId,
String[] blockIds,
BlockFetchingListener listener,
TransportConf transportConf) {
this(client, appId, execId, blockIds, listener, transportConf, null);
}
public OneForOneBlockFetcher(
TransportClient client,
String appId,
@ -69,18 +79,14 @@ public class OneForOneBlockFetcher {
String[] blockIds,
BlockFetchingListener listener,
TransportConf transportConf,
File[] shuffleFiles) {
TempShuffleFileManager tempShuffleFileManager) {
this.client = client;
this.openMessage = new OpenBlocks(appId, execId, blockIds);
this.blockIds = blockIds;
this.listener = listener;
this.chunkCallback = new ChunkCallback();
this.transportConf = transportConf;
if (shuffleFiles != null) {
this.shuffleFiles = shuffleFiles;
assert this.shuffleFiles.length == blockIds.length:
"Number of shuffle files should equal to blocks";
}
this.tempShuffleFileManager = tempShuffleFileManager;
}
/** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
@ -119,9 +125,9 @@ public class OneForOneBlockFetcher {
// Immediately request all chunks -- we expect that the total size of the request is
// reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
for (int i = 0; i < streamHandle.numChunks; i++) {
if (shuffleFiles != null) {
if (tempShuffleFileManager != null) {
client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i),
new DownloadCallback(shuffleFiles[i], i));
new DownloadCallback(i));
} else {
client.fetchChunk(streamHandle.streamId, i, chunkCallback);
}
@ -157,8 +163,8 @@ public class OneForOneBlockFetcher {
private File targetFile = null;
private int chunkIndex;
DownloadCallback(File targetFile, int chunkIndex) throws IOException {
this.targetFile = targetFile;
DownloadCallback(int chunkIndex) throws IOException {
this.targetFile = tempShuffleFileManager.createTempShuffleFile();
this.channel = Channels.newChannel(new FileOutputStream(targetFile));
this.chunkIndex = chunkIndex;
}
@ -174,6 +180,9 @@ public class OneForOneBlockFetcher {
ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0,
targetFile.length());
listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
if (!tempShuffleFileManager.registerTempShuffleFileToClean(targetFile)) {
targetFile.delete();
}
}
@Override
@ -182,6 +191,7 @@ public class OneForOneBlockFetcher {
// On receipt of a failure, fail every block from chunkIndex onwards.
String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length);
failRemainingBlocks(remainingBlockIds, cause);
targetFile.delete();
}
}
}

View file

@ -18,7 +18,6 @@
package org.apache.spark.network.shuffle;
import java.io.Closeable;
import java.io.File;
/** Provides an interface for reading shuffle files, either from an Executor or external service. */
public abstract class ShuffleClient implements Closeable {
@ -35,6 +34,16 @@ public abstract class ShuffleClient implements Closeable {
* Note that this API takes a sequence so the implementation can batch requests, and does not
* return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
* the data of a block is fetched, rather than waiting for all blocks to be fetched.
*
* @param host the host of the remote node.
* @param port the port of the remote node.
* @param execId the executor id.
* @param blockIds block ids to fetch.
* @param listener the listener to receive block fetching status.
* @param tempShuffleFileManager TempShuffleFileManager to create and clean temp shuffle files.
* If it's not <code>null</code>, the remote blocks will be streamed
* into temp shuffle files to reduce the memory usage, otherwise,
* they will be kept in memory.
*/
public abstract void fetchBlocks(
String host,
@ -42,5 +51,5 @@ public abstract class ShuffleClient implements Closeable {
String execId,
String[] blockIds,
BlockFetchingListener listener,
File[] shuffleFiles);
TempShuffleFileManager tempShuffleFileManager);
}

View file

@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle;
import java.io.File;
/**
* A manager to create temp shuffle block files to reduce the memory usage and also clean temp
* files when they won't be used any more.
*/
public interface TempShuffleFileManager {
/** Create a temp shuffle block file. */
File createTempShuffleFile();
/**
* Register a temp shuffle file to clean up when it won't be used any more. Return whether the
* file is registered successfully. If `false`, the caller should clean up the file by itself.
*/
boolean registerTempShuffleFileToClean(File file);
}

View file

@ -204,7 +204,7 @@ public class SaslIntegrationSuite {
String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" };
OneForOneBlockFetcher fetcher =
new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf, null);
new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf);
fetcher.start();
blockFetchLatch.await();
checkSecurityException(exception.get());

View file

@ -131,7 +131,7 @@ public class OneForOneBlockFetcherSuite {
BlockFetchingListener listener = mock(BlockFetchingListener.class);
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
OneForOneBlockFetcher fetcher =
new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener, conf, null);
new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener, conf);
// Respond to the "OpenBlocks" message with an appropriate ShuffleStreamHandle with streamId 123
doAnswer(invocationOnMock -> {

View file

@ -17,7 +17,7 @@
package org.apache.spark.network
import java.io.{Closeable, File}
import java.io.Closeable
import java.nio.ByteBuffer
import scala.concurrent.{Future, Promise}
@ -26,7 +26,7 @@ import scala.reflect.ClassTag
import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager}
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.ThreadUtils
@ -68,7 +68,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener,
shuffleFiles: Array[File]): Unit
tempShuffleFileManager: TempShuffleFileManager): Unit
/**
* Upload a single block to a remote node, available only after [[init]] is invoked.
@ -101,7 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
ret.flip()
result.success(new NioManagedBuffer(ret))
}
}, shuffleFiles = null)
}, tempShuffleFileManager = null)
ThreadUtils.awaitResult(result.future, Duration.Inf)
}

View file

@ -17,7 +17,6 @@
package org.apache.spark.network.netty
import java.io.File
import java.nio.ByteBuffer
import scala.collection.JavaConverters._
@ -30,7 +29,7 @@ import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory}
import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap}
import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher}
import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempShuffleFileManager}
import org.apache.spark.network.shuffle.protocol.UploadBlock
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.JavaSerializer
@ -90,14 +89,14 @@ private[spark] class NettyBlockTransferService(
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener,
shuffleFiles: Array[File]): Unit = {
tempShuffleFileManager: TempShuffleFileManager): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
val client = clientFactory.createClient(host, port)
new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener,
transportConf, shuffleFiles).start()
new OneForOneBlockFetcher(client, appId, execId, blockIds, listener,
transportConf, tempShuffleFileManager).start()
}
}

View file

@ -28,7 +28,7 @@ import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.Utils
import org.apache.spark.util.io.ChunkedByteBufferOutputStream
@ -66,7 +66,7 @@ final class ShuffleBlockFetcherIterator(
maxReqsInFlight: Int,
maxReqSizeShuffleToMem: Long,
detectCorrupt: Boolean)
extends Iterator[(BlockId, InputStream)] with Logging {
extends Iterator[(BlockId, InputStream)] with TempShuffleFileManager with Logging {
import ShuffleBlockFetcherIterator._
@ -135,7 +135,8 @@ final class ShuffleBlockFetcherIterator(
* A set to store the files used for shuffling remote huge blocks. Files in this set will be
* deleted when cleanup. This is a layer of defensiveness against disk file leaks.
*/
val shuffleFilesSet = mutable.HashSet[File]()
@GuardedBy("this")
private[this] val shuffleFilesSet = mutable.HashSet[File]()
initialize()
@ -149,6 +150,19 @@ final class ShuffleBlockFetcherIterator(
currentResult = null
}
override def createTempShuffleFile(): File = {
blockManager.diskBlockManager.createTempLocalBlock()._2
}
override def registerTempShuffleFileToClean(file: File): Boolean = synchronized {
if (isZombie) {
false
} else {
shuffleFilesSet += file
true
}
}
/**
* Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
*/
@ -176,7 +190,7 @@ final class ShuffleBlockFetcherIterator(
}
shuffleFilesSet.foreach { file =>
if (!file.delete()) {
logInfo("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath());
logWarning("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath())
}
}
}
@ -221,12 +235,8 @@ final class ShuffleBlockFetcherIterator(
// already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch
// the data and write it to file directly.
if (req.size > maxReqSizeShuffleToMem) {
val shuffleFiles = blockIds.map { _ =>
blockManager.diskBlockManager.createTempLocalBlock()._2
}.toArray
shuffleFilesSet ++= shuffleFiles
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, shuffleFiles)
blockFetchingListener, this)
} else {
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, null)

View file

@ -45,7 +45,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf}
import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap}
import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager}
import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor}
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.LiveListenerBus
@ -1382,7 +1382,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener,
shuffleFiles: Array[File]): Unit = {
tempShuffleFileManager: TempShuffleFileManager): Unit = {
listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1)))
}

View file

@ -33,7 +33,7 @@ import org.scalatest.PrivateMethodTester
import org.apache.spark.{SparkFunSuite, TaskContext}
import org.apache.spark.network._
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.network.shuffle.{BlockFetchingListener, TempShuffleFileManager}
import org.apache.spark.network.util.LimitedInputStream
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.Utils
@ -432,12 +432,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val remoteBlocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer())
val transfer = mock(classOf[BlockTransferService])
var shuffleFiles: Array[File] = null
var tempShuffleFileManager: TempShuffleFileManager = null
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
shuffleFiles = invocation.getArguments()(5).asInstanceOf[Array[File]]
tempShuffleFileManager = invocation.getArguments()(5).asInstanceOf[TempShuffleFileManager]
Future {
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0)))
@ -466,13 +466,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
fetchShuffleBlock(blocksByAddress1)
// `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch
// shuffle block to disk.
assert(shuffleFiles === null)
assert(tempShuffleFileManager == null)
val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq))
fetchShuffleBlock(blocksByAddress2)
// `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch
// shuffle block to disk.
assert(shuffleFiles != null)
assert(tempShuffleFileManager != null)
}
}