[SPARK-9577][SQL] Surface concrete iterator types in various sort classes.

We often return abstract iterator types in various sort-related classes (e.g. UnsafeKVExternalSorter). It is actually better to return a more concrete type, so the callsite uses that type and JIT can inline the iterator calls.

Author: Reynold Xin <rxin@databricks.com>

Closes #7911 from rxin/surface-concrete-type and squashes the following commits:

0422add [Reynold Xin] [SPARK-9577][SQL] Surface concrete iterator types in various sort classes.
This commit is contained in:
Reynold Xin 2015-08-03 18:47:02 -07:00
parent 3b0e44490a
commit 5eb89f67e3
4 changed files with 65 additions and 85 deletions

View file

@ -428,7 +428,7 @@ public final class UnsafeExternalSorter {
public UnsafeSorterIterator getSortedIterator() throws IOException {
assert(inMemSorter != null);
final UnsafeSorterIterator inMemoryIterator = inMemSorter.getSortedIterator();
final UnsafeInMemorySorter.SortedIterator inMemoryIterator = inMemSorter.getSortedIterator();
int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
if (spillWriters.isEmpty()) {
return inMemoryIterator;

View file

@ -133,7 +133,7 @@ public final class UnsafeInMemorySorter {
pointerArrayInsertPosition++;
}
private static final class SortedIterator extends UnsafeSorterIterator {
public static final class SortedIterator extends UnsafeSorterIterator {
private final TaskMemoryManager memoryManager;
private final int sortBufferInsertPosition;
@ -144,7 +144,7 @@ public final class UnsafeInMemorySorter {
private long keyPrefix;
private int recordLength;
SortedIterator(
private SortedIterator(
TaskMemoryManager memoryManager,
int sortBufferInsertPosition,
long[] sortBuffer) {
@ -186,7 +186,7 @@ public final class UnsafeInMemorySorter {
* Return an iterator over record pointers in sorted order. For efficiency, all calls to
* {@code next()} will return the same mutable object.
*/
public UnsafeSorterIterator getSortedIterator() {
public SortedIterator getSortedIterator() {
sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator);
return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray);
}

View file

@ -134,7 +134,7 @@ public final class UnsafeKVExternalSorter {
value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix);
}
public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() throws IOException {
public KVSorterIterator sortedIterator() throws IOException {
try {
final UnsafeSorterIterator underlying = sorter.getSortedIterator();
if (!underlying.hasNext()) {
@ -142,58 +142,7 @@ public final class UnsafeKVExternalSorter {
// here in order to prevent memory leaks.
cleanupResources();
}
return new KVIterator<UnsafeRow, UnsafeRow>() {
private UnsafeRow key = new UnsafeRow();
private UnsafeRow value = new UnsafeRow();
private int numKeyFields = keySchema.size();
private int numValueFields = valueSchema.size();
@Override
public boolean next() throws IOException {
try {
if (underlying.hasNext()) {
underlying.loadNext();
Object baseObj = underlying.getBaseObject();
long recordOffset = underlying.getBaseOffset();
int recordLen = underlying.getRecordLength();
// Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
int valueLen = recordLen - keyLen - 4;
key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen);
return true;
} else {
key = null;
value = null;
cleanupResources();
return false;
}
} catch (IOException e) {
cleanupResources();
throw e;
}
}
@Override
public UnsafeRow getKey() {
return key;
}
@Override
public UnsafeRow getValue() {
return value;
}
@Override
public void close() {
cleanupResources();
}
};
return new KVSorterIterator(underlying);
} catch (IOException e) {
cleanupResources();
throw e;
@ -233,4 +182,61 @@ public final class UnsafeKVExternalSorter {
return ordering.compare(row1, row2);
}
}
public class KVSorterIterator extends KVIterator<UnsafeRow, UnsafeRow> {
private UnsafeRow key = new UnsafeRow();
private UnsafeRow value = new UnsafeRow();
private final int numKeyFields = keySchema.size();
private final int numValueFields = valueSchema.size();
private final UnsafeSorterIterator underlying;
private KVSorterIterator(UnsafeSorterIterator underlying) {
this.underlying = underlying;
}
@Override
public boolean next() throws IOException {
try {
if (underlying.hasNext()) {
underlying.loadNext();
Object baseObj = underlying.getBaseObject();
long recordOffset = underlying.getBaseOffset();
int recordLen = underlying.getRecordLength();
// Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
int valueLen = recordLen - keyLen - 4;
key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen);
return true;
} else {
key = null;
value = null;
cleanupResources();
return false;
}
} catch (IOException e) {
cleanupResources();
throw e;
}
}
@Override
public UnsafeRow getKey() {
return key;
}
@Override
public UnsafeRow getValue() {
return value;
}
@Override
public void close() {
cleanupResources();
}
};
}

View file

@ -17,12 +17,12 @@
package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.{UnsafeKeyValueSorter, UnsafeFixedWidthAggregationMap}
import org.apache.spark.unsafe.KVIterator
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.types.StructType
/**
@ -230,7 +230,7 @@ class UnsafeHybridAggregationIterator(
}
// Step 5: Get the sorted iterator from the externalSorter.
val sortedKVIterator: KVIterator[UnsafeRow, UnsafeRow] = externalSorter.sortedIterator()
val sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = externalSorter.sortedIterator()
// Step 6: We now create a SortBasedAggregationIterator based on sortedKVIterator.
// For a aggregate function with mode Partial, its mode in the SortBasedAggregationIterator
@ -368,31 +368,5 @@ object UnsafeHybridAggregationIterator {
newMutableProjection,
outputsUnsafeRows)
}
def createFromKVIterator(
groupingKeyAttributes: Seq[Attribute],
valueAttributes: Seq[Attribute],
inputKVIterator: KVIterator[UnsafeRow, InternalRow],
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression2],
completeAggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = {
new UnsafeHybridAggregationIterator(
groupingKeyAttributes,
valueAttributes,
inputKVIterator,
nonCompleteAggregateExpressions,
nonCompleteAggregateAttributes,
completeAggregateExpressions,
completeAggregateAttributes,
initialInputBufferOffset,
resultExpressions,
newMutableProjection,
outputsUnsafeRows)
}
// scalastyle:on
}