[SPARK-29310][CORE][TESTS] TestMemoryManager should implement getExecutionMemoryUsageForTask()

### What changes were proposed in this pull request?

This PR updates `TestMemoryManager`, a class used only in unit tests, to override the `getExecutionMemoryUsageForTask()` and `releaseAllExecutionMemoryForTask()` methods. I also `synchronized` its state-accessing methods (to make the class thread-safe) and added some additional assertions to guard against freeing memory memory than has been allocated.

### Why are the changes needed?

Spark uses a `TestMemoryManager` class to mock out memory manager functionality in tests, allowing test authors to exercise control over certain behaviors (e.g. to simulate OOMs).

Several of Spark's test suites have memory-leak detection to ensure that all allocated memory is cleaned up at the end of each test case; this helps to guard against bugs that could cause production memory leaks. For example, see `testWithMemoryLeakDetection` in `UnsafeFixedWidthAggregationMapSuite`.

Unfortunately, however, this leak-detection logic is broken for tests which use TestMemoryManager because it does not override the `getExecutionMemoryUsageForTask()` method that is used by the leak-detection checks.

This PR fixes that problem, thereby strengthening our existing tests.

I spotted this problem while reviewing #25953: I tried introducing a change to remove a `freePage()` call (purposely inducing a memory leak) but no tests failed.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Added a new `TestMemoryManagerSuite`, with tests covering `TestMemoryManager` itself.

Ran a subset of existing tests on my laptop and uncovered a bug in one test's `free()` calls, plus missing cleanup calls in another test suite; both of these issues are fixed in this PR.

Closes #25985 from JoshRosen/SPARK-29310-testmemorymanager-getExecutionMemoryUsageForTask.

Lead-authored-by: Josh Rosen <rosenville@gmail.com>
Co-authored-by: joshrosen-stripe <48632449+joshrosen-stripe@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Josh Rosen 2019-10-02 11:07:19 +08:00 committed by Wenchen Fan
parent 2ec3265ae7
commit c6938eab57
5 changed files with 139 additions and 31 deletions

View file

@ -54,7 +54,7 @@ public abstract class MemoryConsumer {
/** /**
* Returns the size of used memory in bytes. * Returns the size of used memory in bytes.
*/ */
protected long getUsed() { public long getUsed() {
return used; return used;
} }

View file

@ -692,13 +692,11 @@ public abstract class AbstractBytesToBytesMapSuite {
Thread thread = new Thread(() -> { Thread thread = new Thread(() -> {
int i = 0; int i = 0;
long used = 0;
while (i < 10) { while (i < 10) {
c1.use(10000000); c1.use(10000000);
used += 10000000;
i++; i++;
} }
c1.free(used); c1.free(c1.getUsed());
}); });
try { try {

View file

@ -235,6 +235,9 @@ public class UnsafeExternalSorterSuite {
sorter.insertRecord(null, 0, 0, 0, false); sorter.insertRecord(null, 0, 0, 0, false);
UnsafeSorterIterator iter = sorter.getSortedIterator(); UnsafeSorterIterator iter = sorter.getSortedIterator();
assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime)); assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime));
sorter.cleanupResources();
assertSpillFilesWereCleanedUp();
} }
@Test @Test
@ -510,6 +513,8 @@ public class UnsafeExternalSorterSuite {
verifyIntIterator(sorter.getIterator(79), 79, 300); verifyIntIterator(sorter.getIterator(79), 79, 300);
verifyIntIterator(sorter.getIterator(139), 139, 300); verifyIntIterator(sorter.getIterator(139), 139, 300);
verifyIntIterator(sorter.getIterator(279), 279, 300); verifyIntIterator(sorter.getIterator(279), 279, 300);
sorter.cleanupResources();
assertSpillFilesWereCleanedUp();
} }
@Test @Test

View file

@ -17,60 +17,110 @@
package org.apache.spark.memory package org.apache.spark.memory
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable
import org.apache.spark.SparkConf import org.apache.spark.SparkConf
import org.apache.spark.storage.BlockId import org.apache.spark.storage.BlockId
class TestMemoryManager(conf: SparkConf) class TestMemoryManager(conf: SparkConf)
extends MemoryManager(conf, numCores = 1, Long.MaxValue, Long.MaxValue) { extends MemoryManager(conf, numCores = 1, Long.MaxValue, Long.MaxValue) {
@GuardedBy("this")
private var consequentOOM = 0
@GuardedBy("this")
private var available = Long.MaxValue
@GuardedBy("this")
private val memoryForTask = mutable.HashMap[Long, Long]().withDefaultValue(0L)
override private[memory] def acquireExecutionMemory( override private[memory] def acquireExecutionMemory(
numBytes: Long, numBytes: Long,
taskAttemptId: Long, taskAttemptId: Long,
memoryMode: MemoryMode): Long = { memoryMode: MemoryMode): Long = synchronized {
if (consequentOOM > 0) { require(numBytes >= 0)
consequentOOM -= 1 val acquired = {
0 if (consequentOOM > 0) {
} else if (available >= numBytes) { consequentOOM -= 1
available -= numBytes 0
numBytes } else if (available >= numBytes) {
} else { available -= numBytes
val grant = available numBytes
available = 0 } else {
grant val grant = available
available = 0
grant
}
} }
memoryForTask(taskAttemptId) = memoryForTask.getOrElse(taskAttemptId, 0L) + acquired
acquired
} }
override def acquireStorageMemory(
blockId: BlockId,
numBytes: Long,
memoryMode: MemoryMode): Boolean = true
override def acquireUnrollMemory(
blockId: BlockId,
numBytes: Long,
memoryMode: MemoryMode): Boolean = true
override def releaseStorageMemory(numBytes: Long, memoryMode: MemoryMode): Unit = {}
override private[memory] def releaseExecutionMemory( override private[memory] def releaseExecutionMemory(
numBytes: Long, numBytes: Long,
taskAttemptId: Long, taskAttemptId: Long,
memoryMode: MemoryMode): Unit = { memoryMode: MemoryMode): Unit = synchronized {
require(numBytes >= 0)
available += numBytes available += numBytes
val existingMemoryUsage = memoryForTask.getOrElse(taskAttemptId, 0L)
val newMemoryUsage = existingMemoryUsage - numBytes
require(
newMemoryUsage >= 0,
s"Attempting to free $numBytes of memory for task attempt $taskAttemptId, but it only " +
s"allocated $existingMemoryUsage bytes of memory")
memoryForTask(taskAttemptId) = newMemoryUsage
} }
override private[memory] def releaseAllExecutionMemoryForTask(taskAttemptId: Long): Long = {
memoryForTask.remove(taskAttemptId).getOrElse(0L)
}
override private[memory] def getExecutionMemoryUsageForTask(taskAttemptId: Long): Long = {
memoryForTask.getOrElse(taskAttemptId, 0L)
}
override def acquireStorageMemory(
blockId: BlockId,
numBytes: Long,
memoryMode: MemoryMode): Boolean = {
require(numBytes >= 0)
true
}
override def acquireUnrollMemory(
blockId: BlockId,
numBytes: Long,
memoryMode: MemoryMode): Boolean = {
require(numBytes >= 0)
true
}
override def releaseStorageMemory(numBytes: Long, memoryMode: MemoryMode): Unit = {
require(numBytes >= 0)
}
override def maxOnHeapStorageMemory: Long = Long.MaxValue override def maxOnHeapStorageMemory: Long = Long.MaxValue
override def maxOffHeapStorageMemory: Long = 0L override def maxOffHeapStorageMemory: Long = 0L
private var consequentOOM = 0 /**
private var available = Long.MaxValue * Causes the next call to [[acquireExecutionMemory()]] to fail to allocate
* memory (returning `0`), simulating low-on-memory / out-of-memory conditions.
*/
def markExecutionAsOutOfMemoryOnce(): Unit = { def markExecutionAsOutOfMemoryOnce(): Unit = {
markconsequentOOM(1) markconsequentOOM(1)
} }
def markconsequentOOM(n : Int) : Unit = { /**
* Causes the next `n` calls to [[acquireExecutionMemory()]] to fail to allocate
* memory (returning `0`), simulating low-on-memory / out-of-memory conditions.
*/
def markconsequentOOM(n: Int): Unit = synchronized {
consequentOOM += n consequentOOM += n
} }
def limit(avail: Long): Unit = { def limit(avail: Long): Unit = synchronized {
require(avail >= 0)
available = avail available = avail
} }
} }

View file

@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.memory
import org.apache.spark.{SparkConf, SparkFunSuite}
/**
* Tests of [[TestMemoryManager]] itself.
*/
class TestMemoryManagerSuite extends SparkFunSuite {
test("tracks allocated execution memory by task") {
val testMemoryManager = new TestMemoryManager(new SparkConf())
assert(testMemoryManager.getExecutionMemoryUsageForTask(0) == 0)
assert(testMemoryManager.getExecutionMemoryUsageForTask(1) == 0)
testMemoryManager.acquireExecutionMemory(10, 0, MemoryMode.ON_HEAP)
testMemoryManager.acquireExecutionMemory(5, 1, MemoryMode.ON_HEAP)
testMemoryManager.acquireExecutionMemory(5, 0, MemoryMode.ON_HEAP)
assert(testMemoryManager.getExecutionMemoryUsageForTask(0) == 15)
assert(testMemoryManager.getExecutionMemoryUsageForTask(1) == 5)
testMemoryManager.releaseExecutionMemory(10, 0, MemoryMode.ON_HEAP)
assert(testMemoryManager.getExecutionMemoryUsageForTask(0) == 5)
testMemoryManager.releaseAllExecutionMemoryForTask(0)
testMemoryManager.releaseAllExecutionMemoryForTask(1)
assert(testMemoryManager.getExecutionMemoryUsageForTask(0) == 0)
assert(testMemoryManager.getExecutionMemoryUsageForTask(1) == 0)
}
test("markconsequentOOM") {
val testMemoryManager = new TestMemoryManager(new SparkConf())
assert(testMemoryManager.acquireExecutionMemory(1, 0, MemoryMode.ON_HEAP) == 1)
testMemoryManager.markconsequentOOM(2)
assert(testMemoryManager.acquireExecutionMemory(1, 0, MemoryMode.ON_HEAP) == 0)
assert(testMemoryManager.acquireExecutionMemory(1, 0, MemoryMode.ON_HEAP) == 0)
assert(testMemoryManager.acquireExecutionMemory(1, 0, MemoryMode.ON_HEAP) == 1)
}
}