[SPARK-29310][CORE][TESTS] TestMemoryManager should implement getExecutionMemoryUsageForTask()
### What changes were proposed in this pull request? This PR updates `TestMemoryManager`, a class used only in unit tests, to override the `getExecutionMemoryUsageForTask()` and `releaseAllExecutionMemoryForTask()` methods. I also `synchronized` its state-accessing methods (to make the class thread-safe) and added some additional assertions to guard against freeing memory memory than has been allocated. ### Why are the changes needed? Spark uses a `TestMemoryManager` class to mock out memory manager functionality in tests, allowing test authors to exercise control over certain behaviors (e.g. to simulate OOMs). Several of Spark's test suites have memory-leak detection to ensure that all allocated memory is cleaned up at the end of each test case; this helps to guard against bugs that could cause production memory leaks. For example, see `testWithMemoryLeakDetection` in `UnsafeFixedWidthAggregationMapSuite`. Unfortunately, however, this leak-detection logic is broken for tests which use TestMemoryManager because it does not override the `getExecutionMemoryUsageForTask()` method that is used by the leak-detection checks. This PR fixes that problem, thereby strengthening our existing tests. I spotted this problem while reviewing #25953: I tried introducing a change to remove a `freePage()` call (purposely inducing a memory leak) but no tests failed. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Added a new `TestMemoryManagerSuite`, with tests covering `TestMemoryManager` itself. Ran a subset of existing tests on my laptop and uncovered a bug in one test's `free()` calls, plus missing cleanup calls in another test suite; both of these issues are fixed in this PR. Closes #25985 from JoshRosen/SPARK-29310-testmemorymanager-getExecutionMemoryUsageForTask. Lead-authored-by: Josh Rosen <rosenville@gmail.com> Co-authored-by: joshrosen-stripe <48632449+joshrosen-stripe@users.noreply.github.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
2ec3265ae7
commit
c6938eab57
|
@ -54,7 +54,7 @@ public abstract class MemoryConsumer {
|
||||||
/**
|
/**
|
||||||
* Returns the size of used memory in bytes.
|
* Returns the size of used memory in bytes.
|
||||||
*/
|
*/
|
||||||
protected long getUsed() {
|
public long getUsed() {
|
||||||
return used;
|
return used;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -692,13 +692,11 @@ public abstract class AbstractBytesToBytesMapSuite {
|
||||||
|
|
||||||
Thread thread = new Thread(() -> {
|
Thread thread = new Thread(() -> {
|
||||||
int i = 0;
|
int i = 0;
|
||||||
long used = 0;
|
|
||||||
while (i < 10) {
|
while (i < 10) {
|
||||||
c1.use(10000000);
|
c1.use(10000000);
|
||||||
used += 10000000;
|
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
c1.free(used);
|
c1.free(c1.getUsed());
|
||||||
});
|
});
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
|
@ -235,6 +235,9 @@ public class UnsafeExternalSorterSuite {
|
||||||
sorter.insertRecord(null, 0, 0, 0, false);
|
sorter.insertRecord(null, 0, 0, 0, false);
|
||||||
UnsafeSorterIterator iter = sorter.getSortedIterator();
|
UnsafeSorterIterator iter = sorter.getSortedIterator();
|
||||||
assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime));
|
assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime));
|
||||||
|
|
||||||
|
sorter.cleanupResources();
|
||||||
|
assertSpillFilesWereCleanedUp();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -510,6 +513,8 @@ public class UnsafeExternalSorterSuite {
|
||||||
verifyIntIterator(sorter.getIterator(79), 79, 300);
|
verifyIntIterator(sorter.getIterator(79), 79, 300);
|
||||||
verifyIntIterator(sorter.getIterator(139), 139, 300);
|
verifyIntIterator(sorter.getIterator(139), 139, 300);
|
||||||
verifyIntIterator(sorter.getIterator(279), 279, 300);
|
verifyIntIterator(sorter.getIterator(279), 279, 300);
|
||||||
|
sorter.cleanupResources();
|
||||||
|
assertSpillFilesWereCleanedUp();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -17,60 +17,110 @@
|
||||||
|
|
||||||
package org.apache.spark.memory
|
package org.apache.spark.memory
|
||||||
|
|
||||||
|
import javax.annotation.concurrent.GuardedBy
|
||||||
|
|
||||||
|
import scala.collection.mutable
|
||||||
|
|
||||||
import org.apache.spark.SparkConf
|
import org.apache.spark.SparkConf
|
||||||
import org.apache.spark.storage.BlockId
|
import org.apache.spark.storage.BlockId
|
||||||
|
|
||||||
class TestMemoryManager(conf: SparkConf)
|
class TestMemoryManager(conf: SparkConf)
|
||||||
extends MemoryManager(conf, numCores = 1, Long.MaxValue, Long.MaxValue) {
|
extends MemoryManager(conf, numCores = 1, Long.MaxValue, Long.MaxValue) {
|
||||||
|
|
||||||
|
@GuardedBy("this")
|
||||||
|
private var consequentOOM = 0
|
||||||
|
@GuardedBy("this")
|
||||||
|
private var available = Long.MaxValue
|
||||||
|
@GuardedBy("this")
|
||||||
|
private val memoryForTask = mutable.HashMap[Long, Long]().withDefaultValue(0L)
|
||||||
|
|
||||||
override private[memory] def acquireExecutionMemory(
|
override private[memory] def acquireExecutionMemory(
|
||||||
numBytes: Long,
|
numBytes: Long,
|
||||||
taskAttemptId: Long,
|
taskAttemptId: Long,
|
||||||
memoryMode: MemoryMode): Long = {
|
memoryMode: MemoryMode): Long = synchronized {
|
||||||
if (consequentOOM > 0) {
|
require(numBytes >= 0)
|
||||||
consequentOOM -= 1
|
val acquired = {
|
||||||
0
|
if (consequentOOM > 0) {
|
||||||
} else if (available >= numBytes) {
|
consequentOOM -= 1
|
||||||
available -= numBytes
|
0
|
||||||
numBytes
|
} else if (available >= numBytes) {
|
||||||
} else {
|
available -= numBytes
|
||||||
val grant = available
|
numBytes
|
||||||
available = 0
|
} else {
|
||||||
grant
|
val grant = available
|
||||||
|
available = 0
|
||||||
|
grant
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
memoryForTask(taskAttemptId) = memoryForTask.getOrElse(taskAttemptId, 0L) + acquired
|
||||||
|
acquired
|
||||||
}
|
}
|
||||||
override def acquireStorageMemory(
|
|
||||||
blockId: BlockId,
|
|
||||||
numBytes: Long,
|
|
||||||
memoryMode: MemoryMode): Boolean = true
|
|
||||||
override def acquireUnrollMemory(
|
|
||||||
blockId: BlockId,
|
|
||||||
numBytes: Long,
|
|
||||||
memoryMode: MemoryMode): Boolean = true
|
|
||||||
override def releaseStorageMemory(numBytes: Long, memoryMode: MemoryMode): Unit = {}
|
|
||||||
override private[memory] def releaseExecutionMemory(
|
override private[memory] def releaseExecutionMemory(
|
||||||
numBytes: Long,
|
numBytes: Long,
|
||||||
taskAttemptId: Long,
|
taskAttemptId: Long,
|
||||||
memoryMode: MemoryMode): Unit = {
|
memoryMode: MemoryMode): Unit = synchronized {
|
||||||
|
require(numBytes >= 0)
|
||||||
available += numBytes
|
available += numBytes
|
||||||
|
val existingMemoryUsage = memoryForTask.getOrElse(taskAttemptId, 0L)
|
||||||
|
val newMemoryUsage = existingMemoryUsage - numBytes
|
||||||
|
require(
|
||||||
|
newMemoryUsage >= 0,
|
||||||
|
s"Attempting to free $numBytes of memory for task attempt $taskAttemptId, but it only " +
|
||||||
|
s"allocated $existingMemoryUsage bytes of memory")
|
||||||
|
memoryForTask(taskAttemptId) = newMemoryUsage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override private[memory] def releaseAllExecutionMemoryForTask(taskAttemptId: Long): Long = {
|
||||||
|
memoryForTask.remove(taskAttemptId).getOrElse(0L)
|
||||||
|
}
|
||||||
|
|
||||||
|
override private[memory] def getExecutionMemoryUsageForTask(taskAttemptId: Long): Long = {
|
||||||
|
memoryForTask.getOrElse(taskAttemptId, 0L)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def acquireStorageMemory(
|
||||||
|
blockId: BlockId,
|
||||||
|
numBytes: Long,
|
||||||
|
memoryMode: MemoryMode): Boolean = {
|
||||||
|
require(numBytes >= 0)
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
override def acquireUnrollMemory(
|
||||||
|
blockId: BlockId,
|
||||||
|
numBytes: Long,
|
||||||
|
memoryMode: MemoryMode): Boolean = {
|
||||||
|
require(numBytes >= 0)
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
override def releaseStorageMemory(numBytes: Long, memoryMode: MemoryMode): Unit = {
|
||||||
|
require(numBytes >= 0)
|
||||||
|
}
|
||||||
|
|
||||||
override def maxOnHeapStorageMemory: Long = Long.MaxValue
|
override def maxOnHeapStorageMemory: Long = Long.MaxValue
|
||||||
|
|
||||||
override def maxOffHeapStorageMemory: Long = 0L
|
override def maxOffHeapStorageMemory: Long = 0L
|
||||||
|
|
||||||
private var consequentOOM = 0
|
/**
|
||||||
private var available = Long.MaxValue
|
* Causes the next call to [[acquireExecutionMemory()]] to fail to allocate
|
||||||
|
* memory (returning `0`), simulating low-on-memory / out-of-memory conditions.
|
||||||
|
*/
|
||||||
def markExecutionAsOutOfMemoryOnce(): Unit = {
|
def markExecutionAsOutOfMemoryOnce(): Unit = {
|
||||||
markconsequentOOM(1)
|
markconsequentOOM(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
def markconsequentOOM(n : Int) : Unit = {
|
/**
|
||||||
|
* Causes the next `n` calls to [[acquireExecutionMemory()]] to fail to allocate
|
||||||
|
* memory (returning `0`), simulating low-on-memory / out-of-memory conditions.
|
||||||
|
*/
|
||||||
|
def markconsequentOOM(n: Int): Unit = synchronized {
|
||||||
consequentOOM += n
|
consequentOOM += n
|
||||||
}
|
}
|
||||||
|
|
||||||
def limit(avail: Long): Unit = {
|
def limit(avail: Long): Unit = synchronized {
|
||||||
|
require(avail >= 0)
|
||||||
available = avail
|
available = avail
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.memory
|
||||||
|
|
||||||
|
import org.apache.spark.{SparkConf, SparkFunSuite}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tests of [[TestMemoryManager]] itself.
|
||||||
|
*/
|
||||||
|
class TestMemoryManagerSuite extends SparkFunSuite {
|
||||||
|
test("tracks allocated execution memory by task") {
|
||||||
|
val testMemoryManager = new TestMemoryManager(new SparkConf())
|
||||||
|
|
||||||
|
assert(testMemoryManager.getExecutionMemoryUsageForTask(0) == 0)
|
||||||
|
assert(testMemoryManager.getExecutionMemoryUsageForTask(1) == 0)
|
||||||
|
|
||||||
|
testMemoryManager.acquireExecutionMemory(10, 0, MemoryMode.ON_HEAP)
|
||||||
|
testMemoryManager.acquireExecutionMemory(5, 1, MemoryMode.ON_HEAP)
|
||||||
|
testMemoryManager.acquireExecutionMemory(5, 0, MemoryMode.ON_HEAP)
|
||||||
|
assert(testMemoryManager.getExecutionMemoryUsageForTask(0) == 15)
|
||||||
|
assert(testMemoryManager.getExecutionMemoryUsageForTask(1) == 5)
|
||||||
|
|
||||||
|
testMemoryManager.releaseExecutionMemory(10, 0, MemoryMode.ON_HEAP)
|
||||||
|
assert(testMemoryManager.getExecutionMemoryUsageForTask(0) == 5)
|
||||||
|
|
||||||
|
testMemoryManager.releaseAllExecutionMemoryForTask(0)
|
||||||
|
testMemoryManager.releaseAllExecutionMemoryForTask(1)
|
||||||
|
assert(testMemoryManager.getExecutionMemoryUsageForTask(0) == 0)
|
||||||
|
assert(testMemoryManager.getExecutionMemoryUsageForTask(1) == 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("markconsequentOOM") {
|
||||||
|
val testMemoryManager = new TestMemoryManager(new SparkConf())
|
||||||
|
assert(testMemoryManager.acquireExecutionMemory(1, 0, MemoryMode.ON_HEAP) == 1)
|
||||||
|
testMemoryManager.markconsequentOOM(2)
|
||||||
|
assert(testMemoryManager.acquireExecutionMemory(1, 0, MemoryMode.ON_HEAP) == 0)
|
||||||
|
assert(testMemoryManager.acquireExecutionMemory(1, 0, MemoryMode.ON_HEAP) == 0)
|
||||||
|
assert(testMemoryManager.acquireExecutionMemory(1, 0, MemoryMode.ON_HEAP) == 1)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue