[SPARK-9517][SQL] BytesToBytesMap should encode data the same way as UnsafeExternalSorter

BytesToBytesMap current encodes key/value data in the following format:
```
8B key length, key data, 8B value length, value data
```

UnsafeExternalSorter, on the other hand, encodes data this way:
```
4B record length, data
```

As a result, we cannot pass records encoded by BytesToBytesMap directly into UnsafeExternalSorter for sorting. However, if we rearrange data slightly, we can then pass the key/value records directly into UnsafeExternalSorter:
```
4B key+value length, 4B key length, key data, value data
```

Author: Reynold Xin <rxin@databricks.com>

Closes #7845 from rxin/kvsort-rebase and squashes the following commits:

5716b59 [Reynold Xin] Fixed test.
2e62ccb [Reynold Xin] Updated BytesToBytesMap's data encoding to put the key first.
a51b641 [Reynold Xin] Added a KV sorter interface.
This commit is contained in:
Reynold Xin 2015-07-31 23:55:16 -07:00
parent 67ad4e21fc
commit d90f2cf7a2
8 changed files with 175 additions and 90 deletions

View file

@ -17,7 +17,6 @@
package org.apache.spark.unsafe.map;
import java.io.IOException;
import java.lang.Override;
import java.lang.UnsupportedOperationException;
import java.util.Iterator;
@ -212,7 +211,7 @@ public final class BytesToBytesMap {
*/
public int numElements() { return numElements; }
private static final class BytesToBytesMapIterator implements Iterator<Location> {
public static final class BytesToBytesMapIterator implements Iterator<Location> {
private final int numRecords;
private final Iterator<MemoryBlock> dataPagesIterator;
@ -222,7 +221,8 @@ public final class BytesToBytesMap {
private Object pageBaseObject;
private long offsetInPage;
BytesToBytesMapIterator(int numRecords, Iterator<MemoryBlock> dataPagesIterator, Location loc) {
private BytesToBytesMapIterator(
int numRecords, Iterator<MemoryBlock> dataPagesIterator, Location loc) {
this.numRecords = numRecords;
this.dataPagesIterator = dataPagesIterator;
this.loc = loc;
@ -244,13 +244,13 @@ public final class BytesToBytesMap {
@Override
public Location next() {
int keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage);
if (keyLength == END_OF_PAGE_MARKER) {
int totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage);
if (totalLength == END_OF_PAGE_MARKER) {
advanceToNextPage();
keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage);
totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage);
}
loc.with(pageBaseObject, offsetInPage);
offsetInPage += 8 + 8 + keyLength + loc.getValueLength();
offsetInPage += 8 + totalLength;
currentRecordNumber++;
return loc;
}
@ -269,7 +269,7 @@ public final class BytesToBytesMap {
* If any other lookups or operations are performed on this map while iterating over it, including
* `lookup()`, the behavior of the returned iterator is undefined.
*/
public Iterator<Location> iterator() {
public BytesToBytesMapIterator iterator() {
return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc);
}
@ -352,15 +352,18 @@ public final class BytesToBytesMap {
taskMemoryManager.getOffsetInPage(fullKeyAddress));
}
private void updateAddressesAndSizes(Object page, long keyOffsetInPage) {
long position = keyOffsetInPage;
keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position);
position += 8; // word used to store the key size
keyMemoryLocation.setObjAndOffset(page, position);
position += keyLength;
valueLength = (int) PlatformDependent.UNSAFE.getLong(page, position);
position += 8; // word used to store the key size
valueMemoryLocation.setObjAndOffset(page, position);
private void updateAddressesAndSizes(final Object page, final long keyOffsetInPage) {
long position = keyOffsetInPage;
final int totalLength = PlatformDependent.UNSAFE.getInt(page, position);
position += 4;
keyLength = PlatformDependent.UNSAFE.getInt(page, position);
position += 4;
valueLength = totalLength - keyLength;
keyMemoryLocation.setObjAndOffset(page, position);
position += keyLength;
valueMemoryLocation.setObjAndOffset(page, position);
}
Location with(int pos, int keyHashcode, boolean isDefined) {
@ -478,7 +481,7 @@ public final class BytesToBytesMap {
// the key address instead of storing the absolute address of the value, the key and value
// must be stored in the same memory page.
// (8 byte key length) (key) (8 byte value length) (value)
final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes;
final long requiredSize = 8 + keyLengthBytes + valueLengthBytes;
// --- Figure out where to insert the new record ---------------------------------------------
@ -508,7 +511,7 @@ public final class BytesToBytesMap {
// There wasn't enough space in the current page, so write an end-of-page marker:
final Object pageBaseObject = currentDataPage.getBaseObject();
final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
PlatformDependent.UNSAFE.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
}
final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryGranted != pageSizeBytes) {
@ -535,21 +538,22 @@ public final class BytesToBytesMap {
long insertCursor = dataPageInsertOffset;
// Compute all of our offsets up-front:
final long keySizeOffsetInPage = insertCursor;
insertCursor += 8; // word used to store the key size
final long totalLengthOffset = insertCursor;
insertCursor += 4;
final long keyLengthOffset = insertCursor;
insertCursor += 4;
final long keyDataOffsetInPage = insertCursor;
insertCursor += keyLengthBytes;
final long valueSizeOffsetInPage = insertCursor;
insertCursor += 8; // word used to store the value size
final long valueDataOffsetInPage = insertCursor;
insertCursor += valueLengthBytes; // word used to store the value size
PlatformDependent.UNSAFE.putInt(dataPageBaseObject, totalLengthOffset,
keyLengthBytes + valueLengthBytes);
PlatformDependent.UNSAFE.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes);
// Copy the key
PlatformDependent.UNSAFE.putLong(dataPageBaseObject, keySizeOffsetInPage, keyLengthBytes);
PlatformDependent.copyMemory(
keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes);
// Copy the value
PlatformDependent.UNSAFE.putLong(dataPageBaseObject, valueSizeOffsetInPage, valueLengthBytes);
PlatformDependent.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject,
valueDataOffsetInPage, valueLengthBytes);
@ -557,7 +561,7 @@ public final class BytesToBytesMap {
if (useOverflowPage) {
// Store the end-of-page marker at the end of the data page
PlatformDependent.UNSAFE.putLong(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER);
PlatformDependent.UNSAFE.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER);
} else {
pageCursor += requiredSize;
}
@ -565,7 +569,7 @@ public final class BytesToBytesMap {
numElements++;
bitset.set(pos);
final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
dataPage, keySizeOffsetInPage);
dataPage, totalLengthOffset);
longArray.set(pos * 2, storedKeyAddress);
longArray.set(pos * 2 + 1, keyHashcode);
updateAddressesAndSizes(storedKeyAddress);

View file

@ -282,6 +282,21 @@ public final class UnsafeExternalSorter {
sorter.insertRecord(recordAddress, prefix);
}
/**
* Write a record to the sorter. The record is broken down into two different parts, and
*
*/
public void insertRecord(
Object recordBaseObject1,
long recordBaseOffset1,
int lengthInBytes1,
Object recordBaseObject2,
long recordBaseOffset2,
int lengthInBytes2,
long prefix) throws IOException {
}
public UnsafeSorterIterator getSortedIterator() throws IOException {
final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator();
int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);

View file

@ -243,17 +243,17 @@ public abstract class AbstractBytesToBytesMapSuite {
@Test
public void iteratingOverDataPagesWithWastedSpace() throws Exception {
final int NUM_ENTRIES = 1000 * 1000;
final int KEY_LENGTH = 16;
final int KEY_LENGTH = 24;
final int VALUE_LENGTH = 40;
final BytesToBytesMap map = new BytesToBytesMap(
taskMemoryManager, shuffleMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
// Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte
// Each record will take 8 + 24 + 40 = 72 bytes of space in the data page. Our 64-megabyte
// pages won't be evenly-divisible by records of this size, which will cause us to waste some
// space at the end of the page. This is necessary in order for us to take the end-of-record
// handling branch in iterator().
try {
for (int i = 0; i < NUM_ENTRIES; i++) {
final long[] key = new long[] { i, i }; // 2 * 8 = 16 bytes
final long[] key = new long[] { i, i, i }; // 3 * 8 = 24 bytes
final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes
final BytesToBytesMap.Location loc = map.lookup(
key,

View file

@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution;
import java.io.IOException;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.unsafe.KVIterator;
public abstract class UnsafeKeyValueSorter {
public abstract void insert(UnsafeRow key, UnsafeRow value);
public abstract KVIterator<UnsafeRow, UnsafeRow> sort() throws IOException;
}

View file

@ -17,9 +17,6 @@
package org.apache.spark.sql.execution;
import java.io.IOException;
import java.util.Iterator;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
@ -28,6 +25,7 @@ import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryLocation;
@ -156,54 +154,55 @@ public final class UnsafeFixedWidthAggregationMap {
return currentAggregationBuffer;
}
/**
* Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}.
*/
public static class MapEntry {
private MapEntry() { };
public final UnsafeRow key = new UnsafeRow();
public final UnsafeRow value = new UnsafeRow();
}
/**
* Returns an iterator over the keys and values in this map.
*
* For efficiency, each call returns the same object.
*/
public Iterator<MapEntry> iterator() {
return new Iterator<MapEntry>() {
public KVIterator<UnsafeRow, UnsafeRow> iterator() {
return new KVIterator<UnsafeRow, UnsafeRow>() {
private final MapEntry entry = new MapEntry();
private final Iterator<BytesToBytesMap.Location> mapLocationIterator = map.iterator();
private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator = map.iterator();
private final UnsafeRow key = new UnsafeRow();
private final UnsafeRow value = new UnsafeRow();
@Override
public boolean hasNext() {
return mapLocationIterator.hasNext();
public boolean next() {
if (mapLocationIterator.hasNext()) {
final BytesToBytesMap.Location loc = mapLocationIterator.next();
final MemoryLocation keyAddress = loc.getKeyAddress();
final MemoryLocation valueAddress = loc.getValueAddress();
key.pointTo(
keyAddress.getBaseObject(),
keyAddress.getBaseOffset(),
groupingKeySchema.length(),
loc.getKeyLength()
);
value.pointTo(
valueAddress.getBaseObject(),
valueAddress.getBaseOffset(),
aggregationBufferSchema.length(),
loc.getValueLength()
);
return true;
} else {
return false;
}
}
@Override
public MapEntry next() {
final BytesToBytesMap.Location loc = mapLocationIterator.next();
final MemoryLocation keyAddress = loc.getKeyAddress();
final MemoryLocation valueAddress = loc.getValueAddress();
entry.key.pointTo(
keyAddress.getBaseObject(),
keyAddress.getBaseOffset(),
groupingKeySchema.length(),
loc.getKeyLength()
);
entry.value.pointTo(
valueAddress.getBaseObject(),
valueAddress.getBaseOffset(),
aggregationBufferSchema.length(),
loc.getValueLength()
);
return entry;
public UnsafeRow getKey() {
return key;
}
@Override
public void remove() {
throw new UnsupportedOperationException();
public UnsafeRow getValue() {
return value;
}
@Override
public void close() {
// Do nothing.
}
};
}

View file

@ -287,21 +287,26 @@ case class GeneratedAggregate(
new Iterator[InternalRow] {
private[this] val mapIterator = aggregationMap.iterator()
private[this] val resultProjection = resultProjectionBuilder()
private[this] var _hasNext = mapIterator.next()
def hasNext: Boolean = mapIterator.hasNext
def hasNext: Boolean = _hasNext
def next(): InternalRow = {
val entry = mapIterator.next()
val result = resultProjection(joinedRow(entry.key, entry.value))
if (hasNext) {
result
if (_hasNext) {
val result = resultProjection(joinedRow(mapIterator.getKey, mapIterator.getValue))
_hasNext = mapIterator.next()
if (_hasNext) {
result
} else {
// This is the last element in the iterator, so let's free the buffer. Before we do,
// though, we need to make a defensive copy of the result so that we don't return an
// object that might contain dangling pointers to the freed memory
val resultCopy = result.copy()
aggregationMap.free()
resultCopy
}
} else {
// This is the last element in the iterator, so let's free the buffer. Before we do,
// though, we need to make a defensive copy of the result so that we don't return an
// object that might contain dangling pointers to the freed memory
val resultCopy = result.copy()
aggregationMap.free()
resultCopy
throw new java.util.NoSuchElementException
}
}
}

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.scalatest.{BeforeAndAfterEach, Matchers}
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random
import org.apache.spark.SparkFunSuite
@ -52,7 +53,7 @@ class UnsafeFixedWidthAggregationMapSuite
override def afterEach(): Unit = {
if (taskMemoryManager != null) {
val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask
val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask()
assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0)
assert(leakedShuffleMemory === 0)
taskMemoryManager = null
@ -80,7 +81,7 @@ class UnsafeFixedWidthAggregationMapSuite
PAGE_SIZE_BYTES,
false // disable perf metrics
)
assert(!map.iterator().hasNext)
assert(!map.iterator().next())
map.free()
}
@ -100,13 +101,13 @@ class UnsafeFixedWidthAggregationMapSuite
// Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts)
assert(map.getAggregationBuffer(groupKey) != null)
val iter = map.iterator()
val entry = iter.next()
assert(!iter.hasNext)
entry.key.getString(0) should be ("cats")
entry.value.getInt(0) should be (0)
assert(iter.next())
iter.getKey.getString(0) should be ("cats")
iter.getValue.getInt(0) should be (0)
assert(!iter.next())
// Modifications to rows retrieved from the map should update the values in the map
entry.value.setInt(0, 42)
iter.getValue.setInt(0, 42)
map.getAggregationBuffer(groupKey).getInt(0) should be (42)
map.free()
@ -128,12 +129,14 @@ class UnsafeFixedWidthAggregationMapSuite
groupKeys.foreach { keyString =>
assert(map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) != null)
}
val seenKeys: Set[String] = map.iterator().asScala.map { entry =>
entry.key.getString(0)
}.toSet
seenKeys.size should be (groupKeys.size)
seenKeys should be (groupKeys)
val seenKeys = new mutable.HashSet[String]
val iter = map.iterator()
while (iter.next()) {
seenKeys += iter.getKey.getString(0)
}
assert(seenKeys.size === groupKeys.size)
assert(seenKeys === groupKeys)
map.free()
}

View file

@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.unsafe;
public abstract class KVIterator<K, V> {
public abstract boolean next();
public abstract K getKey();
public abstract V getValue();
public abstract void close();
}