[SPARK-12950] [SQL] Improve lookup of BytesToBytesMap in aggregate

This PR improve the lookup of BytesToBytesMap by:

1. Generate code for calculate the hash code of grouping keys.

2. Do not use MemoryLocation, fetch the baseObject and offset for key and value directly (remove the indirection).

Author: Davies Liu <davies@databricks.com>

Closes #11010 from davies/gen_map.
This commit is contained in:
Davies Liu 2016-02-09 16:41:21 -08:00 committed by Davies Liu
parent fae830d158
commit 0e5ebac3c1
10 changed files with 184 additions and 129 deletions

View file

@ -38,7 +38,6 @@ import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader; import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader;
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter; import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter;
@ -65,8 +64,6 @@ public final class BytesToBytesMap extends MemoryConsumer {
private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class); private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class);
private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0);
private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
private final TaskMemoryManager taskMemoryManager; private final TaskMemoryManager taskMemoryManager;
@ -417,7 +414,19 @@ public final class BytesToBytesMap extends MemoryConsumer {
* This function always return the same {@link Location} instance to avoid object allocation. * This function always return the same {@link Location} instance to avoid object allocation.
*/ */
public Location lookup(Object keyBase, long keyOffset, int keyLength) { public Location lookup(Object keyBase, long keyOffset, int keyLength) {
safeLookup(keyBase, keyOffset, keyLength, loc); safeLookup(keyBase, keyOffset, keyLength, loc,
Murmur3_x86_32.hashUnsafeWords(keyBase, keyOffset, keyLength, 42));
return loc;
}
/**
* Looks up a key, and return a {@link Location} handle that can be used to test existence
* and read/write values.
*
* This function always return the same {@link Location} instance to avoid object allocation.
*/
public Location lookup(Object keyBase, long keyOffset, int keyLength, int hash) {
safeLookup(keyBase, keyOffset, keyLength, loc, hash);
return loc; return loc;
} }
@ -426,14 +435,13 @@ public final class BytesToBytesMap extends MemoryConsumer {
* *
* This is a thread-safe version of `lookup`, could be used by multiple threads. * This is a thread-safe version of `lookup`, could be used by multiple threads.
*/ */
public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc) { public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc, int hash) {
assert(longArray != null); assert(longArray != null);
if (enablePerfMetrics) { if (enablePerfMetrics) {
numKeyLookups++; numKeyLookups++;
} }
final int hashcode = HASHER.hashUnsafeWords(keyBase, keyOffset, keyLength); int pos = hash & mask;
int pos = hashcode & mask;
int step = 1; int step = 1;
while (true) { while (true) {
if (enablePerfMetrics) { if (enablePerfMetrics) {
@ -441,22 +449,19 @@ public final class BytesToBytesMap extends MemoryConsumer {
} }
if (longArray.get(pos * 2) == 0) { if (longArray.get(pos * 2) == 0) {
// This is a new key. // This is a new key.
loc.with(pos, hashcode, false); loc.with(pos, hash, false);
return; return;
} else { } else {
long stored = longArray.get(pos * 2 + 1); long stored = longArray.get(pos * 2 + 1);
if ((int) (stored) == hashcode) { if ((int) (stored) == hash) {
// Full hash code matches. Let's compare the keys for equality. // Full hash code matches. Let's compare the keys for equality.
loc.with(pos, hashcode, true); loc.with(pos, hash, true);
if (loc.getKeyLength() == keyLength) { if (loc.getKeyLength() == keyLength) {
final MemoryLocation keyAddress = loc.getKeyAddress();
final Object storedkeyBase = keyAddress.getBaseObject();
final long storedkeyOffset = keyAddress.getBaseOffset();
final boolean areEqual = ByteArrayMethods.arrayEquals( final boolean areEqual = ByteArrayMethods.arrayEquals(
keyBase, keyBase,
keyOffset, keyOffset,
storedkeyBase, loc.getKeyBase(),
storedkeyOffset, loc.getKeyOffset(),
keyLength keyLength
); );
if (areEqual) { if (areEqual) {
@ -484,13 +489,14 @@ public final class BytesToBytesMap extends MemoryConsumer {
private boolean isDefined; private boolean isDefined;
/** /**
* The hashcode of the most recent key passed to * The hashcode of the most recent key passed to
* {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to * {@link BytesToBytesMap#lookup(Object, long, int, int)}. Caching this hashcode here allows us
* avoid re-hashing the key when storing a value for that key. * to avoid re-hashing the key when storing a value for that key.
*/ */
private int keyHashcode; private int keyHashcode;
private final MemoryLocation keyMemoryLocation = new MemoryLocation(); private Object baseObject; // the base object for key and value
private final MemoryLocation valueMemoryLocation = new MemoryLocation(); private long keyOffset;
private int keyLength; private int keyLength;
private long valueOffset;
private int valueLength; private int valueLength;
/** /**
@ -504,18 +510,15 @@ public final class BytesToBytesMap extends MemoryConsumer {
taskMemoryManager.getOffsetInPage(fullKeyAddress)); taskMemoryManager.getOffsetInPage(fullKeyAddress));
} }
private void updateAddressesAndSizes(final Object base, final long offset) { private void updateAddressesAndSizes(final Object base, long offset) {
long position = offset; baseObject = base;
final int totalLength = Platform.getInt(base, position); final int totalLength = Platform.getInt(base, offset);
position += 4; offset += 4;
keyLength = Platform.getInt(base, position); keyLength = Platform.getInt(base, offset);
position += 4; offset += 4;
keyOffset = offset;
valueOffset = offset + keyLength;
valueLength = totalLength - keyLength - 4; valueLength = totalLength - keyLength - 4;
keyMemoryLocation.setObjAndOffset(base, position);
position += keyLength;
valueMemoryLocation.setObjAndOffset(base, position);
} }
private Location with(int pos, int keyHashcode, boolean isDefined) { private Location with(int pos, int keyHashcode, boolean isDefined) {
@ -543,10 +546,11 @@ public final class BytesToBytesMap extends MemoryConsumer {
private Location with(Object base, long offset, int length) { private Location with(Object base, long offset, int length) {
this.isDefined = true; this.isDefined = true;
this.memoryPage = null; this.memoryPage = null;
baseObject = base;
keyOffset = offset + 4;
keyLength = Platform.getInt(base, offset); keyLength = Platform.getInt(base, offset);
valueOffset = offset + 4 + keyLength;
valueLength = length - 4 - keyLength; valueLength = length - 4 - keyLength;
keyMemoryLocation.setObjAndOffset(base, offset + 4);
valueMemoryLocation.setObjAndOffset(base, offset + 4 + keyLength);
return this; return this;
} }
@ -566,14 +570,35 @@ public final class BytesToBytesMap extends MemoryConsumer {
} }
/** /**
* Returns the address of the key defined at this position. * Returns the base object for key.
* This points to the first byte of the key data.
* Unspecified behavior if the key is not defined.
* For efficiency reasons, calls to this method always returns the same MemoryLocation object.
*/ */
public MemoryLocation getKeyAddress() { public Object getKeyBase() {
assert (isDefined); assert (isDefined);
return keyMemoryLocation; return baseObject;
}
/**
* Returns the offset for key.
*/
public long getKeyOffset() {
assert (isDefined);
return keyOffset;
}
/**
* Returns the base object for value.
*/
public Object getValueBase() {
assert (isDefined);
return baseObject;
}
/**
* Returns the offset for value.
*/
public long getValueOffset() {
assert (isDefined);
return valueOffset;
} }
/** /**
@ -585,17 +610,6 @@ public final class BytesToBytesMap extends MemoryConsumer {
return keyLength; return keyLength;
} }
/**
* Returns the address of the value defined at this position.
* This points to the first byte of the value data.
* Unspecified behavior if the key is not defined.
* For efficiency reasons, calls to this method always returns the same MemoryLocation object.
*/
public MemoryLocation getValueAddress() {
assert (isDefined);
return valueMemoryLocation;
}
/** /**
* Returns the length of the value defined at this position. * Returns the length of the value defined at this position.
* Unspecified behavior if the key is not defined. * Unspecified behavior if the key is not defined.

View file

@ -39,14 +39,13 @@ import org.mockito.stubbing.Answer;
import org.apache.spark.SparkConf; import org.apache.spark.SparkConf;
import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.storage.*; import org.apache.spark.storage.*;
import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.util.Utils; import org.apache.spark.util.Utils;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThan;
@ -142,10 +141,9 @@ public abstract class AbstractBytesToBytesMapSuite {
protected abstract boolean useOffHeapMemoryAllocator(); protected abstract boolean useOffHeapMemoryAllocator();
private static byte[] getByteArray(MemoryLocation loc, int size) { private static byte[] getByteArray(Object base, long offset, int size) {
final byte[] arr = new byte[size]; final byte[] arr = new byte[size];
Platform.copyMemory( Platform.copyMemory(base, offset, arr, Platform.BYTE_ARRAY_OFFSET, size);
loc.getBaseObject(), loc.getBaseOffset(), arr, Platform.BYTE_ARRAY_OFFSET, size);
return arr; return arr;
} }
@ -163,13 +161,14 @@ public abstract class AbstractBytesToBytesMapSuite {
*/ */
private static boolean arrayEquals( private static boolean arrayEquals(
byte[] expected, byte[] expected,
MemoryLocation actualAddr, Object base,
long offset,
long actualLengthBytes) { long actualLengthBytes) {
return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals( return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals(
expected, expected,
Platform.BYTE_ARRAY_OFFSET, Platform.BYTE_ARRAY_OFFSET,
actualAddr.getBaseObject(), base,
actualAddr.getBaseOffset(), offset,
expected.length expected.length
); );
} }
@ -212,16 +211,20 @@ public abstract class AbstractBytesToBytesMapSuite {
// reflect the result of this store without us having to call lookup() again on the same key. // reflect the result of this store without us having to call lookup() again on the same key.
Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
Assert.assertEquals(recordLengthBytes, loc.getValueLength()); Assert.assertEquals(recordLengthBytes, loc.getValueLength());
Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); Assert.assertArrayEquals(keyData,
Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes));
Assert.assertArrayEquals(valueData,
getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes));
// After calling lookup() the location should still point to the correct data. // After calling lookup() the location should still point to the correct data.
Assert.assertTrue( Assert.assertTrue(
map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined());
Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
Assert.assertEquals(recordLengthBytes, loc.getValueLength()); Assert.assertEquals(recordLengthBytes, loc.getValueLength());
Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); Assert.assertArrayEquals(keyData,
Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes));
Assert.assertArrayEquals(valueData,
getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes));
try { try {
Assert.assertTrue(loc.putNewKey( Assert.assertTrue(loc.putNewKey(
@ -283,15 +286,12 @@ public abstract class AbstractBytesToBytesMapSuite {
while (iter.hasNext()) { while (iter.hasNext()) {
final BytesToBytesMap.Location loc = iter.next(); final BytesToBytesMap.Location loc = iter.next();
Assert.assertTrue(loc.isDefined()); Assert.assertTrue(loc.isDefined());
final MemoryLocation keyAddress = loc.getKeyAddress(); final long value = Platform.getLong(loc.getValueBase(), loc.getValueOffset());
final MemoryLocation valueAddress = loc.getValueAddress();
final long value = Platform.getLong(
valueAddress.getBaseObject(), valueAddress.getBaseOffset());
final long keyLength = loc.getKeyLength(); final long keyLength = loc.getKeyLength();
if (keyLength == 0) { if (keyLength == 0) {
Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0); Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0);
} else { } else {
final long key = Platform.getLong(keyAddress.getBaseObject(), keyAddress.getBaseOffset()); final long key = Platform.getLong(loc.getKeyBase(), loc.getKeyOffset());
Assert.assertEquals(value, key); Assert.assertEquals(value, key);
} }
valuesSeen.set((int) value); valuesSeen.set((int) value);
@ -365,15 +365,15 @@ public abstract class AbstractBytesToBytesMapSuite {
Assert.assertEquals(KEY_LENGTH, loc.getKeyLength()); Assert.assertEquals(KEY_LENGTH, loc.getKeyLength());
Assert.assertEquals(VALUE_LENGTH, loc.getValueLength()); Assert.assertEquals(VALUE_LENGTH, loc.getValueLength());
Platform.copyMemory( Platform.copyMemory(
loc.getKeyAddress().getBaseObject(), loc.getKeyBase(),
loc.getKeyAddress().getBaseOffset(), loc.getKeyOffset(),
key, key,
Platform.LONG_ARRAY_OFFSET, Platform.LONG_ARRAY_OFFSET,
KEY_LENGTH KEY_LENGTH
); );
Platform.copyMemory( Platform.copyMemory(
loc.getValueAddress().getBaseObject(), loc.getValueBase(),
loc.getValueAddress().getBaseOffset(), loc.getValueOffset(),
value, value,
Platform.LONG_ARRAY_OFFSET, Platform.LONG_ARRAY_OFFSET,
VALUE_LENGTH VALUE_LENGTH
@ -425,8 +425,9 @@ public abstract class AbstractBytesToBytesMapSuite {
Assert.assertTrue(loc.isDefined()); Assert.assertTrue(loc.isDefined());
Assert.assertEquals(key.length, loc.getKeyLength()); Assert.assertEquals(key.length, loc.getKeyLength());
Assert.assertEquals(value.length, loc.getValueLength()); Assert.assertEquals(value.length, loc.getValueLength());
Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length)); Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length));
Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length)); Assert.assertTrue(
arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length));
} }
} }
@ -436,8 +437,10 @@ public abstract class AbstractBytesToBytesMapSuite {
final BytesToBytesMap.Location loc = final BytesToBytesMap.Location loc =
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
Assert.assertTrue(loc.isDefined()); Assert.assertTrue(loc.isDefined());
Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); Assert.assertTrue(
Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength()));
Assert.assertTrue(
arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength()));
} }
} finally { } finally {
map.free(); map.free();
@ -476,8 +479,9 @@ public abstract class AbstractBytesToBytesMapSuite {
Assert.assertTrue(loc.isDefined()); Assert.assertTrue(loc.isDefined());
Assert.assertEquals(key.length, loc.getKeyLength()); Assert.assertEquals(key.length, loc.getKeyLength());
Assert.assertEquals(value.length, loc.getValueLength()); Assert.assertEquals(value.length, loc.getValueLength());
Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length)); Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length));
Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length)); Assert.assertTrue(
arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length));
} }
} }
for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) { for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
@ -486,8 +490,10 @@ public abstract class AbstractBytesToBytesMapSuite {
final BytesToBytesMap.Location loc = final BytesToBytesMap.Location loc =
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length); map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
Assert.assertTrue(loc.isDefined()); Assert.assertTrue(loc.isDefined());
Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); Assert.assertTrue(
Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength()));
Assert.assertTrue(
arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength()));
} }
} finally { } finally {
map.free(); map.free();

View file

@ -40,6 +40,7 @@ object MimaExcludes {
excludePackage("org.apache.spark.rpc"), excludePackage("org.apache.spark.rpc"),
excludePackage("org.spark-project.jetty"), excludePackage("org.spark-project.jetty"),
excludePackage("org.apache.spark.unused"), excludePackage("org.apache.spark.unused"),
excludePackage("org.apache.spark.unsafe"),
excludePackage("org.apache.spark.util.collection.unsafe"), excludePackage("org.apache.spark.util.collection.unsafe"),
excludePackage("org.apache.spark.sql.catalyst"), excludePackage("org.apache.spark.sql.catalyst"),
excludePackage("org.apache.spark.sql.execution"), excludePackage("org.apache.spark.sql.execution"),

View file

@ -322,7 +322,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
} }
} }
override def genCode(ctx: CodegenContext, ev: ExprCode): String = { override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
ev.isNull = "false" ev.isNull = "false"
val childrenHash = children.map { child => val childrenHash = children.map { child =>

View file

@ -121,19 +121,24 @@ public final class UnsafeFixedWidthAggregationMap {
return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow); return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow);
} }
public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRow) { public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key) {
return getAggregationBufferFromUnsafeRow(key, key.hashCode());
}
public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key, int hash) {
// Probe our map using the serialized key // Probe our map using the serialized key
final BytesToBytesMap.Location loc = map.lookup( final BytesToBytesMap.Location loc = map.lookup(
unsafeGroupingKeyRow.getBaseObject(), key.getBaseObject(),
unsafeGroupingKeyRow.getBaseOffset(), key.getBaseOffset(),
unsafeGroupingKeyRow.getSizeInBytes()); key.getSizeInBytes(),
hash);
if (!loc.isDefined()) { if (!loc.isDefined()) {
// This is the first time that we've seen this grouping key, so we'll insert a copy of the // This is the first time that we've seen this grouping key, so we'll insert a copy of the
// empty aggregation buffer into the map: // empty aggregation buffer into the map:
boolean putSucceeded = loc.putNewKey( boolean putSucceeded = loc.putNewKey(
unsafeGroupingKeyRow.getBaseObject(), key.getBaseObject(),
unsafeGroupingKeyRow.getBaseOffset(), key.getBaseOffset(),
unsafeGroupingKeyRow.getSizeInBytes(), key.getSizeInBytes(),
emptyAggregationBuffer, emptyAggregationBuffer,
Platform.BYTE_ARRAY_OFFSET, Platform.BYTE_ARRAY_OFFSET,
emptyAggregationBuffer.length emptyAggregationBuffer.length
@ -144,10 +149,9 @@ public final class UnsafeFixedWidthAggregationMap {
} }
// Reset the pointer to point to the value that we just stored or looked up: // Reset the pointer to point to the value that we just stored or looked up:
final MemoryLocation address = loc.getValueAddress();
currentAggregationBuffer.pointTo( currentAggregationBuffer.pointTo(
address.getBaseObject(), loc.getValueBase(),
address.getBaseOffset(), loc.getValueOffset(),
loc.getValueLength() loc.getValueLength()
); );
return currentAggregationBuffer; return currentAggregationBuffer;
@ -172,16 +176,14 @@ public final class UnsafeFixedWidthAggregationMap {
public boolean next() { public boolean next() {
if (mapLocationIterator.hasNext()) { if (mapLocationIterator.hasNext()) {
final BytesToBytesMap.Location loc = mapLocationIterator.next(); final BytesToBytesMap.Location loc = mapLocationIterator.next();
final MemoryLocation keyAddress = loc.getKeyAddress();
final MemoryLocation valueAddress = loc.getValueAddress();
key.pointTo( key.pointTo(
keyAddress.getBaseObject(), loc.getKeyBase(),
keyAddress.getBaseOffset(), loc.getKeyOffset(),
loc.getKeyLength() loc.getKeyLength()
); );
value.pointTo( value.pointTo(
valueAddress.getBaseObject(), loc.getValueBase(),
valueAddress.getBaseOffset(), loc.getValueOffset(),
loc.getValueLength() loc.getValueLength()
); );
return true; return true;

View file

@ -97,8 +97,8 @@ public final class UnsafeKVExternalSorter {
UnsafeRow row = new UnsafeRow(numKeyFields); UnsafeRow row = new UnsafeRow(numKeyFields);
while (iter.hasNext()) { while (iter.hasNext()) {
final BytesToBytesMap.Location loc = iter.next(); final BytesToBytesMap.Location loc = iter.next();
final Object baseObject = loc.getKeyAddress().getBaseObject(); final Object baseObject = loc.getKeyBase();
final long baseOffset = loc.getKeyAddress().getBaseOffset(); final long baseOffset = loc.getKeyOffset();
// Get encoded memory address // Get encoded memory address
// baseObject + baseOffset point to the beginning of the key data in the map, but that // baseObject + baseOffset point to the beginning of the key data in the map, but that

View file

@ -366,11 +366,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
def apply(plan: SparkPlan): SparkPlan = { def apply(plan: SparkPlan): SparkPlan = {
if (sqlContext.conf.wholeStageEnabled) { if (sqlContext.conf.wholeStageEnabled) {
plan.transform { plan.transform {
case plan: CodegenSupport if supportCodegen(plan) && case plan: CodegenSupport if supportCodegen(plan) =>
// Whole stage codegen is only useful when there are at least two levels of operators that
// support it (save at least one projection/iterator).
(Utils.isTesting || plan.children.exists(supportCodegen)) =>
var inputs = ArrayBuffer[SparkPlan]() var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform { val combined = plan.transform {
// The build side can't be compiled together // The build side can't be compiled together

View file

@ -501,6 +501,11 @@ case class TungstenAggregate(
} }
} }
// generate hash code for key
val hashExpr = Murmur3Hash(groupingExpressions, 42)
ctx.currentVars = input
val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx)
val inputAttr = bufferAttributes ++ child.output val inputAttr = bufferAttributes ++ child.output
ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
ctx.INPUT_ROW = buffer ctx.INPUT_ROW = buffer
@ -526,10 +531,11 @@ case class TungstenAggregate(
s""" s"""
// generate grouping key // generate grouping key
${keyCode.code.trim} ${keyCode.code.trim}
${hashEval.code.trim}
UnsafeRow $buffer = null; UnsafeRow $buffer = null;
if ($checkFallback) { if ($checkFallback) {
// try to get the buffer from hash map // try to get the buffer from hash map
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value});
} }
if ($buffer == null) { if ($buffer == null) {
if ($sorterTerm == null) { if ($sorterTerm == null) {
@ -540,7 +546,7 @@ case class TungstenAggregate(
$resetCoulter $resetCoulter
// the hash map had be spilled, it should have enough memory now, // the hash map had be spilled, it should have enough memory now,
// try to allocate buffer again. // try to allocate buffer again.
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value});
if ($buffer == null) { if ($buffer == null) {
// failed to allocate the first page // failed to allocate the first page
throw new OutOfMemoryError("No enough memory for aggregation"); throw new OutOfMemoryError("No enough memory for aggregation");

View file

@ -277,13 +277,13 @@ private[joins] final class UnsafeHashedRelation(
val map = binaryMap // avoid the compiler error val map = binaryMap // avoid the compiler error
val loc = new map.Location // this could be allocated in stack val loc = new map.Location // this could be allocated in stack
binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
unsafeKey.getSizeInBytes, loc) unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
if (loc.isDefined) { if (loc.isDefined) {
val buffer = CompactBuffer[UnsafeRow]() val buffer = CompactBuffer[UnsafeRow]()
val base = loc.getValueAddress.getBaseObject val base = loc.getValueBase
var offset = loc.getValueAddress.getBaseOffset var offset = loc.getValueOffset
val last = loc.getValueAddress.getBaseOffset + loc.getValueLength val last = offset + loc.getValueLength
while (offset < last) { while (offset < last) {
val numFields = Platform.getInt(base, offset) val numFields = Platform.getInt(base, offset)
val sizeInBytes = Platform.getInt(base, offset + 4) val sizeInBytes = Platform.getInt(base, offset + 4)
@ -311,12 +311,11 @@ private[joins] final class UnsafeHashedRelation(
out.writeInt(binaryMap.numElements()) out.writeInt(binaryMap.numElements())
var buffer = new Array[Byte](64) var buffer = new Array[Byte](64)
def write(addr: MemoryLocation, length: Int): Unit = { def write(base: Object, offset: Long, length: Int): Unit = {
if (buffer.length < length) { if (buffer.length < length) {
buffer = new Array[Byte](length) buffer = new Array[Byte](length)
} }
Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset, Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
buffer, Platform.BYTE_ARRAY_OFFSET, length)
out.write(buffer, 0, length) out.write(buffer, 0, length)
} }
@ -326,8 +325,8 @@ private[joins] final class UnsafeHashedRelation(
// [key size] [values size] [key bytes] [values bytes] // [key size] [values size] [key bytes] [values bytes]
out.writeInt(loc.getKeyLength) out.writeInt(loc.getKeyLength)
out.writeInt(loc.getValueLength) out.writeInt(loc.getValueLength)
write(loc.getKeyAddress, loc.getKeyLength) write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
write(loc.getValueAddress, loc.getValueLength) write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
} }
} else { } else {

View file

@ -114,11 +114,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
} }
/* /*
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------------
Aggregate w keys codegen=false 2402 / 2551 8.0 125.0 1.0X Aggregate w keys codegen=false 2429 / 2644 8.6 115.8 1.0X
Aggregate w keys codegen=true 1620 / 1670 12.0 83.3 1.5X Aggregate w keys codegen=true 1535 / 1571 13.7 73.2 1.6X
*/ */
} }
@ -165,21 +165,51 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
benchmark.addCase("hash") { iter => benchmark.addCase("hash") { iter =>
var i = 0 var i = 0
val keyBytes = new Array[Byte](16) val keyBytes = new Array[Byte](16)
val valueBytes = new Array[Byte](16)
val key = new UnsafeRow(1) val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
val value = new UnsafeRow(2)
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var s = 0 var s = 0
while (i < N) { while (i < N) {
key.setInt(0, i % 1000) key.setInt(0, i % 1000)
val h = Murmur3_x86_32.hashUnsafeWords( val h = Murmur3_x86_32.hashUnsafeWords(
key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0) key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 42)
s += h s += h
i += 1 i += 1
} }
} }
benchmark.addCase("fast hash") { iter =>
var i = 0
val keyBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var s = 0
while (i < N) {
key.setInt(0, i % 1000)
val h = Murmur3_x86_32.hashLong(i % 1000, 42)
s += h
i += 1
}
}
benchmark.addCase("arrayEqual") { iter =>
var i = 0
val keyBytes = new Array[Byte](16)
val valueBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
val value = new UnsafeRow(1)
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
value.setInt(0, 555)
var s = 0
while (i < N) {
key.setInt(0, i % 1000)
if (key.equals(value)) {
s += 1
}
i += 1
}
}
Seq("off", "on").foreach { heap => Seq("off", "on").foreach { heap =>
benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
val taskMemoryManager = new TaskMemoryManager( val taskMemoryManager = new TaskMemoryManager(
@ -195,15 +225,15 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
val valueBytes = new Array[Byte](16) val valueBytes = new Array[Byte](16)
val key = new UnsafeRow(1) val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
val value = new UnsafeRow(2) val value = new UnsafeRow(1)
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var i = 0 var i = 0
while (i < N) { while (i < N) {
key.setInt(0, i % 65536) key.setInt(0, i % 65536)
val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
Murmur3_x86_32.hashLong(i % 65536, 42))
if (loc.isDefined) { if (loc.isDefined) {
value.pointTo(loc.getValueAddress.getBaseObject, loc.getValueAddress.getBaseOffset, value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
loc.getValueLength)
value.setInt(0, value.getInt(0) + 1) value.setInt(0, value.getInt(0) + 1)
i += 1 i += 1
} else { } else {
@ -218,9 +248,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------------------
hash 628 / 661 83.0 12.0 1.0X hash 651 / 678 80.0 12.5 1.0X
BytesToBytesMap (off Heap) 3292 / 3408 15.0 66.7 0.2X fast hash 336 / 343 155.9 6.4 1.9X
BytesToBytesMap (on Heap) 3349 / 4267 15.0 66.7 0.2X arrayEqual 417 / 428 125.0 8.0 1.6X
BytesToBytesMap (off Heap) 2594 / 2664 20.2 49.5 0.2X
BytesToBytesMap (on Heap) 2693 / 2989 19.5 51.4 0.2X
*/ */
benchmark.run() benchmark.run()
} }