[SPARK-35543][CORE][FOLLOWUP] Fix memory leak in BlockManagerMasterEndpoint removeRdd

### What changes were proposed in this pull request?

Wrapping `JHashMap[BlockId, BlockStatus]` (used in `blockStatusByShuffleService`) into a new class `BlockStatusPerBlockId` which removes the reference to the map when all the persisted blocks are removed.

### Why are the changes needed?

With https://github.com/apache/spark/pull/32790 a bug is introduced when all the persisted blocks are removed we remove the HashMap which already shared by the block manger infos but when new block is persisted this map is needed to be used again for storing the data (and this HashMap must be the same which shared by the block manger infos created for registered block managers running on the same host where the external shuffle service is).

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Extending `BlockManagerInfoSuite` with test which removes all the persisted blocks then adds another one.

Closes #33020 from attilapiros/SPARK-35543-2.

Authored-by: attilapiros <piros.attila.zsolt@gmail.com>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
This commit is contained in:
attilapiros 2021-06-24 00:01:40 -05:00 committed by Mridul Muralidharan
parent 1cdc56c70d
commit 0bdece015e
2 changed files with 67 additions and 34 deletions

View file

@ -63,7 +63,7 @@ class BlockManagerMasterEndpoint(
// Mapping from external shuffle service block manager id to the block statuses.
private val blockStatusByShuffleService =
new mutable.HashMap[BlockManagerId, JHashMap[BlockId, BlockStatus]]
new mutable.HashMap[BlockManagerId, BlockStatusPerBlockId]
// Mapping from executor ID to block manager ID.
private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId]
@ -278,11 +278,6 @@ class BlockManagerMasterEndpoint(
blockIdsToDel += blockId
blockStatusByShuffleService.get(bmIdForShuffleService).foreach { blockStatusForId =>
blockStatusForId.remove(blockId)
// when all blocks are removed from the block statuses then for this BM Id the whole
// blockStatusByShuffleService entry can be removed to avoid leaking memory
if (blockStatusForId.isEmpty) {
blockStatusByShuffleService.remove(bmIdForShuffleService)
}
}
}
}
@ -569,8 +564,12 @@ class BlockManagerMasterEndpoint(
val externalShuffleServiceBlockStatus =
if (externalShuffleServiceRddFetchEnabled) {
// The blockStatusByShuffleService entries are never removed as they belong to the
// external shuffle service instances running on the cluster nodes. To decrease its
// memory footprint when all the disk persisted blocks are removed for a shuffle service
// BlockStatusPerBlockId releases the backing HashMap.
val externalShuffleServiceBlocks = blockStatusByShuffleService
.getOrElseUpdate(externalShuffleServiceIdOnHost(id), new JHashMap[BlockId, BlockStatus])
.getOrElseUpdate(externalShuffleServiceIdOnHost(id), new BlockStatusPerBlockId)
Some(externalShuffleServiceBlocks)
} else {
None
@ -671,7 +670,7 @@ class BlockManagerMasterEndpoint(
val locations = Option(blockLocations.get(blockId)).map(_.toSeq).getOrElse(Seq.empty)
val status = locations.headOption.flatMap { bmId =>
if (externalShuffleServiceRddFetchEnabled && bmId.port == externalShuffleServicePort) {
blockStatusByShuffleService.get(bmId).flatMap(m => Option(m.get(blockId)))
blockStatusByShuffleService.get(bmId).flatMap(m => m.get(blockId))
} else {
aliveBlockManagerInfo(bmId).flatMap(_.getStatus(blockId))
}
@ -794,19 +793,44 @@ object BlockStatus {
def empty: BlockStatus = BlockStatus(StorageLevel.NONE, memSize = 0L, diskSize = 0L)
}
/**
* Stores block statuses for block IDs but removes the reference to the Map which used for storing
* the data when all the blocks are removed to avoid keeping the memory when not needed.
*/
private[spark] class BlockStatusPerBlockId {
private var blocks: JHashMap[BlockId, BlockStatus] = _
def get(blockId: BlockId): Option[BlockStatus] =
if (blocks == null) None else Option(blocks.get(blockId))
def put(blockId: BlockId, blockStatus: BlockStatus): Unit = {
if (blocks == null) {
blocks = new JHashMap[BlockId, BlockStatus]
}
blocks.put(blockId, blockStatus)
}
def remove(blockId: BlockId): Unit = {
blocks.remove(blockId)
if (blocks.isEmpty) {
blocks = null
}
}
}
private[spark] class BlockManagerInfo(
val blockManagerId: BlockManagerId,
timeMs: Long,
val maxOnHeapMem: Long,
val maxOffHeapMem: Long,
val storageEndpoint: RpcEndpointRef,
val externalShuffleServiceBlockStatus: Option[JHashMap[BlockId, BlockStatus]])
val externalShuffleServiceBlockStatus: Option[BlockStatusPerBlockId])
extends Logging {
val maxMem = maxOnHeapMem + maxOffHeapMem
val externalShuffleServiceEnabled = externalShuffleServiceBlockStatus.isDefined
private var _lastSeenMs: Long = timeMs
private var _remainingMem: Long = maxMem
private var _executorRemovalTs: Option[Long] = None

View file

@ -17,15 +17,13 @@
package org.apache.spark.storage
import java.util.{HashMap => JHashMap}
import scala.collection.JavaConverters._
import org.apache.spark.SparkFunSuite
class BlockManagerInfoSuite extends SparkFunSuite {
def testWithShuffleServiceOnOff(testName: String)
private def testWithShuffleServiceOnOff(testName: String)
(f: (Boolean, BlockManagerInfo) => Unit): Unit = {
Seq(true, false).foreach { svcEnabled =>
val bmInfo = new BlockManagerInfo(
@ -34,13 +32,19 @@ class BlockManagerInfoSuite extends SparkFunSuite {
maxOnHeapMem = 10000,
maxOffHeapMem = 20000,
storageEndpoint = null,
if (svcEnabled) Some(new JHashMap[BlockId, BlockStatus]) else None)
if (svcEnabled) Some(new BlockStatusPerBlockId) else None)
test(s"$testName externalShuffleServiceEnabled=$svcEnabled") {
f(svcEnabled, bmInfo)
}
}
}
private def getEssBlockStatus(bmInfo: BlockManagerInfo, blockId: BlockId): Option[BlockStatus] = {
assert(bmInfo.externalShuffleServiceBlockStatus.isDefined)
val blockStatusPerBlockId = bmInfo.externalShuffleServiceBlockStatus.get
blockStatusPerBlockId.get(blockId)
}
testWithShuffleServiceOnOff("broadcast block") { (_, bmInfo) =>
val broadcastId: BlockId = BroadcastBlockId(0, "field1")
bmInfo.updateBlockInfo(
@ -57,7 +61,7 @@ class BlockManagerInfoSuite extends SparkFunSuite {
Map(rddId -> BlockStatus(StorageLevel.MEMORY_ONLY, 200, 0)))
assert(bmInfo.remainingMem === 29800)
if (svcEnabled) {
assert(bmInfo.externalShuffleServiceBlockStatus.get.isEmpty)
assert(getEssBlockStatus(bmInfo, rddId).isEmpty)
}
}
@ -70,8 +74,8 @@ class BlockManagerInfoSuite extends SparkFunSuite {
Map(rddId -> BlockStatus(StorageLevel.MEMORY_AND_DISK, 0, 400)))
assert(bmInfo.remainingMem === 29800)
if (svcEnabled) {
assert(bmInfo.externalShuffleServiceBlockStatus.get.asScala ===
Map(rddId -> BlockStatus(StorageLevel.MEMORY_AND_DISK, 0, 400)))
assert(getEssBlockStatus(bmInfo, rddId) ===
Some(BlockStatus(StorageLevel.MEMORY_AND_DISK, 0, 400)))
}
}
@ -83,8 +87,7 @@ class BlockManagerInfoSuite extends SparkFunSuite {
val exclusiveCachedBlocksForOneMemoryOnly = if (svcEnabled) Set() else Set(rddId)
assert(bmInfo.remainingMem === 30000)
if (svcEnabled) {
assert(bmInfo.externalShuffleServiceBlockStatus.get.asScala ===
Map(rddId -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200)))
assert(getEssBlockStatus(bmInfo, rddId) === Some(BlockStatus(StorageLevel.DISK_ONLY, 0, 200)))
}
}
@ -96,15 +99,14 @@ class BlockManagerInfoSuite extends SparkFunSuite {
assert(bmInfo.blocks.asScala === Map(rddId -> BlockStatus(StorageLevel.MEMORY_ONLY, 200, 0)))
assert(bmInfo.remainingMem === 29800)
if (svcEnabled) {
assert(bmInfo.externalShuffleServiceBlockStatus.get.isEmpty)
assert(getEssBlockStatus(bmInfo, rddId).isEmpty)
}
bmInfo.updateBlockInfo(rddId, StorageLevel.DISK_ONLY, memSize = 0, diskSize = 200)
assert(bmInfo.blocks.asScala === Map(rddId -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200)))
assert(bmInfo.remainingMem === 30000)
if (svcEnabled) {
assert(bmInfo.externalShuffleServiceBlockStatus.get.asScala ===
Map(rddId -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200)))
assert(getEssBlockStatus(bmInfo, rddId) === Some(BlockStatus(StorageLevel.DISK_ONLY, 0, 200)))
}
}
@ -114,33 +116,40 @@ class BlockManagerInfoSuite extends SparkFunSuite {
assert(bmInfo.blocks.asScala === Map(rddId -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200)))
assert(bmInfo.remainingMem === 30000)
if (svcEnabled) {
assert(bmInfo.externalShuffleServiceBlockStatus.get.asScala ===
Map(rddId -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200)))
assert(getEssBlockStatus(bmInfo, rddId) === Some(BlockStatus(StorageLevel.DISK_ONLY, 0, 200)))
}
bmInfo.updateBlockInfo(rddId, StorageLevel.NONE, memSize = 0, diskSize = 200)
assert(bmInfo.blocks.isEmpty)
assert(bmInfo.remainingMem === 30000)
if (svcEnabled) {
assert(bmInfo.externalShuffleServiceBlockStatus.get.isEmpty)
assert(getEssBlockStatus(bmInfo, rddId).isEmpty)
}
}
testWithShuffleServiceOnOff("remove block") { (svcEnabled, bmInfo) =>
val rddId: BlockId = RDDBlockId(0, 0)
bmInfo.updateBlockInfo(rddId, StorageLevel.DISK_ONLY, memSize = 0, diskSize = 200)
assert(bmInfo.blocks.asScala === Map(rddId -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200)))
testWithShuffleServiceOnOff("remove block and add another one") { (svcEnabled, bmInfo) =>
val rddId1: BlockId = RDDBlockId(0, 0)
bmInfo.updateBlockInfo(rddId1, StorageLevel.DISK_ONLY, memSize = 0, diskSize = 200)
assert(bmInfo.blocks.asScala === Map(rddId1 -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200)))
assert(bmInfo.remainingMem === 30000)
if (svcEnabled) {
assert(bmInfo.externalShuffleServiceBlockStatus.get.asScala ===
Map(rddId -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200)))
assert(getEssBlockStatus(bmInfo, rddId1) ===
Some(BlockStatus(StorageLevel.DISK_ONLY, 0, 200)))
}
bmInfo.removeBlock(rddId)
bmInfo.removeBlock(rddId1)
assert(bmInfo.blocks.asScala.isEmpty)
assert(bmInfo.remainingMem === 30000)
if (svcEnabled) {
assert(bmInfo.externalShuffleServiceBlockStatus.get.isEmpty)
assert(getEssBlockStatus(bmInfo, rddId1).isEmpty)
}
val rddId2: BlockId = RDDBlockId(0, 1)
bmInfo.updateBlockInfo(rddId2, StorageLevel.DISK_ONLY, memSize = 0, diskSize = 200)
assert(bmInfo.blocks.asScala === Map(rddId2 -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200)))
assert(bmInfo.remainingMem === 30000)
if (svcEnabled) {
assert(getEssBlockStatus(bmInfo, rddId2) ===
Some(BlockStatus(StorageLevel.DISK_ONLY, 0, 200)))
}
}
}