[SPARK-32900][CORE] Allow UnsafeExternalSorter to spill when there are nulls

### What changes were proposed in this pull request?

This PR changes the way `UnsafeExternalSorter.SpillableIterator` checks whether it has spilled already, by checking whether `inMemSorter` is null. It also allows it to spill other `UnsafeSorterIterator`s than `UnsafeInMemorySorter.SortedIterator`.

### Why are the changes needed?

Before this PR `UnsafeExternalSorter.SpillableIterator` could not spill when there are NULLs in the input and radix sorting is used. Currently, Spark determines whether UnsafeExternalSorter.SpillableIterator has not spilled yet by checking whether `upstream` is an instance of `UnsafeInMemorySorter.SortedIterator`. When radix sorting is used and there are NULLs in the input however, `upstream` will be an instance of `UnsafeExternalSorter.ChainedIterator` instead, and Spark will assume that the `SpillableIterator` iterator has spilled already, and therefore cannot spill again when it's supposed to spill.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

A test was added to `UnsafeExternalSorterSuite` (and therefore also to `UnsafeExternalSorterRadixSortSuite`). I manually confirmed that the test failed in `UnsafeExternalSorterRadixSortSuite` without this patch.

Closes #29772 from tomvanbussel/SPARK-32900.

Authored-by: Tom van Bussel <tom.vanbussel@databricks.com>
Signed-off-by: herman <herman@databricks.com>
This commit is contained in:
Tom van Bussel 2020-09-17 12:35:40 +02:00 committed by herman
parent 92b75dc260
commit e5e54a3614
6 changed files with 88 additions and 27 deletions

View file

@ -501,11 +501,15 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
*/ */
class SpillableIterator extends UnsafeSorterIterator { class SpillableIterator extends UnsafeSorterIterator {
private UnsafeSorterIterator upstream; private UnsafeSorterIterator upstream;
private UnsafeSorterIterator nextUpstream = null;
private MemoryBlock lastPage = null; private MemoryBlock lastPage = null;
private boolean loaded = false; private boolean loaded = false;
private int numRecords = 0; private int numRecords = 0;
private Object currentBaseObject;
private long currentBaseOffset;
private int currentRecordLength;
private long currentKeyPrefix;
SpillableIterator(UnsafeSorterIterator inMemIterator) { SpillableIterator(UnsafeSorterIterator inMemIterator) {
this.upstream = inMemIterator; this.upstream = inMemIterator;
this.numRecords = inMemIterator.getNumRecords(); this.numRecords = inMemIterator.getNumRecords();
@ -516,23 +520,26 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
return numRecords; return numRecords;
} }
@Override
public long getCurrentPageNumber() {
throw new UnsupportedOperationException();
}
public long spill() throws IOException { public long spill() throws IOException {
synchronized (this) { synchronized (this) {
if (!(upstream instanceof UnsafeInMemorySorter.SortedIterator && nextUpstream == null if (inMemSorter == null || numRecords <= 0) {
&& numRecords > 0)) {
return 0L; return 0L;
} }
UnsafeInMemorySorter.SortedIterator inMemIterator = long currentPageNumber = upstream.getCurrentPageNumber();
((UnsafeInMemorySorter.SortedIterator) upstream).clone();
ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics();
// Iterate over the records that have not been returned and spill them. // Iterate over the records that have not been returned and spill them.
final UnsafeSorterSpillWriter spillWriter = final UnsafeSorterSpillWriter spillWriter =
new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords); new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords);
spillIterator(inMemIterator, spillWriter); spillIterator(upstream, spillWriter);
spillWriters.add(spillWriter); spillWriters.add(spillWriter);
nextUpstream = spillWriter.getReader(serializerManager); upstream = spillWriter.getReader(serializerManager);
long released = 0L; long released = 0L;
synchronized (UnsafeExternalSorter.this) { synchronized (UnsafeExternalSorter.this) {
@ -540,8 +547,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
// is accessing the current record. We free this page in that caller's next loadNext() // is accessing the current record. We free this page in that caller's next loadNext()
// call. // call.
for (MemoryBlock page : allocatedPages) { for (MemoryBlock page : allocatedPages) {
if (!loaded || page.pageNumber != if (!loaded || page.pageNumber != currentPageNumber) {
((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) {
released += page.size(); released += page.size();
freePage(page); freePage(page);
} else { } else {
@ -575,22 +581,26 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
try { try {
synchronized (this) { synchronized (this) {
loaded = true; loaded = true;
if (nextUpstream != null) { // Just consumed the last record from in memory iterator
// Just consumed the last record from in memory iterator if (lastPage != null) {
if(lastPage != null) { // Do not free the page here, while we are locking `SpillableIterator`. The `freePage`
// Do not free the page here, while we are locking `SpillableIterator`. The `freePage` // method locks the `TaskMemoryManager`, and it's a bad idea to lock 2 objects in
// method locks the `TaskMemoryManager`, and it's a bad idea to lock 2 objects in // sequence. We may hit dead lock if another thread locks `TaskMemoryManager` and
// sequence. We may hit dead lock if another thread locks `TaskMemoryManager` and // `SpillableIterator` in sequence, which may happen in
// `SpillableIterator` in sequence, which may happen in // `TaskMemoryManager.acquireExecutionMemory`.
// `TaskMemoryManager.acquireExecutionMemory`. pageToFree = lastPage;
pageToFree = lastPage; lastPage = null;
lastPage = null;
}
upstream = nextUpstream;
nextUpstream = null;
} }
numRecords--; numRecords--;
upstream.loadNext(); upstream.loadNext();
// Keep track of the current base object, base offset, record length, and key prefix,
// so that the current record can still be read in case a spill is triggered and we
// switch to the spill writer's iterator.
currentBaseObject = upstream.getBaseObject();
currentBaseOffset = upstream.getBaseOffset();
currentRecordLength = upstream.getRecordLength();
currentKeyPrefix = upstream.getKeyPrefix();
} }
} finally { } finally {
if (pageToFree != null) { if (pageToFree != null) {
@ -601,22 +611,22 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
@Override @Override
public Object getBaseObject() { public Object getBaseObject() {
return upstream.getBaseObject(); return currentBaseObject;
} }
@Override @Override
public long getBaseOffset() { public long getBaseOffset() {
return upstream.getBaseOffset(); return currentBaseOffset;
} }
@Override @Override
public int getRecordLength() { public int getRecordLength() {
return upstream.getRecordLength(); return currentRecordLength;
} }
@Override @Override
public long getKeyPrefix() { public long getKeyPrefix() {
return upstream.getKeyPrefix(); return currentKeyPrefix;
} }
} }
@ -693,6 +703,11 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
return numRecords; return numRecords;
} }
@Override
public long getCurrentPageNumber() {
return current.getCurrentPageNumber();
}
@Override @Override
public boolean hasNext() { public boolean hasNext() {
while (!current.hasNext() && !iterators.isEmpty()) { while (!current.hasNext() && !iterators.isEmpty()) {

View file

@ -330,6 +330,7 @@ public final class UnsafeInMemorySorter {
@Override @Override
public long getBaseOffset() { return baseOffset; } public long getBaseOffset() { return baseOffset; }
@Override
public long getCurrentPageNumber() { public long getCurrentPageNumber() {
return currentPageNumber; return currentPageNumber;
} }

View file

@ -34,4 +34,6 @@ public abstract class UnsafeSorterIterator {
public abstract long getKeyPrefix(); public abstract long getKeyPrefix();
public abstract int getNumRecords(); public abstract int getNumRecords();
public abstract long getCurrentPageNumber();
} }

View file

@ -70,6 +70,11 @@ final class UnsafeSorterSpillMerger {
return numRecords; return numRecords;
} }
@Override
public long getCurrentPageNumber() {
throw new UnsupportedOperationException();
}
@Override @Override
public boolean hasNext() { public boolean hasNext() {
return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()); return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext());

View file

@ -89,6 +89,11 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
return numRecords; return numRecords;
} }
@Override
public long getCurrentPageNumber() {
throw new UnsupportedOperationException();
}
@Override @Override
public boolean hasNext() { public boolean hasNext() {
return (numRecordsRemaining > 0); return (numRecordsRemaining > 0);

View file

@ -359,6 +359,39 @@ public class UnsafeExternalSorterSuite {
assertSpillFilesWereCleanedUp(); assertSpillFilesWereCleanedUp();
} }
@Test
public void forcedSpillingNullsWithReadIterator() throws Exception {
final UnsafeExternalSorter sorter = newSorter();
long[] record = new long[100];
final int recordSize = record.length * 8;
final int n = (int) pageSizeBytes / recordSize * 3;
for (int i = 0; i < n; i++) {
boolean isNull = i % 2 == 0;
sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, isNull);
}
assertTrue(sorter.getNumberOfAllocatedPages() >= 2);
UnsafeExternalSorter.SpillableIterator iter =
(UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator();
final int numRecordsToReadBeforeSpilling = n / 3;
for (int i = 0; i < numRecordsToReadBeforeSpilling; i++) {
assertTrue(iter.hasNext());
iter.loadNext();
}
assertTrue(iter.spill() > 0);
assertEquals(0, iter.spill());
for (int i = numRecordsToReadBeforeSpilling; i < n; i++) {
assertTrue(iter.hasNext());
iter.loadNext();
}
assertFalse(iter.hasNext());
sorter.cleanupResources();
assertSpillFilesWereCleanedUp();
}
@Test @Test
public void forcedSpillingWithNotReadIterator() throws Exception { public void forcedSpillingWithNotReadIterator() throws Exception {
final UnsafeExternalSorter sorter = newSorter(); final UnsafeExternalSorter sorter = newSorter();