diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 873330e136..bcf65e9d7e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -133,6 +133,9 @@ private[spark] class BlockManager( private val compressRdds = conf.getBoolean("spark.rdd.compress", false) // Whether to compress shuffle output temporarily spilled to disk private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) + // Max number of failures before this block manager refreshes the block locations from the driver + private val maxFailuresBeforeLocationRefresh = + conf.getInt("spark.block.failures.beforeLocationRefresh", 5) private val slaveEndpoint = rpcEnv.setupEndpoint( "BlockManagerEndpoint" + BlockManager.ID_GENERATOR.next, @@ -568,26 +571,46 @@ private[spark] class BlockManager( def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = { logDebug(s"Getting remote block $blockId") require(blockId != null, "BlockId is null") + var runningFailureCount = 0 + var totalFailureCount = 0 val locations = getLocations(blockId) - var numFetchFailures = 0 - for (loc <- locations) { + val maxFetchFailures = locations.size + var locationIterator = locations.iterator + while (locationIterator.hasNext) { + val loc = locationIterator.next() logDebug(s"Getting remote block $blockId from $loc") val data = try { blockTransferService.fetchBlockSync( loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() } catch { case NonFatal(e) => - numFetchFailures += 1 - if (numFetchFailures == locations.size) { - // An exception is thrown while fetching this block from all locations - throw new BlockFetchException(s"Failed to fetch block from" + - s" ${locations.size} locations. Most recent failure cause:", e) - } else { - // This location failed, so we retry fetch from a different one by returning null here - logWarning(s"Failed to fetch remote block $blockId " + - s"from $loc (failed attempt $numFetchFailures)", e) - null + runningFailureCount += 1 + totalFailureCount += 1 + + if (totalFailureCount >= maxFetchFailures) { + // Give up trying anymore locations. Either we've tried all of the original locations, + // or we've refreshed the list of locations from the master, and have still + // hit failures after trying locations from the refreshed list. + throw new BlockFetchException(s"Failed to fetch block after" + + s" ${totalFailureCount} fetch failures. Most recent failure cause:", e) } + + logWarning(s"Failed to fetch remote block $blockId " + + s"from $loc (failed attempt $runningFailureCount)", e) + + // If there is a large number of executors then locations list can contain a + // large number of stale entries causing a large number of retries that may + // take a significant amount of time. To get rid of these stale entries + // we refresh the block locations after a certain number of fetch failures + if (runningFailureCount >= maxFailuresBeforeLocationRefresh) { + locationIterator = getLocations(blockId).iterator + logDebug(s"Refreshed locations from the driver " + + s"after ${runningFailureCount} fetch failures.") + runningFailureCount = 0 + } + + // This location failed, so we retry fetch from a different one by returning null here + null } if (data != null) { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 42595c8cf2..dc4be14677 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -21,11 +21,12 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ +import scala.concurrent.Future import scala.language.implicitConversions import scala.language.postfixOps import org.mockito.{Matchers => mc} -import org.mockito.Mockito.{mock, when} +import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ @@ -33,7 +34,10 @@ import org.scalatest.concurrent.Timeouts._ import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.memory.StaticMemoryManager +import org.apache.spark.network.{BlockDataManager, BlockTransferService} +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.NettyBlockTransferService +import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} @@ -66,9 +70,11 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE private def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER, - master: BlockManagerMaster = this.master): BlockManager = { + master: BlockManagerMaster = this.master, + transferService: Option[BlockTransferService] = Option.empty): BlockManager = { val serializer = new KryoSerializer(conf) - val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) + val transfer = transferService + .getOrElse(new NettyBlockTransferService(conf, securityMgr, numCores = 1)) val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf, memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) @@ -1287,6 +1293,78 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getSingle("a1").isDefined, "a1 was not in store") assert(store.getSingle("a3").isDefined, "a3 was not in store") } + + test("SPARK-13328: refresh block locations (fetch should fail after hitting a threshold)") { + val mockBlockTransferService = + new MockBlockTransferService(conf.getInt("spark.block.failures.beforeLocationRefresh", 5)) + store = makeBlockManager(8000, "executor1", transferService = Option(mockBlockTransferService)) + store.putSingle("item", 999L, StorageLevel.MEMORY_ONLY, tellMaster = true) + intercept[BlockFetchException] { + store.getRemoteBytes("item") + } + } + + test("SPARK-13328: refresh block locations (fetch should succeed after location refresh)") { + val maxFailuresBeforeLocationRefresh = + conf.getInt("spark.block.failures.beforeLocationRefresh", 5) + val mockBlockManagerMaster = mock(classOf[BlockManagerMaster]) + val mockBlockTransferService = + new MockBlockTransferService(maxFailuresBeforeLocationRefresh) + // make sure we have more than maxFailuresBeforeLocationRefresh locations + // so that we have a chance to do location refresh + val blockManagerIds = (0 to maxFailuresBeforeLocationRefresh) + .map { i => BlockManagerId(s"id-$i", s"host-$i", i + 1) } + when(mockBlockManagerMaster.getLocations(mc.any[BlockId])).thenReturn(blockManagerIds) + store = makeBlockManager(8000, "executor1", mockBlockManagerMaster, + transferService = Option(mockBlockTransferService)) + val block = store.getRemoteBytes("item") + .asInstanceOf[Option[ByteBuffer]] + assert(block.isDefined) + verify(mockBlockManagerMaster, times(2)).getLocations("item") + } + + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { + var numCalls = 0 + + override def init(blockDataManager: BlockDataManager): Unit = {} + + override def fetchBlocks( + host: String, + port: Int, + execId: String, + blockIds: Array[String], + listener: BlockFetchingListener): Unit = { + listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) + } + + override def close(): Unit = {} + + override def hostName: String = { "MockBlockTransferServiceHost" } + + override def port: Int = { 63332 } + + override def uploadBlock( + hostname: String, + port: Int, execId: String, + blockId: BlockId, + blockData: ManagedBuffer, + level: StorageLevel): Future[Unit] = { + import scala.concurrent.ExecutionContext.Implicits.global + Future {} + } + + override def fetchBlockSync( + host: String, + port: Int, + execId: String, + blockId: String): ManagedBuffer = { + numCalls += 1 + if (numCalls <= maxFailures) { + throw new RuntimeException("Failing block fetch in the mock block transfer service") + } + super.fetchBlockSync(host, port, execId, blockId) + } + } } private object BlockManagerSuite {