[SPARK-7542][SQL] Support off-heap index/sort buffer

This brings the support of off-heap memory for array inside BytesToBytesMap and InMemorySorter, then we could allocate all the memory from off-heap for execution.

Closes #8068

Author: Davies Liu <davies@databricks.com>

Closes #9477 from davies/unsafe_timsort.
This commit is contained in:
Davies Liu 2015-11-05 19:02:18 -08:00 committed by Josh Rosen
parent 3cc2c053b5
commit eec74ba8bd
17 changed files with 266 additions and 190 deletions

View file

@ -20,6 +20,7 @@ package org.apache.spark.memory;
import java.io.IOException;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
@ -28,9 +29,9 @@ import org.apache.spark.unsafe.memory.MemoryBlock;
*/
public abstract class MemoryConsumer {
private final TaskMemoryManager taskMemoryManager;
protected final TaskMemoryManager taskMemoryManager;
private final long pageSize;
private long used;
protected long used;
protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) {
this.taskMemoryManager = taskMemoryManager;
@ -74,26 +75,29 @@ public abstract class MemoryConsumer {
public abstract long spill(long size, MemoryConsumer trigger) throws IOException;
/**
* Acquire `size` bytes memory.
*
* If there is not enough memory, throws OutOfMemoryError.
* Allocates a LongArray of `size`.
*/
protected void acquireMemory(long size) {
long got = taskMemoryManager.acquireExecutionMemory(size, this);
if (got < size) {
taskMemoryManager.releaseExecutionMemory(got, this);
public LongArray allocateArray(long size) {
long required = size * 8L;
MemoryBlock page = taskMemoryManager.allocatePage(required, this);
if (page == null || page.size() < required) {
long got = 0;
if (page != null) {
got = page.size();
taskMemoryManager.freePage(page, this);
}
taskMemoryManager.showMemoryUsage();
throw new OutOfMemoryError("Could not acquire " + size + " bytes of memory, got " + got);
throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got);
}
used += got;
used += required;
return new LongArray(page);
}
/**
* Release `size` bytes memory.
* Frees a LongArray.
*/
protected void releaseMemory(long size) {
used -= size;
taskMemoryManager.releaseExecutionMemory(size, this);
public void freeArray(LongArray array) {
freePage(array.memoryBlock());
}
/**
@ -109,7 +113,7 @@ public abstract class MemoryConsumer {
long got = 0;
if (page != null) {
got = page.size();
freePage(page);
taskMemoryManager.freePage(page, this);
}
taskMemoryManager.showMemoryUsage();
throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got);

View file

@ -137,7 +137,7 @@ public class TaskMemoryManager {
if (got < required) {
// Call spill() on other consumers to release memory
for (MemoryConsumer c: consumers) {
if (c != null && c != consumer && c.getUsed() > 0) {
if (c != consumer && c.getUsed() > 0) {
try {
long released = c.spill(required - got, consumer);
if (released > 0) {
@ -173,7 +173,9 @@ public class TaskMemoryManager {
}
}
consumers.add(consumer);
if (consumer != null) {
consumers.add(consumer);
}
logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer);
return got;
}

View file

@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.TempShuffleBlockId;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.Utils;
@ -114,8 +115,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
this.numElementsForSpillThreshold =
conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
this.writeMetrics = writeMetrics;
acquireMemory(initialSize * 8L);
this.inMemSorter = new ShuffleInMemorySorter(initialSize);
this.inMemSorter = new ShuffleInMemorySorter(this, initialSize);
this.peakMemoryUsedBytes = getMemoryUsage();
}
@ -301,9 +301,8 @@ final class ShuffleExternalSorter extends MemoryConsumer {
public void cleanupResources() {
freeMemory();
if (inMemSorter != null) {
long sorterMemoryUsage = inMemSorter.getMemoryUsage();
inMemSorter.free();
inMemSorter = null;
releaseMemory(sorterMemoryUsage);
}
for (SpillInfo spill : spills) {
if (spill.file.exists() && !spill.file.delete()) {
@ -321,9 +320,10 @@ final class ShuffleExternalSorter extends MemoryConsumer {
assert(inMemSorter != null);
if (!inMemSorter.hasSpaceForAnotherRecord()) {
long used = inMemSorter.getMemoryUsage();
long needed = used + inMemSorter.getMemoryToExpand();
LongArray array;
try {
acquireMemory(needed); // could trigger spilling
// could trigger spilling
array = allocateArray(used / 8 * 2);
} catch (OutOfMemoryError e) {
// should have trigger spilling
assert(inMemSorter.hasSpaceForAnotherRecord());
@ -331,16 +331,9 @@ final class ShuffleExternalSorter extends MemoryConsumer {
}
// check if spilling is triggered or not
if (inMemSorter.hasSpaceForAnotherRecord()) {
releaseMemory(needed);
freeArray(array);
} else {
try {
inMemSorter.expandPointerArray();
releaseMemory(used);
} catch (OutOfMemoryError oom) {
// Just in case that JVM had run out of memory
releaseMemory(needed);
spill();
}
inMemSorter.expandPointerArray(array);
}
}
}
@ -404,9 +397,8 @@ final class ShuffleExternalSorter extends MemoryConsumer {
// Do not count the final file towards the spill count.
writeSortedFile(true);
freeMemory();
long sorterMemoryUsage = inMemSorter.getMemoryUsage();
inMemSorter.free();
inMemSorter = null;
releaseMemory(sorterMemoryUsage);
}
return spills.toArray(new SpillInfo[spills.size()]);
} catch (IOException e) {

View file

@ -19,11 +19,14 @@ package org.apache.spark.shuffle.sort;
import java.util.Comparator;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.util.collection.Sorter;
final class ShuffleInMemorySorter {
private final Sorter<PackedRecordPointer, long[]> sorter;
private final Sorter<PackedRecordPointer, LongArray> sorter;
private static final class SortComparator implements Comparator<PackedRecordPointer> {
@Override
public int compare(PackedRecordPointer left, PackedRecordPointer right) {
@ -32,24 +35,34 @@ final class ShuffleInMemorySorter {
}
private static final SortComparator SORT_COMPARATOR = new SortComparator();
private final MemoryConsumer consumer;
/**
* An array of record pointers and partition ids that have been encoded by
* {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
* records.
*/
private long[] array;
private LongArray array;
/**
* The position in the pointer array where new records can be inserted.
*/
private int pos = 0;
public ShuffleInMemorySorter(int initialSize) {
public ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) {
this.consumer = consumer;
assert (initialSize > 0);
this.array = new long[initialSize];
this.array = consumer.allocateArray(initialSize);
this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE);
}
public void free() {
if (array != null) {
consumer.freeArray(array);
array = null;
}
}
public int numRecords() {
return pos;
}
@ -58,30 +71,25 @@ final class ShuffleInMemorySorter {
pos = 0;
}
private int newLength() {
// Guard against overflow:
return array.length <= Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE;
}
/**
* Returns the memory needed to expand
*/
public long getMemoryToExpand() {
return ((long) (newLength() - array.length)) * 8;
}
public void expandPointerArray() {
final long[] oldArray = array;
array = new long[newLength()];
System.arraycopy(oldArray, 0, array, 0, oldArray.length);
public void expandPointerArray(LongArray newArray) {
assert(newArray.size() > array.size());
Platform.copyMemory(
array.getBaseObject(),
array.getBaseOffset(),
newArray.getBaseObject(),
newArray.getBaseOffset(),
array.size() * 8L
);
consumer.freeArray(array);
array = newArray;
}
public boolean hasSpaceForAnotherRecord() {
return pos < array.length;
return pos < array.size();
}
public long getMemoryUsage() {
return array.length * 8L;
return array.size() * 8L;
}
/**
@ -96,14 +104,9 @@ final class ShuffleInMemorySorter {
*/
public void insertRecord(long recordPointer, int partitionId) {
if (!hasSpaceForAnotherRecord()) {
if (array.length == Integer.MAX_VALUE) {
throw new IllegalStateException("Sort pointer array has reached maximum size");
} else {
expandPointerArray();
}
expandPointerArray(consumer.allocateArray(array.size() * 2));
}
array[pos] =
PackedRecordPointer.packPointer(recordPointer, partitionId);
array.set(pos, PackedRecordPointer.packPointer(recordPointer, partitionId));
pos++;
}
@ -112,12 +115,12 @@ final class ShuffleInMemorySorter {
*/
public static final class ShuffleSorterIterator {
private final long[] pointerArray;
private final LongArray pointerArray;
private final int numRecords;
final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
private int position = 0;
public ShuffleSorterIterator(int numRecords, long[] pointerArray) {
public ShuffleSorterIterator(int numRecords, LongArray pointerArray) {
this.numRecords = numRecords;
this.pointerArray = pointerArray;
}
@ -127,7 +130,7 @@ final class ShuffleInMemorySorter {
}
public void loadNext() {
packedRecordPointer.set(pointerArray[position]);
packedRecordPointer.set(pointerArray.get(position));
position++;
}
}

View file

@ -17,16 +17,19 @@
package org.apache.spark.shuffle.sort;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.collection.SortDataFormat;
final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, LongArray> {
public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat();
private ShuffleSortDataFormat() { }
@Override
public PackedRecordPointer getKey(long[] data, int pos) {
public PackedRecordPointer getKey(LongArray data, int pos) {
// Since we re-use keys, this method shouldn't be called.
throw new UnsupportedOperationException();
}
@ -37,31 +40,38 @@ final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, lo
}
@Override
public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) {
reuse.set(data[pos]);
public PackedRecordPointer getKey(LongArray data, int pos, PackedRecordPointer reuse) {
reuse.set(data.get(pos));
return reuse;
}
@Override
public void swap(long[] data, int pos0, int pos1) {
final long temp = data[pos0];
data[pos0] = data[pos1];
data[pos1] = temp;
public void swap(LongArray data, int pos0, int pos1) {
final long temp = data.get(pos0);
data.set(pos0, data.get(pos1));
data.set(pos1, temp);
}
@Override
public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
dst[dstPos] = src[srcPos];
public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) {
dst.set(dstPos, src.get(srcPos));
}
@Override
public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
System.arraycopy(src, srcPos, dst, dstPos, length);
public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) {
Platform.copyMemory(
src.getBaseObject(),
src.getBaseOffset() + srcPos * 8,
dst.getBaseObject(),
dst.getBaseOffset() + dstPos * 8,
length * 8
);
}
@Override
public long[] allocate(int length) {
return new long[length];
public LongArray allocate(int length) {
// This buffer is used temporary (usually small), so it's fine to allocated from JVM heap.
return new LongArray(MemoryBlock.fromLongArray(new long[length]));
}
}

View file

@ -20,7 +20,6 @@ package org.apache.spark.unsafe.map;
import javax.annotation.Nullable;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
@ -724,11 +723,10 @@ public final class BytesToBytesMap extends MemoryConsumer {
*/
private void allocate(int capacity) {
assert (capacity >= 0);
// The capacity needs to be divisible by 64 so that our bit set can be sized properly
capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64);
assert (capacity <= MAX_CAPACITY);
acquireMemory(capacity * 16);
longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2]));
longArray = allocateArray(capacity * 2);
longArray.zeroOut();
this.growthThreshold = (int) (capacity * loadFactor);
this.mask = capacity - 1;
@ -743,9 +741,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
public void free() {
updatePeakMemoryUsed();
if (longArray != null) {
long used = longArray.memoryBlock().size();
freeArray(longArray);
longArray = null;
releaseMemory(used);
}
Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator();
while (dataPagesIterator.hasNext()) {
@ -834,9 +831,9 @@ public final class BytesToBytesMap extends MemoryConsumer {
/**
* Returns the underline long[] of longArray.
*/
public long[] getArray() {
public LongArray getArray() {
assert(longArray != null);
return (long[]) longArray.memoryBlock().getBaseObject();
return longArray;
}
/**
@ -844,7 +841,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
*/
public void reset() {
numElements = 0;
Arrays.fill(getArray(), 0);
longArray.zeroOut();
while (dataPages.size() > 0) {
MemoryBlock dataPage = dataPages.removeLast();
freePage(dataPage);
@ -887,7 +885,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
longArray.set(newPos * 2, keyPointer);
longArray.set(newPos * 2 + 1, hashcode);
}
releaseMemory(oldLongArray.memoryBlock().size());
freeArray(oldLongArray);
if (enablePerfMetrics) {
timeSpentResizingNs += System.nanoTime() - resizeStartTime;

View file

@ -32,6 +32,7 @@ import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.TaskCompletionListener;
import org.apache.spark.util.Utils;
@ -123,9 +124,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
this.writeMetrics = new ShuffleWriteMetrics();
if (existingInMemorySorter == null) {
this.inMemSorter =
new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize);
acquireMemory(inMemSorter.getMemoryUsage());
this.inMemSorter = new UnsafeInMemorySorter(
this, taskMemoryManager, recordComparator, prefixComparator, initialSize);
} else {
this.inMemSorter = existingInMemorySorter;
}
@ -277,9 +277,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
deleteSpillFiles();
freeMemory();
if (inMemSorter != null) {
long used = inMemSorter.getMemoryUsage();
inMemSorter.free();
inMemSorter = null;
releaseMemory(used);
}
}
}
@ -293,9 +292,10 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
assert(inMemSorter != null);
if (!inMemSorter.hasSpaceForAnotherRecord()) {
long used = inMemSorter.getMemoryUsage();
long needed = used + inMemSorter.getMemoryToExpand();
LongArray array;
try {
acquireMemory(needed); // could trigger spilling
// could trigger spilling
array = allocateArray(used / 8 * 2);
} catch (OutOfMemoryError e) {
// should have trigger spilling
assert(inMemSorter.hasSpaceForAnotherRecord());
@ -303,16 +303,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
}
// check if spilling is triggered or not
if (inMemSorter.hasSpaceForAnotherRecord()) {
releaseMemory(needed);
freeArray(array);
} else {
try {
inMemSorter.expandPointerArray();
releaseMemory(used);
} catch (OutOfMemoryError oom) {
// Just in case that JVM had run out of memory
releaseMemory(needed);
spill();
}
inMemSorter.expandPointerArray(array);
}
}
}
@ -498,9 +491,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
nextUpstream = null;
assert(inMemSorter != null);
long used = inMemSorter.getMemoryUsage();
inMemSorter.free();
inMemSorter = null;
releaseMemory(used);
}
numRecords--;
upstream.loadNext();

View file

@ -19,8 +19,10 @@ package org.apache.spark.util.collection.unsafe.sort;
import java.util.Comparator;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.util.collection.Sorter;
/**
@ -62,15 +64,16 @@ public final class UnsafeInMemorySorter {
}
}
private final MemoryConsumer consumer;
private final TaskMemoryManager memoryManager;
private final Sorter<RecordPointerAndKeyPrefix, long[]> sorter;
private final Sorter<RecordPointerAndKeyPrefix, LongArray> sorter;
private final Comparator<RecordPointerAndKeyPrefix> sortComparator;
/**
* Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
* index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
*/
private long[] array;
private LongArray array;
/**
* The position in the sort buffer where new records can be inserted.
@ -78,22 +81,33 @@ public final class UnsafeInMemorySorter {
private int pos = 0;
public UnsafeInMemorySorter(
final MemoryConsumer consumer,
final TaskMemoryManager memoryManager,
final RecordComparator recordComparator,
final PrefixComparator prefixComparator,
int initialSize) {
this(memoryManager, recordComparator, prefixComparator, new long[initialSize * 2]);
this(consumer, memoryManager, recordComparator, prefixComparator,
consumer.allocateArray(initialSize * 2));
}
public UnsafeInMemorySorter(
final MemoryConsumer consumer,
final TaskMemoryManager memoryManager,
final RecordComparator recordComparator,
final PrefixComparator prefixComparator,
long[] array) {
this.array = array;
LongArray array) {
this.consumer = consumer;
this.memoryManager = memoryManager;
this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
this.array = array;
}
/**
* Free the memory used by pointer array.
*/
public void free() {
consumer.freeArray(array);
}
public void reset() {
@ -107,26 +121,26 @@ public final class UnsafeInMemorySorter {
return pos / 2;
}
private int newLength() {
return array.length < Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE;
}
public long getMemoryToExpand() {
return (long) (newLength() - array.length) * 8L;
}
public long getMemoryUsage() {
return array.length * 8L;
return array.size() * 8L;
}
public boolean hasSpaceForAnotherRecord() {
return pos + 2 <= array.length;
return pos + 2 <= array.size();
}
public void expandPointerArray() {
final long[] oldArray = array;
array = new long[newLength()];
System.arraycopy(oldArray, 0, array, 0, oldArray.length);
public void expandPointerArray(LongArray newArray) {
if (newArray.size() < array.size()) {
throw new OutOfMemoryError("Not enough memory to grow pointer array");
}
Platform.copyMemory(
array.getBaseObject(),
array.getBaseOffset(),
newArray.getBaseObject(),
newArray.getBaseOffset(),
array.size() * 8L);
consumer.freeArray(array);
array = newArray;
}
/**
@ -138,11 +152,11 @@ public final class UnsafeInMemorySorter {
*/
public void insertRecord(long recordPointer, long keyPrefix) {
if (!hasSpaceForAnotherRecord()) {
expandPointerArray();
expandPointerArray(consumer.allocateArray(array.size() * 2));
}
array[pos] = recordPointer;
array.set(pos, recordPointer);
pos++;
array[pos] = keyPrefix;
array.set(pos, keyPrefix);
pos++;
}
@ -150,7 +164,7 @@ public final class UnsafeInMemorySorter {
private final TaskMemoryManager memoryManager;
private final int sortBufferInsertPosition;
private final long[] sortBuffer;
private final LongArray sortBuffer;
private int position = 0;
private Object baseObject;
private long baseOffset;
@ -160,7 +174,7 @@ public final class UnsafeInMemorySorter {
private SortedIterator(
TaskMemoryManager memoryManager,
int sortBufferInsertPosition,
long[] sortBuffer) {
LongArray sortBuffer) {
this.memoryManager = memoryManager;
this.sortBufferInsertPosition = sortBufferInsertPosition;
this.sortBuffer = sortBuffer;
@ -188,11 +202,11 @@ public final class UnsafeInMemorySorter {
@Override
public void loadNext() {
// This pointer points to a 4-byte record length, followed by the record's bytes
final long recordPointer = sortBuffer[position];
final long recordPointer = sortBuffer.get(position);
baseObject = memoryManager.getPage(recordPointer);
baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length
recordLength = Platform.getInt(baseObject, baseOffset - 4);
keyPrefix = sortBuffer[position + 1];
keyPrefix = sortBuffer.get(position + 1);
position += 2;
}

View file

@ -17,6 +17,9 @@
package org.apache.spark.util.collection.unsafe.sort;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.collection.SortDataFormat;
/**
@ -26,14 +29,14 @@ import org.apache.spark.util.collection.SortDataFormat;
* Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at
* index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
*/
final class UnsafeSortDataFormat extends SortDataFormat<RecordPointerAndKeyPrefix, long[]> {
final class UnsafeSortDataFormat extends SortDataFormat<RecordPointerAndKeyPrefix, LongArray> {
public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat();
private UnsafeSortDataFormat() { }
@Override
public RecordPointerAndKeyPrefix getKey(long[] data, int pos) {
public RecordPointerAndKeyPrefix getKey(LongArray data, int pos) {
// Since we re-use keys, this method shouldn't be called.
throw new UnsupportedOperationException();
}
@ -44,37 +47,43 @@ final class UnsafeSortDataFormat extends SortDataFormat<RecordPointerAndKeyPrefi
}
@Override
public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) {
reuse.recordPointer = data[pos * 2];
reuse.keyPrefix = data[pos * 2 + 1];
public RecordPointerAndKeyPrefix getKey(LongArray data, int pos, RecordPointerAndKeyPrefix reuse) {
reuse.recordPointer = data.get(pos * 2);
reuse.keyPrefix = data.get(pos * 2 + 1);
return reuse;
}
@Override
public void swap(long[] data, int pos0, int pos1) {
long tempPointer = data[pos0 * 2];
long tempKeyPrefix = data[pos0 * 2 + 1];
data[pos0 * 2] = data[pos1 * 2];
data[pos0 * 2 + 1] = data[pos1 * 2 + 1];
data[pos1 * 2] = tempPointer;
data[pos1 * 2 + 1] = tempKeyPrefix;
public void swap(LongArray data, int pos0, int pos1) {
long tempPointer = data.get(pos0 * 2);
long tempKeyPrefix = data.get(pos0 * 2 + 1);
data.set(pos0 * 2, data.get(pos1 * 2));
data.set(pos0 * 2 + 1, data.get(pos1 * 2 + 1));
data.set(pos1 * 2, tempPointer);
data.set(pos1 * 2 + 1, tempKeyPrefix);
}
@Override
public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
dst[dstPos * 2] = src[srcPos * 2];
dst[dstPos * 2 + 1] = src[srcPos * 2 + 1];
public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) {
dst.set(dstPos * 2, src.get(srcPos * 2));
dst.set(dstPos * 2 + 1, src.get(srcPos * 2 + 1));
}
@Override
public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2);
public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) {
Platform.copyMemory(
src.getBaseObject(),
src.getBaseOffset() + srcPos * 16,
dst.getBaseObject(),
dst.getBaseOffset() + dstPos * 16,
length * 16);
}
@Override
public long[] allocate(int length) {
public LongArray allocate(int length) {
assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large";
return new long[length * 2];
// This is used as temporary buffer, it's fine to allocate from JVM heap.
return new LongArray(MemoryBlock.fromLongArray(new long[length * 2]));
}
}

View file

@ -17,8 +17,6 @@
package org.apache.spark.memory;
import java.io.IOException;
import org.junit.Assert;
import org.junit.Test;
@ -27,27 +25,6 @@ import org.apache.spark.unsafe.memory.MemoryBlock;
public class TaskMemoryManagerSuite {
class TestMemoryConsumer extends MemoryConsumer {
TestMemoryConsumer(TaskMemoryManager memoryManager) {
super(memoryManager);
}
@Override
public long spill(long size, MemoryConsumer trigger) throws IOException {
long used = getUsed();
releaseMemory(used);
return used;
}
void use(long size) {
acquireMemory(size);
}
void free(long size) {
releaseMemory(size);
}
}
@Test
public void leakedPageMemoryIsDetected() {
final TaskMemoryManager manager = new TaskMemoryManager(

View file

@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.memory;
import java.io.IOException;
public class TestMemoryConsumer extends MemoryConsumer {
public TestMemoryConsumer(TaskMemoryManager memoryManager) {
super(memoryManager);
}
@Override
public long spill(long size, MemoryConsumer trigger) throws IOException {
long used = getUsed();
free(used);
return used;
}
void use(long size) {
long got = taskMemoryManager.acquireExecutionMemory(size, this);
used += got;
}
void free(long size) {
used -= size;
taskMemoryManager.releaseExecutionMemory(size, this);
}
}

View file

@ -25,13 +25,19 @@ import org.junit.Test;
import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkConf;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.memory.TestMemoryConsumer;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.memory.MemoryBlock;
public class ShuffleInMemorySorterSuite {
final TestMemoryManager memoryManager =
new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"));
final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
final TestMemoryConsumer consumer = new TestMemoryConsumer(taskMemoryManager);
private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
final byte[] strBytes = new byte[strLength];
Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength);
@ -40,7 +46,7 @@ public class ShuffleInMemorySorterSuite {
@Test
public void testSortingEmptyInput() {
final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100);
final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 100);
final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
assert(!iter.hasNext());
}
@ -63,7 +69,7 @@ public class ShuffleInMemorySorterSuite {
new TaskMemoryManager(new TestMemoryManager(conf), 0);
final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
final Object baseObject = dataPage.getBaseObject();
final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4);
final HashPartitioner hashPartitioner = new HashPartitioner(4);
// Write the records into the data page and store pointers into the sorter
@ -104,7 +110,7 @@ public class ShuffleInMemorySorterSuite {
@Test
public void testSortingManyNumbers() throws Exception {
ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4);
int[] numbersToSort = new int[128000];
Random random = new Random(16);
for (int i = 0; i < numbersToSort.length; i++) {

View file

@ -390,7 +390,6 @@ public class UnsafeExternalSorterSuite {
for (int i = 0; i < numRecordsPerPage * 10; i++) {
insertNumber(sorter, i);
newPeakMemory = sorter.getPeakMemoryUsedBytes();
// The first page is pre-allocated on instantiation
if (i % numRecordsPerPage == 0) {
// We allocated a new page for this record, so peak memory should change
assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);

View file

@ -23,6 +23,7 @@ import org.junit.Test;
import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkConf;
import org.apache.spark.memory.TestMemoryConsumer;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
@ -44,9 +45,11 @@ public class UnsafeInMemorySorterSuite {
@Test
public void testSortingEmptyInput() {
final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
new TaskMemoryManager(
new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0),
final TaskMemoryManager memoryManager = new TaskMemoryManager(
new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager);
final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer,
memoryManager,
mock(RecordComparator.class),
mock(PrefixComparator.class),
100);
@ -69,6 +72,7 @@ public class UnsafeInMemorySorterSuite {
};
final TaskMemoryManager memoryManager = new TaskMemoryManager(
new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager);
final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
final Object baseObject = dataPage.getBaseObject();
// Write the records into the data page:
@ -102,7 +106,7 @@ public class UnsafeInMemorySorterSuite {
return (int) prefix1 - (int) prefix2;
}
};
UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator,
UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, recordComparator,
prefixComparator, dataToSort.length);
// Given a page of records, insert those records into the sorter one-by-one:
position = dataPage.getBaseOffset();

View file

@ -85,8 +85,9 @@ public final class UnsafeKVExternalSorter {
} else {
// During spilling, the array in map will not be used, so we can borrow that and use it
// as the underline array for in-memory sorter (it's always large enough).
// Since we will not grow the array, it's fine to pass `null` as consumer.
final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
taskMemoryManager, recordComparator, prefixComparator, map.getArray());
null, taskMemoryManager, recordComparator, prefixComparator, map.getArray());
// We cannot use the destructive iterator here because we are reusing the existing memory
// pages in BytesToBytesMap to hold records during sorting.

View file

@ -39,7 +39,6 @@ public final class LongArray {
private final long length;
public LongArray(MemoryBlock memory) {
assert memory.size() % WIDTH == 0 : "Memory not aligned (" + memory.size() + ")";
assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size > 4 billion elements";
this.memory = memory;
this.baseObj = memory.getBaseObject();
@ -51,6 +50,14 @@ public final class LongArray {
return memory;
}
public Object getBaseObject() {
return baseObj;
}
public long getBaseOffset() {
return baseOffset;
}
/**
* Returns the number of elements this array can hold.
*/
@ -58,6 +65,15 @@ public final class LongArray {
return length;
}
/**
* Fill this all with 0L.
*/
public void zeroOut() {
for (long off = baseOffset; off < baseOffset + length * WIDTH; off += WIDTH) {
Platform.putLong(baseObj, off, 0);
}
}
/**
* Sets the value at position {@code index}.
*/

View file

@ -34,5 +34,9 @@ public class LongArraySuite {
Assert.assertEquals(2, arr.size());
Assert.assertEquals(1L, arr.get(0));
Assert.assertEquals(3L, arr.get(1));
arr.zeroOut();
Assert.assertEquals(0L, arr.get(0));
Assert.assertEquals(0L, arr.get(1));
}
}