[SPARK-12950] [SQL] Improve lookup of BytesToBytesMap in aggregate
This PR improve the lookup of BytesToBytesMap by: 1. Generate code for calculate the hash code of grouping keys. 2. Do not use MemoryLocation, fetch the baseObject and offset for key and value directly (remove the indirection). Author: Davies Liu <davies@databricks.com> Closes #11010 from davies/gen_map.
This commit is contained in:
parent
fae830d158
commit
0e5ebac3c1
|
@ -38,7 +38,6 @@ import org.apache.spark.unsafe.array.ByteArrayMethods;
|
||||||
import org.apache.spark.unsafe.array.LongArray;
|
import org.apache.spark.unsafe.array.LongArray;
|
||||||
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
|
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
|
||||||
import org.apache.spark.unsafe.memory.MemoryBlock;
|
import org.apache.spark.unsafe.memory.MemoryBlock;
|
||||||
import org.apache.spark.unsafe.memory.MemoryLocation;
|
|
||||||
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader;
|
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader;
|
||||||
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter;
|
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter;
|
||||||
|
|
||||||
|
@ -65,8 +64,6 @@ public final class BytesToBytesMap extends MemoryConsumer {
|
||||||
|
|
||||||
private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class);
|
private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class);
|
||||||
|
|
||||||
private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0);
|
|
||||||
|
|
||||||
private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
|
private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
|
||||||
|
|
||||||
private final TaskMemoryManager taskMemoryManager;
|
private final TaskMemoryManager taskMemoryManager;
|
||||||
|
@ -417,7 +414,19 @@ public final class BytesToBytesMap extends MemoryConsumer {
|
||||||
* This function always return the same {@link Location} instance to avoid object allocation.
|
* This function always return the same {@link Location} instance to avoid object allocation.
|
||||||
*/
|
*/
|
||||||
public Location lookup(Object keyBase, long keyOffset, int keyLength) {
|
public Location lookup(Object keyBase, long keyOffset, int keyLength) {
|
||||||
safeLookup(keyBase, keyOffset, keyLength, loc);
|
safeLookup(keyBase, keyOffset, keyLength, loc,
|
||||||
|
Murmur3_x86_32.hashUnsafeWords(keyBase, keyOffset, keyLength, 42));
|
||||||
|
return loc;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Looks up a key, and return a {@link Location} handle that can be used to test existence
|
||||||
|
* and read/write values.
|
||||||
|
*
|
||||||
|
* This function always return the same {@link Location} instance to avoid object allocation.
|
||||||
|
*/
|
||||||
|
public Location lookup(Object keyBase, long keyOffset, int keyLength, int hash) {
|
||||||
|
safeLookup(keyBase, keyOffset, keyLength, loc, hash);
|
||||||
return loc;
|
return loc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -426,14 +435,13 @@ public final class BytesToBytesMap extends MemoryConsumer {
|
||||||
*
|
*
|
||||||
* This is a thread-safe version of `lookup`, could be used by multiple threads.
|
* This is a thread-safe version of `lookup`, could be used by multiple threads.
|
||||||
*/
|
*/
|
||||||
public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc) {
|
public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc, int hash) {
|
||||||
assert(longArray != null);
|
assert(longArray != null);
|
||||||
|
|
||||||
if (enablePerfMetrics) {
|
if (enablePerfMetrics) {
|
||||||
numKeyLookups++;
|
numKeyLookups++;
|
||||||
}
|
}
|
||||||
final int hashcode = HASHER.hashUnsafeWords(keyBase, keyOffset, keyLength);
|
int pos = hash & mask;
|
||||||
int pos = hashcode & mask;
|
|
||||||
int step = 1;
|
int step = 1;
|
||||||
while (true) {
|
while (true) {
|
||||||
if (enablePerfMetrics) {
|
if (enablePerfMetrics) {
|
||||||
|
@ -441,22 +449,19 @@ public final class BytesToBytesMap extends MemoryConsumer {
|
||||||
}
|
}
|
||||||
if (longArray.get(pos * 2) == 0) {
|
if (longArray.get(pos * 2) == 0) {
|
||||||
// This is a new key.
|
// This is a new key.
|
||||||
loc.with(pos, hashcode, false);
|
loc.with(pos, hash, false);
|
||||||
return;
|
return;
|
||||||
} else {
|
} else {
|
||||||
long stored = longArray.get(pos * 2 + 1);
|
long stored = longArray.get(pos * 2 + 1);
|
||||||
if ((int) (stored) == hashcode) {
|
if ((int) (stored) == hash) {
|
||||||
// Full hash code matches. Let's compare the keys for equality.
|
// Full hash code matches. Let's compare the keys for equality.
|
||||||
loc.with(pos, hashcode, true);
|
loc.with(pos, hash, true);
|
||||||
if (loc.getKeyLength() == keyLength) {
|
if (loc.getKeyLength() == keyLength) {
|
||||||
final MemoryLocation keyAddress = loc.getKeyAddress();
|
|
||||||
final Object storedkeyBase = keyAddress.getBaseObject();
|
|
||||||
final long storedkeyOffset = keyAddress.getBaseOffset();
|
|
||||||
final boolean areEqual = ByteArrayMethods.arrayEquals(
|
final boolean areEqual = ByteArrayMethods.arrayEquals(
|
||||||
keyBase,
|
keyBase,
|
||||||
keyOffset,
|
keyOffset,
|
||||||
storedkeyBase,
|
loc.getKeyBase(),
|
||||||
storedkeyOffset,
|
loc.getKeyOffset(),
|
||||||
keyLength
|
keyLength
|
||||||
);
|
);
|
||||||
if (areEqual) {
|
if (areEqual) {
|
||||||
|
@ -484,13 +489,14 @@ public final class BytesToBytesMap extends MemoryConsumer {
|
||||||
private boolean isDefined;
|
private boolean isDefined;
|
||||||
/**
|
/**
|
||||||
* The hashcode of the most recent key passed to
|
* The hashcode of the most recent key passed to
|
||||||
* {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to
|
* {@link BytesToBytesMap#lookup(Object, long, int, int)}. Caching this hashcode here allows us
|
||||||
* avoid re-hashing the key when storing a value for that key.
|
* to avoid re-hashing the key when storing a value for that key.
|
||||||
*/
|
*/
|
||||||
private int keyHashcode;
|
private int keyHashcode;
|
||||||
private final MemoryLocation keyMemoryLocation = new MemoryLocation();
|
private Object baseObject; // the base object for key and value
|
||||||
private final MemoryLocation valueMemoryLocation = new MemoryLocation();
|
private long keyOffset;
|
||||||
private int keyLength;
|
private int keyLength;
|
||||||
|
private long valueOffset;
|
||||||
private int valueLength;
|
private int valueLength;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -504,18 +510,15 @@ public final class BytesToBytesMap extends MemoryConsumer {
|
||||||
taskMemoryManager.getOffsetInPage(fullKeyAddress));
|
taskMemoryManager.getOffsetInPage(fullKeyAddress));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateAddressesAndSizes(final Object base, final long offset) {
|
private void updateAddressesAndSizes(final Object base, long offset) {
|
||||||
long position = offset;
|
baseObject = base;
|
||||||
final int totalLength = Platform.getInt(base, position);
|
final int totalLength = Platform.getInt(base, offset);
|
||||||
position += 4;
|
offset += 4;
|
||||||
keyLength = Platform.getInt(base, position);
|
keyLength = Platform.getInt(base, offset);
|
||||||
position += 4;
|
offset += 4;
|
||||||
|
keyOffset = offset;
|
||||||
|
valueOffset = offset + keyLength;
|
||||||
valueLength = totalLength - keyLength - 4;
|
valueLength = totalLength - keyLength - 4;
|
||||||
|
|
||||||
keyMemoryLocation.setObjAndOffset(base, position);
|
|
||||||
|
|
||||||
position += keyLength;
|
|
||||||
valueMemoryLocation.setObjAndOffset(base, position);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Location with(int pos, int keyHashcode, boolean isDefined) {
|
private Location with(int pos, int keyHashcode, boolean isDefined) {
|
||||||
|
@ -543,10 +546,11 @@ public final class BytesToBytesMap extends MemoryConsumer {
|
||||||
private Location with(Object base, long offset, int length) {
|
private Location with(Object base, long offset, int length) {
|
||||||
this.isDefined = true;
|
this.isDefined = true;
|
||||||
this.memoryPage = null;
|
this.memoryPage = null;
|
||||||
|
baseObject = base;
|
||||||
|
keyOffset = offset + 4;
|
||||||
keyLength = Platform.getInt(base, offset);
|
keyLength = Platform.getInt(base, offset);
|
||||||
|
valueOffset = offset + 4 + keyLength;
|
||||||
valueLength = length - 4 - keyLength;
|
valueLength = length - 4 - keyLength;
|
||||||
keyMemoryLocation.setObjAndOffset(base, offset + 4);
|
|
||||||
valueMemoryLocation.setObjAndOffset(base, offset + 4 + keyLength);
|
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -566,14 +570,35 @@ public final class BytesToBytesMap extends MemoryConsumer {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the address of the key defined at this position.
|
* Returns the base object for key.
|
||||||
* This points to the first byte of the key data.
|
|
||||||
* Unspecified behavior if the key is not defined.
|
|
||||||
* For efficiency reasons, calls to this method always returns the same MemoryLocation object.
|
|
||||||
*/
|
*/
|
||||||
public MemoryLocation getKeyAddress() {
|
public Object getKeyBase() {
|
||||||
assert (isDefined);
|
assert (isDefined);
|
||||||
return keyMemoryLocation;
|
return baseObject;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the offset for key.
|
||||||
|
*/
|
||||||
|
public long getKeyOffset() {
|
||||||
|
assert (isDefined);
|
||||||
|
return keyOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the base object for value.
|
||||||
|
*/
|
||||||
|
public Object getValueBase() {
|
||||||
|
assert (isDefined);
|
||||||
|
return baseObject;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the offset for value.
|
||||||
|
*/
|
||||||
|
public long getValueOffset() {
|
||||||
|
assert (isDefined);
|
||||||
|
return valueOffset;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -585,17 +610,6 @@ public final class BytesToBytesMap extends MemoryConsumer {
|
||||||
return keyLength;
|
return keyLength;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the address of the value defined at this position.
|
|
||||||
* This points to the first byte of the value data.
|
|
||||||
* Unspecified behavior if the key is not defined.
|
|
||||||
* For efficiency reasons, calls to this method always returns the same MemoryLocation object.
|
|
||||||
*/
|
|
||||||
public MemoryLocation getValueAddress() {
|
|
||||||
assert (isDefined);
|
|
||||||
return valueMemoryLocation;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the length of the value defined at this position.
|
* Returns the length of the value defined at this position.
|
||||||
* Unspecified behavior if the key is not defined.
|
* Unspecified behavior if the key is not defined.
|
||||||
|
|
|
@ -39,14 +39,13 @@ import org.mockito.stubbing.Answer;
|
||||||
|
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.executor.ShuffleWriteMetrics;
|
import org.apache.spark.executor.ShuffleWriteMetrics;
|
||||||
import org.apache.spark.memory.TestMemoryManager;
|
|
||||||
import org.apache.spark.memory.TaskMemoryManager;
|
import org.apache.spark.memory.TaskMemoryManager;
|
||||||
|
import org.apache.spark.memory.TestMemoryManager;
|
||||||
import org.apache.spark.network.util.JavaUtils;
|
import org.apache.spark.network.util.JavaUtils;
|
||||||
import org.apache.spark.serializer.SerializerInstance;
|
import org.apache.spark.serializer.SerializerInstance;
|
||||||
import org.apache.spark.storage.*;
|
import org.apache.spark.storage.*;
|
||||||
import org.apache.spark.unsafe.Platform;
|
import org.apache.spark.unsafe.Platform;
|
||||||
import org.apache.spark.unsafe.array.ByteArrayMethods;
|
import org.apache.spark.unsafe.array.ByteArrayMethods;
|
||||||
import org.apache.spark.unsafe.memory.MemoryLocation;
|
|
||||||
import org.apache.spark.util.Utils;
|
import org.apache.spark.util.Utils;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.greaterThan;
|
import static org.hamcrest.Matchers.greaterThan;
|
||||||
|
@ -142,10 +141,9 @@ public abstract class AbstractBytesToBytesMapSuite {
|
||||||
|
|
||||||
protected abstract boolean useOffHeapMemoryAllocator();
|
protected abstract boolean useOffHeapMemoryAllocator();
|
||||||
|
|
||||||
private static byte[] getByteArray(MemoryLocation loc, int size) {
|
private static byte[] getByteArray(Object base, long offset, int size) {
|
||||||
final byte[] arr = new byte[size];
|
final byte[] arr = new byte[size];
|
||||||
Platform.copyMemory(
|
Platform.copyMemory(base, offset, arr, Platform.BYTE_ARRAY_OFFSET, size);
|
||||||
loc.getBaseObject(), loc.getBaseOffset(), arr, Platform.BYTE_ARRAY_OFFSET, size);
|
|
||||||
return arr;
|
return arr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,13 +161,14 @@ public abstract class AbstractBytesToBytesMapSuite {
|
||||||
*/
|
*/
|
||||||
private static boolean arrayEquals(
|
private static boolean arrayEquals(
|
||||||
byte[] expected,
|
byte[] expected,
|
||||||
MemoryLocation actualAddr,
|
Object base,
|
||||||
|
long offset,
|
||||||
long actualLengthBytes) {
|
long actualLengthBytes) {
|
||||||
return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals(
|
return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals(
|
||||||
expected,
|
expected,
|
||||||
Platform.BYTE_ARRAY_OFFSET,
|
Platform.BYTE_ARRAY_OFFSET,
|
||||||
actualAddr.getBaseObject(),
|
base,
|
||||||
actualAddr.getBaseOffset(),
|
offset,
|
||||||
expected.length
|
expected.length
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -212,16 +211,20 @@ public abstract class AbstractBytesToBytesMapSuite {
|
||||||
// reflect the result of this store without us having to call lookup() again on the same key.
|
// reflect the result of this store without us having to call lookup() again on the same key.
|
||||||
Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
|
Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
|
||||||
Assert.assertEquals(recordLengthBytes, loc.getValueLength());
|
Assert.assertEquals(recordLengthBytes, loc.getValueLength());
|
||||||
Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes));
|
Assert.assertArrayEquals(keyData,
|
||||||
Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
|
getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes));
|
||||||
|
Assert.assertArrayEquals(valueData,
|
||||||
|
getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes));
|
||||||
|
|
||||||
// After calling lookup() the location should still point to the correct data.
|
// After calling lookup() the location should still point to the correct data.
|
||||||
Assert.assertTrue(
|
Assert.assertTrue(
|
||||||
map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined());
|
map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined());
|
||||||
Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
|
Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
|
||||||
Assert.assertEquals(recordLengthBytes, loc.getValueLength());
|
Assert.assertEquals(recordLengthBytes, loc.getValueLength());
|
||||||
Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes));
|
Assert.assertArrayEquals(keyData,
|
||||||
Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
|
getByteArray(loc.getKeyBase(), loc.getKeyOffset(), recordLengthBytes));
|
||||||
|
Assert.assertArrayEquals(valueData,
|
||||||
|
getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes));
|
||||||
|
|
||||||
try {
|
try {
|
||||||
Assert.assertTrue(loc.putNewKey(
|
Assert.assertTrue(loc.putNewKey(
|
||||||
|
@ -283,15 +286,12 @@ public abstract class AbstractBytesToBytesMapSuite {
|
||||||
while (iter.hasNext()) {
|
while (iter.hasNext()) {
|
||||||
final BytesToBytesMap.Location loc = iter.next();
|
final BytesToBytesMap.Location loc = iter.next();
|
||||||
Assert.assertTrue(loc.isDefined());
|
Assert.assertTrue(loc.isDefined());
|
||||||
final MemoryLocation keyAddress = loc.getKeyAddress();
|
final long value = Platform.getLong(loc.getValueBase(), loc.getValueOffset());
|
||||||
final MemoryLocation valueAddress = loc.getValueAddress();
|
|
||||||
final long value = Platform.getLong(
|
|
||||||
valueAddress.getBaseObject(), valueAddress.getBaseOffset());
|
|
||||||
final long keyLength = loc.getKeyLength();
|
final long keyLength = loc.getKeyLength();
|
||||||
if (keyLength == 0) {
|
if (keyLength == 0) {
|
||||||
Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0);
|
Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0);
|
||||||
} else {
|
} else {
|
||||||
final long key = Platform.getLong(keyAddress.getBaseObject(), keyAddress.getBaseOffset());
|
final long key = Platform.getLong(loc.getKeyBase(), loc.getKeyOffset());
|
||||||
Assert.assertEquals(value, key);
|
Assert.assertEquals(value, key);
|
||||||
}
|
}
|
||||||
valuesSeen.set((int) value);
|
valuesSeen.set((int) value);
|
||||||
|
@ -365,15 +365,15 @@ public abstract class AbstractBytesToBytesMapSuite {
|
||||||
Assert.assertEquals(KEY_LENGTH, loc.getKeyLength());
|
Assert.assertEquals(KEY_LENGTH, loc.getKeyLength());
|
||||||
Assert.assertEquals(VALUE_LENGTH, loc.getValueLength());
|
Assert.assertEquals(VALUE_LENGTH, loc.getValueLength());
|
||||||
Platform.copyMemory(
|
Platform.copyMemory(
|
||||||
loc.getKeyAddress().getBaseObject(),
|
loc.getKeyBase(),
|
||||||
loc.getKeyAddress().getBaseOffset(),
|
loc.getKeyOffset(),
|
||||||
key,
|
key,
|
||||||
Platform.LONG_ARRAY_OFFSET,
|
Platform.LONG_ARRAY_OFFSET,
|
||||||
KEY_LENGTH
|
KEY_LENGTH
|
||||||
);
|
);
|
||||||
Platform.copyMemory(
|
Platform.copyMemory(
|
||||||
loc.getValueAddress().getBaseObject(),
|
loc.getValueBase(),
|
||||||
loc.getValueAddress().getBaseOffset(),
|
loc.getValueOffset(),
|
||||||
value,
|
value,
|
||||||
Platform.LONG_ARRAY_OFFSET,
|
Platform.LONG_ARRAY_OFFSET,
|
||||||
VALUE_LENGTH
|
VALUE_LENGTH
|
||||||
|
@ -425,8 +425,9 @@ public abstract class AbstractBytesToBytesMapSuite {
|
||||||
Assert.assertTrue(loc.isDefined());
|
Assert.assertTrue(loc.isDefined());
|
||||||
Assert.assertEquals(key.length, loc.getKeyLength());
|
Assert.assertEquals(key.length, loc.getKeyLength());
|
||||||
Assert.assertEquals(value.length, loc.getValueLength());
|
Assert.assertEquals(value.length, loc.getValueLength());
|
||||||
Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length));
|
Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length));
|
||||||
Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length));
|
Assert.assertTrue(
|
||||||
|
arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -436,8 +437,10 @@ public abstract class AbstractBytesToBytesMapSuite {
|
||||||
final BytesToBytesMap.Location loc =
|
final BytesToBytesMap.Location loc =
|
||||||
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
|
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
|
||||||
Assert.assertTrue(loc.isDefined());
|
Assert.assertTrue(loc.isDefined());
|
||||||
Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength()));
|
Assert.assertTrue(
|
||||||
Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength()));
|
arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength()));
|
||||||
|
Assert.assertTrue(
|
||||||
|
arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength()));
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
map.free();
|
map.free();
|
||||||
|
@ -476,8 +479,9 @@ public abstract class AbstractBytesToBytesMapSuite {
|
||||||
Assert.assertTrue(loc.isDefined());
|
Assert.assertTrue(loc.isDefined());
|
||||||
Assert.assertEquals(key.length, loc.getKeyLength());
|
Assert.assertEquals(key.length, loc.getKeyLength());
|
||||||
Assert.assertEquals(value.length, loc.getValueLength());
|
Assert.assertEquals(value.length, loc.getValueLength());
|
||||||
Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length));
|
Assert.assertTrue(arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), key.length));
|
||||||
Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length));
|
Assert.assertTrue(
|
||||||
|
arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), value.length));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
|
for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
|
||||||
|
@ -486,8 +490,10 @@ public abstract class AbstractBytesToBytesMapSuite {
|
||||||
final BytesToBytesMap.Location loc =
|
final BytesToBytesMap.Location loc =
|
||||||
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
|
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
|
||||||
Assert.assertTrue(loc.isDefined());
|
Assert.assertTrue(loc.isDefined());
|
||||||
Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength()));
|
Assert.assertTrue(
|
||||||
Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength()));
|
arrayEquals(key, loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength()));
|
||||||
|
Assert.assertTrue(
|
||||||
|
arrayEquals(value, loc.getValueBase(), loc.getValueOffset(), loc.getValueLength()));
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
map.free();
|
map.free();
|
||||||
|
|
|
@ -40,6 +40,7 @@ object MimaExcludes {
|
||||||
excludePackage("org.apache.spark.rpc"),
|
excludePackage("org.apache.spark.rpc"),
|
||||||
excludePackage("org.spark-project.jetty"),
|
excludePackage("org.spark-project.jetty"),
|
||||||
excludePackage("org.apache.spark.unused"),
|
excludePackage("org.apache.spark.unused"),
|
||||||
|
excludePackage("org.apache.spark.unsafe"),
|
||||||
excludePackage("org.apache.spark.util.collection.unsafe"),
|
excludePackage("org.apache.spark.util.collection.unsafe"),
|
||||||
excludePackage("org.apache.spark.sql.catalyst"),
|
excludePackage("org.apache.spark.sql.catalyst"),
|
||||||
excludePackage("org.apache.spark.sql.execution"),
|
excludePackage("org.apache.spark.sql.execution"),
|
||||||
|
|
|
@ -322,7 +322,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
|
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
|
||||||
ev.isNull = "false"
|
ev.isNull = "false"
|
||||||
val childrenHash = children.map { child =>
|
val childrenHash = children.map { child =>
|
||||||
|
|
|
@ -121,19 +121,24 @@ public final class UnsafeFixedWidthAggregationMap {
|
||||||
return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow);
|
return getAggregationBufferFromUnsafeRow(unsafeGroupingKeyRow);
|
||||||
}
|
}
|
||||||
|
|
||||||
public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRow) {
|
public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key) {
|
||||||
|
return getAggregationBufferFromUnsafeRow(key, key.hashCode());
|
||||||
|
}
|
||||||
|
|
||||||
|
public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key, int hash) {
|
||||||
// Probe our map using the serialized key
|
// Probe our map using the serialized key
|
||||||
final BytesToBytesMap.Location loc = map.lookup(
|
final BytesToBytesMap.Location loc = map.lookup(
|
||||||
unsafeGroupingKeyRow.getBaseObject(),
|
key.getBaseObject(),
|
||||||
unsafeGroupingKeyRow.getBaseOffset(),
|
key.getBaseOffset(),
|
||||||
unsafeGroupingKeyRow.getSizeInBytes());
|
key.getSizeInBytes(),
|
||||||
|
hash);
|
||||||
if (!loc.isDefined()) {
|
if (!loc.isDefined()) {
|
||||||
// This is the first time that we've seen this grouping key, so we'll insert a copy of the
|
// This is the first time that we've seen this grouping key, so we'll insert a copy of the
|
||||||
// empty aggregation buffer into the map:
|
// empty aggregation buffer into the map:
|
||||||
boolean putSucceeded = loc.putNewKey(
|
boolean putSucceeded = loc.putNewKey(
|
||||||
unsafeGroupingKeyRow.getBaseObject(),
|
key.getBaseObject(),
|
||||||
unsafeGroupingKeyRow.getBaseOffset(),
|
key.getBaseOffset(),
|
||||||
unsafeGroupingKeyRow.getSizeInBytes(),
|
key.getSizeInBytes(),
|
||||||
emptyAggregationBuffer,
|
emptyAggregationBuffer,
|
||||||
Platform.BYTE_ARRAY_OFFSET,
|
Platform.BYTE_ARRAY_OFFSET,
|
||||||
emptyAggregationBuffer.length
|
emptyAggregationBuffer.length
|
||||||
|
@ -144,10 +149,9 @@ public final class UnsafeFixedWidthAggregationMap {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset the pointer to point to the value that we just stored or looked up:
|
// Reset the pointer to point to the value that we just stored or looked up:
|
||||||
final MemoryLocation address = loc.getValueAddress();
|
|
||||||
currentAggregationBuffer.pointTo(
|
currentAggregationBuffer.pointTo(
|
||||||
address.getBaseObject(),
|
loc.getValueBase(),
|
||||||
address.getBaseOffset(),
|
loc.getValueOffset(),
|
||||||
loc.getValueLength()
|
loc.getValueLength()
|
||||||
);
|
);
|
||||||
return currentAggregationBuffer;
|
return currentAggregationBuffer;
|
||||||
|
@ -172,16 +176,14 @@ public final class UnsafeFixedWidthAggregationMap {
|
||||||
public boolean next() {
|
public boolean next() {
|
||||||
if (mapLocationIterator.hasNext()) {
|
if (mapLocationIterator.hasNext()) {
|
||||||
final BytesToBytesMap.Location loc = mapLocationIterator.next();
|
final BytesToBytesMap.Location loc = mapLocationIterator.next();
|
||||||
final MemoryLocation keyAddress = loc.getKeyAddress();
|
|
||||||
final MemoryLocation valueAddress = loc.getValueAddress();
|
|
||||||
key.pointTo(
|
key.pointTo(
|
||||||
keyAddress.getBaseObject(),
|
loc.getKeyBase(),
|
||||||
keyAddress.getBaseOffset(),
|
loc.getKeyOffset(),
|
||||||
loc.getKeyLength()
|
loc.getKeyLength()
|
||||||
);
|
);
|
||||||
value.pointTo(
|
value.pointTo(
|
||||||
valueAddress.getBaseObject(),
|
loc.getValueBase(),
|
||||||
valueAddress.getBaseOffset(),
|
loc.getValueOffset(),
|
||||||
loc.getValueLength()
|
loc.getValueLength()
|
||||||
);
|
);
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -97,8 +97,8 @@ public final class UnsafeKVExternalSorter {
|
||||||
UnsafeRow row = new UnsafeRow(numKeyFields);
|
UnsafeRow row = new UnsafeRow(numKeyFields);
|
||||||
while (iter.hasNext()) {
|
while (iter.hasNext()) {
|
||||||
final BytesToBytesMap.Location loc = iter.next();
|
final BytesToBytesMap.Location loc = iter.next();
|
||||||
final Object baseObject = loc.getKeyAddress().getBaseObject();
|
final Object baseObject = loc.getKeyBase();
|
||||||
final long baseOffset = loc.getKeyAddress().getBaseOffset();
|
final long baseOffset = loc.getKeyOffset();
|
||||||
|
|
||||||
// Get encoded memory address
|
// Get encoded memory address
|
||||||
// baseObject + baseOffset point to the beginning of the key data in the map, but that
|
// baseObject + baseOffset point to the beginning of the key data in the map, but that
|
||||||
|
|
|
@ -366,11 +366,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
|
||||||
def apply(plan: SparkPlan): SparkPlan = {
|
def apply(plan: SparkPlan): SparkPlan = {
|
||||||
if (sqlContext.conf.wholeStageEnabled) {
|
if (sqlContext.conf.wholeStageEnabled) {
|
||||||
plan.transform {
|
plan.transform {
|
||||||
case plan: CodegenSupport if supportCodegen(plan) &&
|
case plan: CodegenSupport if supportCodegen(plan) =>
|
||||||
// Whole stage codegen is only useful when there are at least two levels of operators that
|
|
||||||
// support it (save at least one projection/iterator).
|
|
||||||
(Utils.isTesting || plan.children.exists(supportCodegen)) =>
|
|
||||||
|
|
||||||
var inputs = ArrayBuffer[SparkPlan]()
|
var inputs = ArrayBuffer[SparkPlan]()
|
||||||
val combined = plan.transform {
|
val combined = plan.transform {
|
||||||
// The build side can't be compiled together
|
// The build side can't be compiled together
|
||||||
|
|
|
@ -501,6 +501,11 @@ case class TungstenAggregate(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generate hash code for key
|
||||||
|
val hashExpr = Murmur3Hash(groupingExpressions, 42)
|
||||||
|
ctx.currentVars = input
|
||||||
|
val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx)
|
||||||
|
|
||||||
val inputAttr = bufferAttributes ++ child.output
|
val inputAttr = bufferAttributes ++ child.output
|
||||||
ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
|
ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
|
||||||
ctx.INPUT_ROW = buffer
|
ctx.INPUT_ROW = buffer
|
||||||
|
@ -526,10 +531,11 @@ case class TungstenAggregate(
|
||||||
s"""
|
s"""
|
||||||
// generate grouping key
|
// generate grouping key
|
||||||
${keyCode.code.trim}
|
${keyCode.code.trim}
|
||||||
|
${hashEval.code.trim}
|
||||||
UnsafeRow $buffer = null;
|
UnsafeRow $buffer = null;
|
||||||
if ($checkFallback) {
|
if ($checkFallback) {
|
||||||
// try to get the buffer from hash map
|
// try to get the buffer from hash map
|
||||||
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
|
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value});
|
||||||
}
|
}
|
||||||
if ($buffer == null) {
|
if ($buffer == null) {
|
||||||
if ($sorterTerm == null) {
|
if ($sorterTerm == null) {
|
||||||
|
@ -540,7 +546,7 @@ case class TungstenAggregate(
|
||||||
$resetCoulter
|
$resetCoulter
|
||||||
// the hash map had be spilled, it should have enough memory now,
|
// the hash map had be spilled, it should have enough memory now,
|
||||||
// try to allocate buffer again.
|
// try to allocate buffer again.
|
||||||
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
|
$buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value});
|
||||||
if ($buffer == null) {
|
if ($buffer == null) {
|
||||||
// failed to allocate the first page
|
// failed to allocate the first page
|
||||||
throw new OutOfMemoryError("No enough memory for aggregation");
|
throw new OutOfMemoryError("No enough memory for aggregation");
|
||||||
|
|
|
@ -277,13 +277,13 @@ private[joins] final class UnsafeHashedRelation(
|
||||||
val map = binaryMap // avoid the compiler error
|
val map = binaryMap // avoid the compiler error
|
||||||
val loc = new map.Location // this could be allocated in stack
|
val loc = new map.Location // this could be allocated in stack
|
||||||
binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
|
binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
|
||||||
unsafeKey.getSizeInBytes, loc)
|
unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
|
||||||
if (loc.isDefined) {
|
if (loc.isDefined) {
|
||||||
val buffer = CompactBuffer[UnsafeRow]()
|
val buffer = CompactBuffer[UnsafeRow]()
|
||||||
|
|
||||||
val base = loc.getValueAddress.getBaseObject
|
val base = loc.getValueBase
|
||||||
var offset = loc.getValueAddress.getBaseOffset
|
var offset = loc.getValueOffset
|
||||||
val last = loc.getValueAddress.getBaseOffset + loc.getValueLength
|
val last = offset + loc.getValueLength
|
||||||
while (offset < last) {
|
while (offset < last) {
|
||||||
val numFields = Platform.getInt(base, offset)
|
val numFields = Platform.getInt(base, offset)
|
||||||
val sizeInBytes = Platform.getInt(base, offset + 4)
|
val sizeInBytes = Platform.getInt(base, offset + 4)
|
||||||
|
@ -311,12 +311,11 @@ private[joins] final class UnsafeHashedRelation(
|
||||||
out.writeInt(binaryMap.numElements())
|
out.writeInt(binaryMap.numElements())
|
||||||
|
|
||||||
var buffer = new Array[Byte](64)
|
var buffer = new Array[Byte](64)
|
||||||
def write(addr: MemoryLocation, length: Int): Unit = {
|
def write(base: Object, offset: Long, length: Int): Unit = {
|
||||||
if (buffer.length < length) {
|
if (buffer.length < length) {
|
||||||
buffer = new Array[Byte](length)
|
buffer = new Array[Byte](length)
|
||||||
}
|
}
|
||||||
Platform.copyMemory(addr.getBaseObject, addr.getBaseOffset,
|
Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
|
||||||
buffer, Platform.BYTE_ARRAY_OFFSET, length)
|
|
||||||
out.write(buffer, 0, length)
|
out.write(buffer, 0, length)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -326,8 +325,8 @@ private[joins] final class UnsafeHashedRelation(
|
||||||
// [key size] [values size] [key bytes] [values bytes]
|
// [key size] [values size] [key bytes] [values bytes]
|
||||||
out.writeInt(loc.getKeyLength)
|
out.writeInt(loc.getKeyLength)
|
||||||
out.writeInt(loc.getValueLength)
|
out.writeInt(loc.getValueLength)
|
||||||
write(loc.getKeyAddress, loc.getKeyLength)
|
write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
|
||||||
write(loc.getValueAddress, loc.getValueLength)
|
write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -114,11 +114,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
|
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
|
||||||
Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||||
-------------------------------------------------------------------------------------------
|
-------------------------------------------------------------------------------------------
|
||||||
Aggregate w keys codegen=false 2402 / 2551 8.0 125.0 1.0X
|
Aggregate w keys codegen=false 2429 / 2644 8.6 115.8 1.0X
|
||||||
Aggregate w keys codegen=true 1620 / 1670 12.0 83.3 1.5X
|
Aggregate w keys codegen=true 1535 / 1571 13.7 73.2 1.6X
|
||||||
*/
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,21 +165,51 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
|
||||||
benchmark.addCase("hash") { iter =>
|
benchmark.addCase("hash") { iter =>
|
||||||
var i = 0
|
var i = 0
|
||||||
val keyBytes = new Array[Byte](16)
|
val keyBytes = new Array[Byte](16)
|
||||||
val valueBytes = new Array[Byte](16)
|
|
||||||
val key = new UnsafeRow(1)
|
val key = new UnsafeRow(1)
|
||||||
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
|
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
|
||||||
val value = new UnsafeRow(2)
|
|
||||||
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
|
|
||||||
var s = 0
|
var s = 0
|
||||||
while (i < N) {
|
while (i < N) {
|
||||||
key.setInt(0, i % 1000)
|
key.setInt(0, i % 1000)
|
||||||
val h = Murmur3_x86_32.hashUnsafeWords(
|
val h = Murmur3_x86_32.hashUnsafeWords(
|
||||||
key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0)
|
key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 42)
|
||||||
s += h
|
s += h
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
benchmark.addCase("fast hash") { iter =>
|
||||||
|
var i = 0
|
||||||
|
val keyBytes = new Array[Byte](16)
|
||||||
|
val key = new UnsafeRow(1)
|
||||||
|
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
|
||||||
|
var s = 0
|
||||||
|
while (i < N) {
|
||||||
|
key.setInt(0, i % 1000)
|
||||||
|
val h = Murmur3_x86_32.hashLong(i % 1000, 42)
|
||||||
|
s += h
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmark.addCase("arrayEqual") { iter =>
|
||||||
|
var i = 0
|
||||||
|
val keyBytes = new Array[Byte](16)
|
||||||
|
val valueBytes = new Array[Byte](16)
|
||||||
|
val key = new UnsafeRow(1)
|
||||||
|
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
|
||||||
|
val value = new UnsafeRow(1)
|
||||||
|
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
|
||||||
|
value.setInt(0, 555)
|
||||||
|
var s = 0
|
||||||
|
while (i < N) {
|
||||||
|
key.setInt(0, i % 1000)
|
||||||
|
if (key.equals(value)) {
|
||||||
|
s += 1
|
||||||
|
}
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Seq("off", "on").foreach { heap =>
|
Seq("off", "on").foreach { heap =>
|
||||||
benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
|
benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
|
||||||
val taskMemoryManager = new TaskMemoryManager(
|
val taskMemoryManager = new TaskMemoryManager(
|
||||||
|
@ -195,15 +225,15 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
|
||||||
val valueBytes = new Array[Byte](16)
|
val valueBytes = new Array[Byte](16)
|
||||||
val key = new UnsafeRow(1)
|
val key = new UnsafeRow(1)
|
||||||
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
|
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
|
||||||
val value = new UnsafeRow(2)
|
val value = new UnsafeRow(1)
|
||||||
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
|
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
|
||||||
var i = 0
|
var i = 0
|
||||||
while (i < N) {
|
while (i < N) {
|
||||||
key.setInt(0, i % 65536)
|
key.setInt(0, i % 65536)
|
||||||
val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
|
val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
|
||||||
|
Murmur3_x86_32.hashLong(i % 65536, 42))
|
||||||
if (loc.isDefined) {
|
if (loc.isDefined) {
|
||||||
value.pointTo(loc.getValueAddress.getBaseObject, loc.getValueAddress.getBaseOffset,
|
value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
|
||||||
loc.getValueLength)
|
|
||||||
value.setInt(0, value.getInt(0) + 1)
|
value.setInt(0, value.getInt(0) + 1)
|
||||||
i += 1
|
i += 1
|
||||||
} else {
|
} else {
|
||||||
|
@ -218,9 +248,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
|
||||||
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
|
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
|
||||||
BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||||
-------------------------------------------------------------------------------------------
|
-------------------------------------------------------------------------------------------
|
||||||
hash 628 / 661 83.0 12.0 1.0X
|
hash 651 / 678 80.0 12.5 1.0X
|
||||||
BytesToBytesMap (off Heap) 3292 / 3408 15.0 66.7 0.2X
|
fast hash 336 / 343 155.9 6.4 1.9X
|
||||||
BytesToBytesMap (on Heap) 3349 / 4267 15.0 66.7 0.2X
|
arrayEqual 417 / 428 125.0 8.0 1.6X
|
||||||
|
BytesToBytesMap (off Heap) 2594 / 2664 20.2 49.5 0.2X
|
||||||
|
BytesToBytesMap (on Heap) 2693 / 2989 19.5 51.4 0.2X
|
||||||
*/
|
*/
|
||||||
benchmark.run()
|
benchmark.run()
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue