[SPARK-15391] [SQL] manage the temporary memory of timsort

## What changes were proposed in this pull request?

Currently, the memory for temporary buffer used by TimSort is always allocated as on-heap without bookkeeping, it could cause OOM both in on-heap and off-heap mode.

This PR will try to manage that by preallocate it together with the pointer array, same with RadixSort. It both works for on-heap and off-heap mode.

This PR also change the loadFactor of BytesToBytesMap to 0.5 (it was 0.70), it enables use to radix sort also makes sure that we have enough memory for timsort.

## How was this patch tested?

Existing tests.

Author: Davies Liu <davies@databricks.com>

Closes #13318 from davies/fix_timsort.
This commit is contained in:
Davies Liu 2016-06-03 16:45:09 -07:00 committed by Davies Liu
parent 67cc89ff02
commit 3074f575a3
13 changed files with 120 additions and 69 deletions

View file

@ -51,6 +51,6 @@ public class MemoryBlock extends MemoryLocation {
* Creates a memory block pointing to the memory used by the long array. * Creates a memory block pointing to the memory used by the long array.
*/ */
public static MemoryBlock fromLongArray(final long[] array) { public static MemoryBlock fromLongArray(final long[] array) {
return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8); return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L);
} }
} }

View file

@ -22,12 +22,12 @@ import java.util.Comparator;
import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.collection.Sorter; import org.apache.spark.util.collection.Sorter;
import org.apache.spark.util.collection.unsafe.sort.RadixSort; import org.apache.spark.util.collection.unsafe.sort.RadixSort;
final class ShuffleInMemorySorter { final class ShuffleInMemorySorter {
private final Sorter<PackedRecordPointer, LongArray> sorter;
private static final class SortComparator implements Comparator<PackedRecordPointer> { private static final class SortComparator implements Comparator<PackedRecordPointer> {
@Override @Override
public int compare(PackedRecordPointer left, PackedRecordPointer right) { public int compare(PackedRecordPointer left, PackedRecordPointer right) {
@ -44,6 +44,9 @@ final class ShuffleInMemorySorter {
* An array of record pointers and partition ids that have been encoded by * An array of record pointers and partition ids that have been encoded by
* {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
* records. * records.
*
* Only part of the array will be used to store the pointers, the rest part is preserved as
* temporary buffer for sorting.
*/ */
private LongArray array; private LongArray array;
@ -53,16 +56,16 @@ final class ShuffleInMemorySorter {
*/ */
private final boolean useRadixSort; private final boolean useRadixSort;
/**
* Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x.
*/
private final int memoryAllocationFactor;
/** /**
* The position in the pointer array where new records can be inserted. * The position in the pointer array where new records can be inserted.
*/ */
private int pos = 0; private int pos = 0;
/**
* How many records could be inserted, because part of the array should be left for sorting.
*/
private int usableCapacity = 0;
private int initialSize; private int initialSize;
ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize, boolean useRadixSort) { ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize, boolean useRadixSort) {
@ -70,9 +73,14 @@ final class ShuffleInMemorySorter {
assert (initialSize > 0); assert (initialSize > 0);
this.initialSize = initialSize; this.initialSize = initialSize;
this.useRadixSort = useRadixSort; this.useRadixSort = useRadixSort;
this.memoryAllocationFactor = useRadixSort ? 2 : 1;
this.array = consumer.allocateArray(initialSize); this.array = consumer.allocateArray(initialSize);
this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE); this.usableCapacity = getUsableCapacity();
}
private int getUsableCapacity() {
// Radix sort requires same amount of used memory as buffer, Tim sort requires
// half of the used memory as buffer.
return (int) (array.size() / (useRadixSort ? 2 : 1.5));
} }
public void free() { public void free() {
@ -89,7 +97,8 @@ final class ShuffleInMemorySorter {
public void reset() { public void reset() {
if (consumer != null) { if (consumer != null) {
consumer.freeArray(array); consumer.freeArray(array);
this.array = consumer.allocateArray(initialSize); array = consumer.allocateArray(initialSize);
usableCapacity = getUsableCapacity();
} }
pos = 0; pos = 0;
} }
@ -101,14 +110,15 @@ final class ShuffleInMemorySorter {
array.getBaseOffset(), array.getBaseOffset(),
newArray.getBaseObject(), newArray.getBaseObject(),
newArray.getBaseOffset(), newArray.getBaseOffset(),
array.size() * (8 / memoryAllocationFactor) pos * 8L
); );
consumer.freeArray(array); consumer.freeArray(array);
array = newArray; array = newArray;
usableCapacity = getUsableCapacity();
} }
public boolean hasSpaceForAnotherRecord() { public boolean hasSpaceForAnotherRecord() {
return pos < array.size() / memoryAllocationFactor; return pos < usableCapacity;
} }
public long getMemoryUsage() { public long getMemoryUsage() {
@ -170,6 +180,14 @@ final class ShuffleInMemorySorter {
PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX, PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX,
PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false); PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false);
} else { } else {
MemoryBlock unused = new MemoryBlock(
array.getBaseObject(),
array.getBaseOffset() + pos * 8L,
(array.size() - pos) * 8L);
LongArray buffer = new LongArray(unused);
Sorter<PackedRecordPointer, LongArray> sorter =
new Sorter<>(new ShuffleSortDataFormat(buffer));
sorter.sort(array, 0, pos, SORT_COMPARATOR); sorter.sort(array, 0, pos, SORT_COMPARATOR);
} }
return new ShuffleSorterIterator(pos, array, offset); return new ShuffleSorterIterator(pos, array, offset);

View file

@ -19,14 +19,15 @@ package org.apache.spark.shuffle.sort;
import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.collection.SortDataFormat; import org.apache.spark.util.collection.SortDataFormat;
final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, LongArray> { final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, LongArray> {
public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat(); private final LongArray buffer;
private ShuffleSortDataFormat() { } ShuffleSortDataFormat(LongArray buffer) {
this.buffer = buffer;
}
@Override @Override
public PackedRecordPointer getKey(LongArray data, int pos) { public PackedRecordPointer getKey(LongArray data, int pos) {
@ -70,8 +71,8 @@ final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, Lo
@Override @Override
public LongArray allocate(int length) { public LongArray allocate(int length) {
// This buffer is used temporary (usually small), so it's fine to allocated from JVM heap. assert (length <= buffer.size()) :
return new LongArray(MemoryBlock.fromLongArray(new long[length])); "the buffer is smaller than required: " + buffer.size() + " < " + length;
return buffer;
} }
} }

View file

@ -221,7 +221,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
SparkEnv.get() != null ? SparkEnv.get().blockManager() : null, SparkEnv.get() != null ? SparkEnv.get().blockManager() : null,
SparkEnv.get() != null ? SparkEnv.get().serializerManager() : null, SparkEnv.get() != null ? SparkEnv.get().serializerManager() : null,
initialCapacity, initialCapacity,
0.70, // In order to re-use the longArray for sorting, the load factor cannot be larger than 0.5.
0.5,
pageSizeBytes, pageSizeBytes,
enablePerfMetrics); enablePerfMetrics);
} }

View file

@ -25,6 +25,7 @@ import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.collection.Sorter; import org.apache.spark.util.collection.Sorter;
/** /**
@ -69,8 +70,6 @@ public final class UnsafeInMemorySorter {
private final MemoryConsumer consumer; private final MemoryConsumer consumer;
private final TaskMemoryManager memoryManager; private final TaskMemoryManager memoryManager;
@Nullable @Nullable
private final Sorter<RecordPointerAndKeyPrefix, LongArray> sorter;
@Nullable
private final Comparator<RecordPointerAndKeyPrefix> sortComparator; private final Comparator<RecordPointerAndKeyPrefix> sortComparator;
/** /**
@ -79,14 +78,12 @@ public final class UnsafeInMemorySorter {
@Nullable @Nullable
private final PrefixComparators.RadixSortSupport radixSortSupport; private final PrefixComparators.RadixSortSupport radixSortSupport;
/**
* Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x.
*/
private final int memoryAllocationFactor;
/** /**
* Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
* index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
*
* Only part of the array will be used to store the pointers, the rest part is preserved as
* temporary buffer for sorting.
*/ */
private LongArray array; private LongArray array;
@ -95,6 +92,11 @@ public final class UnsafeInMemorySorter {
*/ */
private int pos = 0; private int pos = 0;
/**
* How many records could be inserted, because part of the array should be left for sorting.
*/
private int usableCapacity = 0;
private long initialSize; private long initialSize;
private long totalSortTimeNanos = 0L; private long totalSortTimeNanos = 0L;
@ -121,7 +123,6 @@ public final class UnsafeInMemorySorter {
this.memoryManager = memoryManager; this.memoryManager = memoryManager;
this.initialSize = array.size(); this.initialSize = array.size();
if (recordComparator != null) { if (recordComparator != null) {
this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
if (canUseRadixSort && prefixComparator instanceof PrefixComparators.RadixSortSupport) { if (canUseRadixSort && prefixComparator instanceof PrefixComparators.RadixSortSupport) {
this.radixSortSupport = (PrefixComparators.RadixSortSupport)prefixComparator; this.radixSortSupport = (PrefixComparators.RadixSortSupport)prefixComparator;
@ -129,12 +130,17 @@ public final class UnsafeInMemorySorter {
this.radixSortSupport = null; this.radixSortSupport = null;
} }
} else { } else {
this.sorter = null;
this.sortComparator = null; this.sortComparator = null;
this.radixSortSupport = null; this.radixSortSupport = null;
} }
this.memoryAllocationFactor = this.radixSortSupport != null ? 2 : 1;
this.array = array; this.array = array;
this.usableCapacity = getUsableCapacity();
}
private int getUsableCapacity() {
// Radix sort requires same amount of used memory as buffer, Tim sort requires
// half of the used memory as buffer.
return (int) (array.size() / (radixSortSupport != null ? 2 : 1.5));
} }
/** /**
@ -150,7 +156,8 @@ public final class UnsafeInMemorySorter {
public void reset() { public void reset() {
if (consumer != null) { if (consumer != null) {
consumer.freeArray(array); consumer.freeArray(array);
this.array = consumer.allocateArray(initialSize); array = consumer.allocateArray(initialSize);
usableCapacity = getUsableCapacity();
} }
pos = 0; pos = 0;
} }
@ -174,7 +181,7 @@ public final class UnsafeInMemorySorter {
} }
public boolean hasSpaceForAnotherRecord() { public boolean hasSpaceForAnotherRecord() {
return pos + 1 < (array.size() / memoryAllocationFactor); return pos + 1 < usableCapacity;
} }
public void expandPointerArray(LongArray newArray) { public void expandPointerArray(LongArray newArray) {
@ -186,9 +193,10 @@ public final class UnsafeInMemorySorter {
array.getBaseOffset(), array.getBaseOffset(),
newArray.getBaseObject(), newArray.getBaseObject(),
newArray.getBaseOffset(), newArray.getBaseOffset(),
array.size() * (8 / memoryAllocationFactor)); pos * 8L);
consumer.freeArray(array); consumer.freeArray(array);
array = newArray; array = newArray;
usableCapacity = getUsableCapacity();
} }
/** /**
@ -275,13 +283,20 @@ public final class UnsafeInMemorySorter {
public SortedIterator getSortedIterator() { public SortedIterator getSortedIterator() {
int offset = 0; int offset = 0;
long start = System.nanoTime(); long start = System.nanoTime();
if (sorter != null) { if (sortComparator != null) {
if (this.radixSortSupport != null) { if (this.radixSortSupport != null) {
// TODO(ekl) we should handle NULL values before radix sort for efficiency, since they // TODO(ekl) we should handle NULL values before radix sort for efficiency, since they
// force a full-width sort (and we cannot radix-sort nullable long fields at all). // force a full-width sort (and we cannot radix-sort nullable long fields at all).
offset = RadixSort.sortKeyPrefixArray( offset = RadixSort.sortKeyPrefixArray(
array, pos / 2, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); array, pos / 2, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
} else { } else {
MemoryBlock unused = new MemoryBlock(
array.getBaseObject(),
array.getBaseOffset() + pos * 8L,
(array.size() - pos) * 8L);
LongArray buffer = new LongArray(unused);
Sorter<RecordPointerAndKeyPrefix, LongArray> sorter =
new Sorter<>(new UnsafeSortDataFormat(buffer));
sorter.sort(array, 0, pos / 2, sortComparator); sorter.sort(array, 0, pos / 2, sortComparator);
} }
} }

View file

@ -19,7 +19,6 @@ package org.apache.spark.util.collection.unsafe.sort;
import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.collection.SortDataFormat; import org.apache.spark.util.collection.SortDataFormat;
/** /**
@ -32,9 +31,11 @@ import org.apache.spark.util.collection.SortDataFormat;
public final class UnsafeSortDataFormat public final class UnsafeSortDataFormat
extends SortDataFormat<RecordPointerAndKeyPrefix, LongArray> { extends SortDataFormat<RecordPointerAndKeyPrefix, LongArray> {
public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); private final LongArray buffer;
private UnsafeSortDataFormat() { } public UnsafeSortDataFormat(LongArray buffer) {
this.buffer = buffer;
}
@Override @Override
public RecordPointerAndKeyPrefix getKey(LongArray data, int pos) { public RecordPointerAndKeyPrefix getKey(LongArray data, int pos) {
@ -83,9 +84,9 @@ public final class UnsafeSortDataFormat
@Override @Override
public LongArray allocate(int length) { public LongArray allocate(int length) {
assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large"; assert (length * 2 <= buffer.size()) :
// This is used as temporary buffer, it's fine to allocate from JVM heap. "the buffer is smaller than required: " + buffer.size() + " < " + (length * 2);
return new LongArray(MemoryBlock.fromLongArray(new long[length * 2])); return buffer;
} }
} }

View file

@ -21,12 +21,15 @@ import java.io.*;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.*; import java.util.*;
import scala.*; import scala.Option;
import scala.Product2;
import scala.Tuple2;
import scala.Tuple2$;
import scala.collection.Iterator; import scala.collection.Iterator;
import scala.runtime.AbstractFunction1; import scala.runtime.AbstractFunction1;
import com.google.common.collect.Iterators;
import com.google.common.collect.HashMultiset; import com.google.common.collect.HashMultiset;
import com.google.common.collect.Iterators;
import com.google.common.io.ByteStreams; import com.google.common.io.ByteStreams;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
@ -35,6 +38,26 @@ import org.mockito.Mock;
import org.mockito.MockitoAnnotations; import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock; import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer; import org.mockito.stubbing.Answer;
import org.apache.spark.HashPartitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.LZ4CompressionCodec;
import org.apache.spark.io.LZFCompressionCodec;
import org.apache.spark.io.SnappyCompressionCodec;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.serializer.*;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThan;
@ -42,22 +65,6 @@ import static org.junit.Assert.*;
import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Answers.RETURNS_SMART_NULLS;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
import org.apache.spark.*;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.LZ4CompressionCodec;
import org.apache.spark.io.LZFCompressionCodec;
import org.apache.spark.io.SnappyCompressionCodec;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.serializer.*;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.storage.*;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;
public class UnsafeShuffleWriterSuite { public class UnsafeShuffleWriterSuite {
static final int NUM_PARTITITONS = 4; static final int NUM_PARTITITONS = 4;

View file

@ -589,7 +589,7 @@ public abstract class AbstractBytesToBytesMapSuite {
@Test @Test
public void multipleValuesForSameKey() { public void multipleValuesForSameKey() {
BytesToBytesMap map = BytesToBytesMap map =
new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.75, 1024, false); new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.5, 1024, false);
try { try {
int i; int i;
for (i = 0; i < 1024; i++) { for (i = 0; i < 1024; i++) {

View file

@ -106,8 +106,10 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
// that can trigger copyRange() in TimSort.mergeLo() or TimSort.mergeHi() // that can trigger copyRange() in TimSort.mergeLo() or TimSort.mergeHi()
val ref = Array.tabulate[Long](size) { i => if (i < size / 2) size / 2 + i else i } val ref = Array.tabulate[Long](size) { i => if (i < size / 2) size / 2 + i else i }
val buf = new LongArray(MemoryBlock.fromLongArray(ref)) val buf = new LongArray(MemoryBlock.fromLongArray(ref))
val tmp = new Array[Long](size/2)
val tmpBuf = new LongArray(MemoryBlock.fromLongArray(tmp))
new Sorter(UnsafeSortDataFormat.INSTANCE).sort( new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort(
buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] { buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] {
override def compare( override def compare(
r1: RecordPointerAndKeyPrefix, r1: RecordPointerAndKeyPrefix,

View file

@ -93,7 +93,8 @@ class RadixSortSuite extends SparkFunSuite with Logging {
} }
private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) {
new Sorter(UnsafeSortDataFormat.INSTANCE).sort( val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] {
override def compare( override def compare(
r1: RecordPointerAndKeyPrefix, r1: RecordPointerAndKeyPrefix,

View file

@ -73,6 +73,8 @@ public final class UnsafeKVExternalSorter {
PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema); PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema);
BaseOrdering ordering = GenerateOrdering.create(keySchema); BaseOrdering ordering = GenerateOrdering.create(keySchema);
KVComparator recordComparator = new KVComparator(ordering, keySchema.length()); KVComparator recordComparator = new KVComparator(ordering, keySchema.length());
boolean canUseRadixSort = keySchema.length() == 1 &&
SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0));
TaskMemoryManager taskMemoryManager = taskContext.taskMemoryManager(); TaskMemoryManager taskMemoryManager = taskContext.taskMemoryManager();
@ -86,14 +88,16 @@ public final class UnsafeKVExternalSorter {
prefixComparator, prefixComparator,
/* initialSize */ 4096, /* initialSize */ 4096,
pageSizeBytes, pageSizeBytes,
keySchema.length() == 1 && SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0))); canUseRadixSort);
} else { } else {
// The array will be used to do in-place sort, which require half of the space to be empty.
assert(map.numKeys() <= map.getArray().size() / 2);
// During spilling, the array in map will not be used, so we can borrow that and use it // During spilling, the array in map will not be used, so we can borrow that and use it
// as the underline array for in-memory sorter (it's always large enough). // as the underline array for in-memory sorter (it's always large enough).
// Since we will not grow the array, it's fine to pass `null` as consumer. // Since we will not grow the array, it's fine to pass `null` as consumer.
final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
null, taskMemoryManager, recordComparator, prefixComparator, map.getArray(), null, taskMemoryManager, recordComparator, prefixComparator, map.getArray(),
false /* TODO(ekl) we can only radix sort if the BytesToBytes load factor is <= 0.5 */); canUseRadixSort);
// We cannot use the destructive iterator here because we are reusing the existing memory // We cannot use the destructive iterator here because we are reusing the existing memory
// pages in BytesToBytesMap to hold records during sorting. // pages in BytesToBytesMap to hold records during sorting.

View file

@ -540,7 +540,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
cursor - Platform.LONG_ARRAY_OFFSET) cursor - Platform.LONG_ARRAY_OFFSET)
page = newPage page = newPage
freeMemory(used * 8) freeMemory(used * 8L)
} }
// copy the bytes of UnsafeRow // copy the bytes of UnsafeRow
@ -599,7 +599,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
i += 2 i += 2
} }
old_array = null // release the reference to old array old_array = null // release the reference to old array
freeMemory(n * 8) freeMemory(n * 8L)
} }
/** /**
@ -610,7 +610,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
// Convert to dense mode if it does not require more memory or could fit within L1 cache // Convert to dense mode if it does not require more memory or could fit within L1 cache
if (range < array.length || range < 1024) { if (range < array.length || range < 1024) {
try { try {
ensureAcquireMemory((range + 1) * 8) ensureAcquireMemory((range + 1) * 8L)
} catch { } catch {
case e: SparkException => case e: SparkException =>
// there is no enough memory to convert // there is no enough memory to convert
@ -628,7 +628,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
val old_length = array.length val old_length = array.length
array = denseArray array = denseArray
isDense = true isDense = true
freeMemory(old_length * 8) freeMemory(old_length * 8L)
} }
} }
@ -637,11 +637,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/ */
def free(): Unit = { def free(): Unit = {
if (page != null) { if (page != null) {
freeMemory(page.length * 8) freeMemory(page.length * 8L)
page = null page = null
} }
if (array != null) { if (array != null) {
freeMemory(array.length * 8) freeMemory(array.length * 8L)
array = null array = null
} }
} }

View file

@ -36,7 +36,8 @@ import org.apache.spark.util.random.XORShiftRandom
class SortBenchmark extends BenchmarkBase { class SortBenchmark extends BenchmarkBase {
private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) {
new Sorter(UnsafeSortDataFormat.INSTANCE).sort( val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] {
override def compare( override def compare(
r1: RecordPointerAndKeyPrefix, r1: RecordPointerAndKeyPrefix,