[SPARK-29310][CORE][TESTS] TestMemoryManager should implement getExecutionMemoryUsageForTask()
### What changes were proposed in this pull request? This PR updates `TestMemoryManager`, a class used only in unit tests, to override the `getExecutionMemoryUsageForTask()` and `releaseAllExecutionMemoryForTask()` methods. I also `synchronized` its state-accessing methods (to make the class thread-safe) and added some additional assertions to guard against freeing memory memory than has been allocated. ### Why are the changes needed? Spark uses a `TestMemoryManager` class to mock out memory manager functionality in tests, allowing test authors to exercise control over certain behaviors (e.g. to simulate OOMs). Several of Spark's test suites have memory-leak detection to ensure that all allocated memory is cleaned up at the end of each test case; this helps to guard against bugs that could cause production memory leaks. For example, see `testWithMemoryLeakDetection` in `UnsafeFixedWidthAggregationMapSuite`. Unfortunately, however, this leak-detection logic is broken for tests which use TestMemoryManager because it does not override the `getExecutionMemoryUsageForTask()` method that is used by the leak-detection checks. This PR fixes that problem, thereby strengthening our existing tests. I spotted this problem while reviewing #25953: I tried introducing a change to remove a `freePage()` call (purposely inducing a memory leak) but no tests failed. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Added a new `TestMemoryManagerSuite`, with tests covering `TestMemoryManager` itself. Ran a subset of existing tests on my laptop and uncovered a bug in one test's `free()` calls, plus missing cleanup calls in another test suite; both of these issues are fixed in this PR. Closes #25985 from JoshRosen/SPARK-29310-testmemorymanager-getExecutionMemoryUsageForTask. Lead-authored-by: Josh Rosen <rosenville@gmail.com> Co-authored-by: joshrosen-stripe <48632449+joshrosen-stripe@users.noreply.github.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
2ec3265ae7
commit
c6938eab57
|
@ -54,7 +54,7 @@ public abstract class MemoryConsumer {
|
|||
/**
|
||||
* Returns the size of used memory in bytes.
|
||||
*/
|
||||
protected long getUsed() {
|
||||
public long getUsed() {
|
||||
return used;
|
||||
}
|
||||
|
||||
|
|
|
@ -692,13 +692,11 @@ public abstract class AbstractBytesToBytesMapSuite {
|
|||
|
||||
Thread thread = new Thread(() -> {
|
||||
int i = 0;
|
||||
long used = 0;
|
||||
while (i < 10) {
|
||||
c1.use(10000000);
|
||||
used += 10000000;
|
||||
i++;
|
||||
}
|
||||
c1.free(used);
|
||||
c1.free(c1.getUsed());
|
||||
});
|
||||
|
||||
try {
|
||||
|
|
|
@ -235,6 +235,9 @@ public class UnsafeExternalSorterSuite {
|
|||
sorter.insertRecord(null, 0, 0, 0, false);
|
||||
UnsafeSorterIterator iter = sorter.getSortedIterator();
|
||||
assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime));
|
||||
|
||||
sorter.cleanupResources();
|
||||
assertSpillFilesWereCleanedUp();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -510,6 +513,8 @@ public class UnsafeExternalSorterSuite {
|
|||
verifyIntIterator(sorter.getIterator(79), 79, 300);
|
||||
verifyIntIterator(sorter.getIterator(139), 139, 300);
|
||||
verifyIntIterator(sorter.getIterator(279), 279, 300);
|
||||
sorter.cleanupResources();
|
||||
assertSpillFilesWereCleanedUp();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -17,60 +17,110 @@
|
|||
|
||||
package org.apache.spark.memory
|
||||
|
||||
import javax.annotation.concurrent.GuardedBy
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.storage.BlockId
|
||||
|
||||
class TestMemoryManager(conf: SparkConf)
|
||||
extends MemoryManager(conf, numCores = 1, Long.MaxValue, Long.MaxValue) {
|
||||
|
||||
@GuardedBy("this")
|
||||
private var consequentOOM = 0
|
||||
@GuardedBy("this")
|
||||
private var available = Long.MaxValue
|
||||
@GuardedBy("this")
|
||||
private val memoryForTask = mutable.HashMap[Long, Long]().withDefaultValue(0L)
|
||||
|
||||
override private[memory] def acquireExecutionMemory(
|
||||
numBytes: Long,
|
||||
taskAttemptId: Long,
|
||||
memoryMode: MemoryMode): Long = {
|
||||
if (consequentOOM > 0) {
|
||||
consequentOOM -= 1
|
||||
0
|
||||
} else if (available >= numBytes) {
|
||||
available -= numBytes
|
||||
numBytes
|
||||
} else {
|
||||
val grant = available
|
||||
available = 0
|
||||
grant
|
||||
memoryMode: MemoryMode): Long = synchronized {
|
||||
require(numBytes >= 0)
|
||||
val acquired = {
|
||||
if (consequentOOM > 0) {
|
||||
consequentOOM -= 1
|
||||
0
|
||||
} else if (available >= numBytes) {
|
||||
available -= numBytes
|
||||
numBytes
|
||||
} else {
|
||||
val grant = available
|
||||
available = 0
|
||||
grant
|
||||
}
|
||||
}
|
||||
memoryForTask(taskAttemptId) = memoryForTask.getOrElse(taskAttemptId, 0L) + acquired
|
||||
acquired
|
||||
}
|
||||
override def acquireStorageMemory(
|
||||
blockId: BlockId,
|
||||
numBytes: Long,
|
||||
memoryMode: MemoryMode): Boolean = true
|
||||
override def acquireUnrollMemory(
|
||||
blockId: BlockId,
|
||||
numBytes: Long,
|
||||
memoryMode: MemoryMode): Boolean = true
|
||||
override def releaseStorageMemory(numBytes: Long, memoryMode: MemoryMode): Unit = {}
|
||||
|
||||
override private[memory] def releaseExecutionMemory(
|
||||
numBytes: Long,
|
||||
taskAttemptId: Long,
|
||||
memoryMode: MemoryMode): Unit = {
|
||||
memoryMode: MemoryMode): Unit = synchronized {
|
||||
require(numBytes >= 0)
|
||||
available += numBytes
|
||||
val existingMemoryUsage = memoryForTask.getOrElse(taskAttemptId, 0L)
|
||||
val newMemoryUsage = existingMemoryUsage - numBytes
|
||||
require(
|
||||
newMemoryUsage >= 0,
|
||||
s"Attempting to free $numBytes of memory for task attempt $taskAttemptId, but it only " +
|
||||
s"allocated $existingMemoryUsage bytes of memory")
|
||||
memoryForTask(taskAttemptId) = newMemoryUsage
|
||||
}
|
||||
|
||||
override private[memory] def releaseAllExecutionMemoryForTask(taskAttemptId: Long): Long = {
|
||||
memoryForTask.remove(taskAttemptId).getOrElse(0L)
|
||||
}
|
||||
|
||||
override private[memory] def getExecutionMemoryUsageForTask(taskAttemptId: Long): Long = {
|
||||
memoryForTask.getOrElse(taskAttemptId, 0L)
|
||||
}
|
||||
|
||||
override def acquireStorageMemory(
|
||||
blockId: BlockId,
|
||||
numBytes: Long,
|
||||
memoryMode: MemoryMode): Boolean = {
|
||||
require(numBytes >= 0)
|
||||
true
|
||||
}
|
||||
|
||||
override def acquireUnrollMemory(
|
||||
blockId: BlockId,
|
||||
numBytes: Long,
|
||||
memoryMode: MemoryMode): Boolean = {
|
||||
require(numBytes >= 0)
|
||||
true
|
||||
}
|
||||
|
||||
override def releaseStorageMemory(numBytes: Long, memoryMode: MemoryMode): Unit = {
|
||||
require(numBytes >= 0)
|
||||
}
|
||||
|
||||
override def maxOnHeapStorageMemory: Long = Long.MaxValue
|
||||
|
||||
override def maxOffHeapStorageMemory: Long = 0L
|
||||
|
||||
private var consequentOOM = 0
|
||||
private var available = Long.MaxValue
|
||||
|
||||
/**
|
||||
* Causes the next call to [[acquireExecutionMemory()]] to fail to allocate
|
||||
* memory (returning `0`), simulating low-on-memory / out-of-memory conditions.
|
||||
*/
|
||||
def markExecutionAsOutOfMemoryOnce(): Unit = {
|
||||
markconsequentOOM(1)
|
||||
}
|
||||
|
||||
def markconsequentOOM(n : Int) : Unit = {
|
||||
/**
|
||||
* Causes the next `n` calls to [[acquireExecutionMemory()]] to fail to allocate
|
||||
* memory (returning `0`), simulating low-on-memory / out-of-memory conditions.
|
||||
*/
|
||||
def markconsequentOOM(n: Int): Unit = synchronized {
|
||||
consequentOOM += n
|
||||
}
|
||||
|
||||
def limit(avail: Long): Unit = {
|
||||
def limit(avail: Long): Unit = synchronized {
|
||||
require(avail >= 0)
|
||||
available = avail
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.memory
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkFunSuite}
|
||||
|
||||
/**
|
||||
* Tests of [[TestMemoryManager]] itself.
|
||||
*/
|
||||
class TestMemoryManagerSuite extends SparkFunSuite {
|
||||
test("tracks allocated execution memory by task") {
|
||||
val testMemoryManager = new TestMemoryManager(new SparkConf())
|
||||
|
||||
assert(testMemoryManager.getExecutionMemoryUsageForTask(0) == 0)
|
||||
assert(testMemoryManager.getExecutionMemoryUsageForTask(1) == 0)
|
||||
|
||||
testMemoryManager.acquireExecutionMemory(10, 0, MemoryMode.ON_HEAP)
|
||||
testMemoryManager.acquireExecutionMemory(5, 1, MemoryMode.ON_HEAP)
|
||||
testMemoryManager.acquireExecutionMemory(5, 0, MemoryMode.ON_HEAP)
|
||||
assert(testMemoryManager.getExecutionMemoryUsageForTask(0) == 15)
|
||||
assert(testMemoryManager.getExecutionMemoryUsageForTask(1) == 5)
|
||||
|
||||
testMemoryManager.releaseExecutionMemory(10, 0, MemoryMode.ON_HEAP)
|
||||
assert(testMemoryManager.getExecutionMemoryUsageForTask(0) == 5)
|
||||
|
||||
testMemoryManager.releaseAllExecutionMemoryForTask(0)
|
||||
testMemoryManager.releaseAllExecutionMemoryForTask(1)
|
||||
assert(testMemoryManager.getExecutionMemoryUsageForTask(0) == 0)
|
||||
assert(testMemoryManager.getExecutionMemoryUsageForTask(1) == 0)
|
||||
}
|
||||
|
||||
test("markconsequentOOM") {
|
||||
val testMemoryManager = new TestMemoryManager(new SparkConf())
|
||||
assert(testMemoryManager.acquireExecutionMemory(1, 0, MemoryMode.ON_HEAP) == 1)
|
||||
testMemoryManager.markconsequentOOM(2)
|
||||
assert(testMemoryManager.acquireExecutionMemory(1, 0, MemoryMode.ON_HEAP) == 0)
|
||||
assert(testMemoryManager.acquireExecutionMemory(1, 0, MemoryMode.ON_HEAP) == 0)
|
||||
assert(testMemoryManager.acquireExecutionMemory(1, 0, MemoryMode.ON_HEAP) == 1)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue