[SPARK-31952][SQL] Fix incorrect memory spill metric when doing Aggregate
### What changes were proposed in this pull request? This PR takes over https://github.com/apache/spark/pull/28780. 1. Counted the spilled memory size when creating the `UnsafeExternalSorter` with the existing `InMemorySorter` 2. Accumulate the `totalSpillBytes` when merging two `UnsafeExternalSorter` ### Why are the changes needed? As mentioned in https://github.com/apache/spark/pull/28780: > It happends when hash aggregate downgrades to sort based aggregate. `UnsafeExternalSorter.createWithExistingInMemorySorter` calls spill on an `InMemorySorter` immediately, but the memory pointed by `InMemorySorter` is acquired by outside `BytesToBytesMap`, instead the allocatedPages in `UnsafeExternalSorter`. So the memory spill bytes metric is always 0, but disk bytes spill metric is right. Besides, this PR also fixes the `UnsafeExternalSorter.merge` by accumulating the `totalSpillBytes` of two sorters. Thus, we can report the correct spilled size in `HashAggregateExec.finishAggregate`. Issues can be reproduced by the following step by checking the SQL metrics in UI: ``` bin/spark-shell --driver-memory 512m --executor-memory 512m --executor-cores 1 --conf "spark.default.parallelism=1" scala> sql("select id, count(1) from range(10000000) group by id").write.csv("/tmp/result.json") ``` Before: <img width="200" alt="WeChatfe5146180d91015e03b9a27852e9a443" src="https://user-images.githubusercontent.com/16397174/103625414-e6fc6280-4f75-11eb-8b93-c55095bdb5b8.png"> After: <img width="200" alt="WeChat42ab0e73c5fbc3b14c12ab85d232071d" src="https://user-images.githubusercontent.com/16397174/103625420-e8c62600-4f75-11eb-8e1f-6f5e8ab561b9.png"> ### Does this PR introduce _any_ user-facing change? Yes, users can see the correct spill metrics after this PR. ### How was this patch tested? Tested manually and added UTs. Closes #31035 from Ngone51/SPARK-31952. Lead-authored-by: yi.wu <yi.wu@databricks.com> Co-authored-by: wangguangxin.cn <wangguangxin.cn@bytedance.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
d97e99157e
commit
4afca0f706
|
@ -104,11 +104,14 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
|
|||
int initialSize,
|
||||
long pageSizeBytes,
|
||||
int numElementsForSpillThreshold,
|
||||
UnsafeInMemorySorter inMemorySorter) throws IOException {
|
||||
UnsafeInMemorySorter inMemorySorter,
|
||||
long existingMemoryConsumption) throws IOException {
|
||||
UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
|
||||
serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize,
|
||||
pageSizeBytes, numElementsForSpillThreshold, inMemorySorter, false /* ignored */);
|
||||
sorter.spill(Long.MAX_VALUE, sorter);
|
||||
taskContext.taskMetrics().incMemoryBytesSpilled(existingMemoryConsumption);
|
||||
sorter.totalSpillBytes += existingMemoryConsumption;
|
||||
// The external sorter will be used to insert records, in-memory sorter is not needed.
|
||||
sorter.inMemSorter = null;
|
||||
return sorter;
|
||||
|
@ -496,6 +499,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
|
|||
*/
|
||||
public void merge(UnsafeExternalSorter other) throws IOException {
|
||||
other.spill();
|
||||
totalSpillBytes += other.totalSpillBytes;
|
||||
spillWriters.addAll(other.spillWriters);
|
||||
// remove them from `spillWriters`, or the files will be deleted in `cleanupResources`.
|
||||
other.spillWriters.clear();
|
||||
|
|
|
@ -165,7 +165,8 @@ public final class UnsafeKVExternalSorter {
|
|||
(int) (long) SparkEnv.get().conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()),
|
||||
pageSizeBytes,
|
||||
numElementsForSpillThreshold,
|
||||
inMemSorter);
|
||||
inMemSorter,
|
||||
map.getTotalMemoryConsumption());
|
||||
|
||||
// reset the map, so we can re-use it to insert new records. the inMemSorter will not used
|
||||
// anymore, so the underline array could be used by map again.
|
||||
|
|
|
@ -210,23 +210,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession
|
|||
test("SPARK-23376: Create UnsafeKVExternalSorter with BytesToByteMap having duplicated keys") {
|
||||
val memoryManager = new TestMemoryManager(new SparkConf())
|
||||
val taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
|
||||
val map = new BytesToBytesMap(taskMemoryManager, 64, taskMemoryManager.pageSizeBytes())
|
||||
|
||||
// Key/value are a unsafe rows with a single int column
|
||||
val map = createBytesToBytesMapWithDuplicateKeys(taskMemoryManager)
|
||||
val schema = new StructType().add("i", IntegerType)
|
||||
val key = new UnsafeRow(1)
|
||||
key.pointTo(new Array[Byte](32), 32)
|
||||
key.setInt(0, 1)
|
||||
val value = new UnsafeRow(1)
|
||||
value.pointTo(new Array[Byte](32), 32)
|
||||
value.setInt(0, 2)
|
||||
|
||||
for (_ <- 1 to 65) {
|
||||
val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
|
||||
loc.append(
|
||||
key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
|
||||
value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
|
||||
}
|
||||
|
||||
// Make sure we can successfully create a UnsafeKVExternalSorter with a `BytesToBytesMap`
|
||||
// which has duplicated keys and the number of entries exceeds its capacity.
|
||||
|
@ -245,4 +230,82 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession
|
|||
TaskContext.unset()
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-31952: create UnsafeKVExternalSorter with existing map should count spilled memory " +
|
||||
"size correctly") {
|
||||
val memoryManager = new TestMemoryManager(new SparkConf())
|
||||
val taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
|
||||
val map = createBytesToBytesMapWithDuplicateKeys(taskMemoryManager)
|
||||
val schema = new StructType().add("i", IntegerType)
|
||||
|
||||
try {
|
||||
val context = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties(), null)
|
||||
TaskContext.setTaskContext(context)
|
||||
val expectedSpillSize = map.getTotalMemoryConsumption
|
||||
val sorter = new UnsafeKVExternalSorter(
|
||||
schema,
|
||||
schema,
|
||||
sparkContext.env.blockManager,
|
||||
sparkContext.env.serializerManager,
|
||||
taskMemoryManager.pageSizeBytes(),
|
||||
Int.MaxValue,
|
||||
map)
|
||||
assert(sorter.getSpillSize === expectedSpillSize)
|
||||
} finally {
|
||||
TaskContext.unset()
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-31952: UnsafeKVExternalSorter.merge should accumulate totalSpillBytes") {
|
||||
val memoryManager = new TestMemoryManager(new SparkConf())
|
||||
val taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
|
||||
val map1 = createBytesToBytesMapWithDuplicateKeys(taskMemoryManager)
|
||||
val map2 = createBytesToBytesMapWithDuplicateKeys(taskMemoryManager)
|
||||
val schema = new StructType().add("i", IntegerType)
|
||||
|
||||
try {
|
||||
val context = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties(), null)
|
||||
TaskContext.setTaskContext(context)
|
||||
val expectedSpillSize = map1.getTotalMemoryConsumption + map2.getTotalMemoryConsumption
|
||||
val sorter1 = new UnsafeKVExternalSorter(
|
||||
schema,
|
||||
schema,
|
||||
sparkContext.env.blockManager,
|
||||
sparkContext.env.serializerManager,
|
||||
taskMemoryManager.pageSizeBytes(),
|
||||
Int.MaxValue,
|
||||
map1)
|
||||
val sorter2 = new UnsafeKVExternalSorter(
|
||||
schema,
|
||||
schema,
|
||||
sparkContext.env.blockManager,
|
||||
sparkContext.env.serializerManager,
|
||||
taskMemoryManager.pageSizeBytes(),
|
||||
Int.MaxValue,
|
||||
map2)
|
||||
sorter1.merge(sorter2)
|
||||
assert(sorter1.getSpillSize === expectedSpillSize)
|
||||
} finally {
|
||||
TaskContext.unset()
|
||||
}
|
||||
}
|
||||
|
||||
private def createBytesToBytesMapWithDuplicateKeys(taskMemoryManager: TaskMemoryManager)
|
||||
: BytesToBytesMap = {
|
||||
val map = new BytesToBytesMap(taskMemoryManager, 64, taskMemoryManager.pageSizeBytes())
|
||||
// Key/value are a unsafe rows with a single int column
|
||||
val key = new UnsafeRow(1)
|
||||
key.pointTo(new Array[Byte](32), 32)
|
||||
key.setInt(0, 1)
|
||||
val value = new UnsafeRow(1)
|
||||
value.pointTo(new Array[Byte](32), 32)
|
||||
value.setInt(0, 2)
|
||||
for (_ <- 1 to 65) {
|
||||
val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
|
||||
loc.append(
|
||||
key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
|
||||
value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
|
||||
}
|
||||
map
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue