[SPARK-31425][SQL][CORE] UnsafeKVExternalSorter/VariableLengthRowBasedKeyValueBatch should also respect UnsafeAlignedOffset

### What changes were proposed in this pull request?

Make `UnsafeKVExternalSorter` / `VariableLengthRowBasedKeyValueBatch ` also respect `UnsafeAlignedOffset` when reading the record and update some out of date comemnts.

### Why are the changes needed?

Since `BytesToBytesMap` respects `UnsafeAlignedOffset` when writing the record, `UnsafeKVExternalSorter` should also respect `UnsafeAlignedOffset` when reading the record from `BytesToBytesMap` otherwise it will causes data correctness issue.

Unlike `UnsafeKVExternalSorter` may reading records from `BytesToBytesMap`, `VariableLengthRowBasedKeyValueBatch` writes and reads records by itself. Thus, similar to #22053 and [comment](https://github.com/apache/spark/pull/22053#issuecomment-411975239) there, fix for `VariableLengthRowBasedKeyValueBatch` more likely an improvement for the support of SPARC platform.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Manually tested `HashAggregationQueryWithControlledFallbackSuite` with `UAO_SIZE=8`  to simulate SPARC platform. And tests only pass with this fix.

Closes #28195 from Ngone51/fix_uao.

Authored-by: yi.wu <yi.wu@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
yi.wu 2020-04-17 04:48:27 +00:00 committed by Wenchen Fan
parent b2e9e1717b
commit 40f9dbb628
6 changed files with 76 additions and 54 deletions

View file

@ -28,12 +28,20 @@ public class UnsafeAlignedOffset {
private static final int UAO_SIZE = Platform.unaligned() ? 4 : 8;
private static int TEST_UAO_SIZE = 0;
// used for test only
public static void setUaoSize(int size) {
assert size == 0 || size == 4 || size == 8;
TEST_UAO_SIZE = size;
}
public static int getUaoSize() {
return UAO_SIZE;
return TEST_UAO_SIZE == 0 ? UAO_SIZE : TEST_UAO_SIZE;
}
public static int getSize(Object object, long offset) {
switch (UAO_SIZE) {
switch (getUaoSize()) {
case 4:
return Platform.getInt(object, offset);
case 8:
@ -46,7 +54,7 @@ public class UnsafeAlignedOffset {
}
public static void putSize(Object object, long offset, int value) {
switch (UAO_SIZE) {
switch (getUaoSize()) {
case 4:
Platform.putInt(object, offset, value);
break;

View file

@ -54,13 +54,13 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter;
* probably be using sorting instead of hashing for better cache locality.
*
* The key and values under the hood are stored together, in the following format:
* Bytes 0 to 4: len(k) (key length in bytes) + len(v) (value length in bytes) + 4
* Bytes 4 to 8: len(k)
* Bytes 8 to 8 + len(k): key data
* Bytes 8 + len(k) to 8 + len(k) + len(v): value data
* Bytes 8 + len(k) + len(v) to 8 + len(k) + len(v) + 8: pointer to next pair
* First uaoSize bytes: len(k) (key length in bytes) + len(v) (value length in bytes) + uaoSize
* Next uaoSize bytes: len(k)
* Next len(k) bytes: key data
* Next len(v) bytes: value data
* Last 8 bytes: pointer to next pair
*
* This means that the first four bytes store the entire record (key + value) length. This format
* It means first uaoSize bytes store the entire record (key + value + uaoSize) length. This format
* is compatible with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter},
* so we can pass records from this map directly into the sorter to sort records in place.
*/
@ -706,7 +706,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
// Here, we'll copy the data into our data pages. Because we only store a relative offset from
// the key address instead of storing the absolute address of the value, the key and value
// must be stored in the same memory page.
// (8 byte key length) (key) (value) (8 byte pointer to next value)
// (total length) (key length) (key) (value) (8 byte pointer to next value)
int uaoSize = UnsafeAlignedOffset.getUaoSize();
final long recordLength = (2L * uaoSize) + klen + vlen + 8;
if (currentPage == null || currentPage.size() - pageCursor < recordLength) {

View file

@ -235,7 +235,7 @@ public final class UnsafeInMemorySorter {
/**
* Inserts a record to be sorted. Assumes that the record pointer points to a record length
* stored as a 4-byte integer, followed by the record's bytes.
* stored as a uaoSize(4 or 8) bytes integer, followed by the record's bytes.
*
* @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
* @param keyPrefix a user-defined key prefix

View file

@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.expressions;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.UnsafeAlignedOffset;
/**
* An implementation of `RowBasedKeyValueBatch` in which key-value records have variable lengths.
*
* The format for each record looks like this:
* The format for each record looks like this (in case of uaoSize = 4):
* [4 bytes total size = (klen + vlen + 4)] [4 bytes key size = klen]
* [UnsafeRow for key of length klen] [UnsafeRow for Value of length vlen]
* [8 bytes pointer to next]
@ -41,7 +42,8 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB
@Override
public UnsafeRow appendRow(Object kbase, long koff, int klen,
Object vbase, long voff, int vlen) {
final long recordLength = 8L + klen + vlen + 8;
int uaoSize = UnsafeAlignedOffset.getUaoSize();
final long recordLength = 2 * uaoSize + klen + vlen + 8L;
// if run out of max supported rows or page size, return null
if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) {
return null;
@ -49,10 +51,10 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB
long offset = page.getBaseOffset() + pageCursor;
final long recordOffset = offset;
Platform.putInt(base, offset, klen + vlen + 4);
Platform.putInt(base, offset + 4, klen);
UnsafeAlignedOffset.putSize(base, offset, klen + vlen + uaoSize);
UnsafeAlignedOffset.putSize(base, offset + uaoSize, klen);
offset += 8;
offset += 2 * uaoSize;
Platform.copyMemory(kbase, koff, base, offset, klen);
offset += klen;
Platform.copyMemory(vbase, voff, base, offset, vlen);
@ -61,11 +63,11 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB
pageCursor += recordLength;
keyOffsets[numRows] = recordOffset + 8;
keyOffsets[numRows] = recordOffset + 2 * uaoSize;
keyRowId = numRows;
keyRow.pointTo(base, recordOffset + 8, klen);
valueRow.pointTo(base, recordOffset + 8 + klen, vlen);
keyRow.pointTo(base, recordOffset + 2 * uaoSize, klen);
valueRow.pointTo(base, recordOffset + 2 * uaoSize + klen, vlen);
numRows++;
return valueRow;
}
@ -79,7 +81,7 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB
assert(rowId < numRows);
if (keyRowId != rowId) { // if keyRowId == rowId, desired keyRow is already cached
long offset = keyOffsets[rowId];
int klen = Platform.getInt(base, offset - 4);
int klen = UnsafeAlignedOffset.getSize(base, offset - UnsafeAlignedOffset.getUaoSize());
keyRow.pointTo(base, offset, klen);
// set keyRowId so we can check if desired row is cached
keyRowId = rowId;
@ -99,9 +101,10 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB
getKeyRow(rowId);
}
assert(rowId >= 0);
int uaoSize = UnsafeAlignedOffset.getUaoSize();
long offset = keyRow.getBaseOffset();
int klen = keyRow.getSizeInBytes();
int vlen = Platform.getInt(base, offset - 8) - klen - 4;
int vlen = UnsafeAlignedOffset.getSize(base, offset - uaoSize * 2) - klen - uaoSize;
valueRow.pointTo(base, offset + klen, vlen);
return valueRow;
}
@ -141,14 +144,15 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB
return false;
}
totalLength = Platform.getInt(base, offsetInPage) - 4;
currentklen = Platform.getInt(base, offsetInPage + 4);
int uaoSize = UnsafeAlignedOffset.getUaoSize();
totalLength = UnsafeAlignedOffset.getSize(base, offsetInPage) - uaoSize;
currentklen = UnsafeAlignedOffset.getSize(base, offsetInPage + uaoSize);
currentvlen = totalLength - currentklen;
key.pointTo(base, offsetInPage + 8, currentklen);
value.pointTo(base, offsetInPage + 8 + currentklen, currentvlen);
key.pointTo(base, offsetInPage + 2 * uaoSize, currentklen);
value.pointTo(base, offsetInPage + 2 * uaoSize + currentklen, currentvlen);
offsetInPage += 8 + totalLength + 8;
offsetInPage += 2 * uaoSize + totalLength + 8;
recordsInPage -= 1;
return true;
}

View file

@ -35,6 +35,7 @@ import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.UnsafeAlignedOffset;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryBlock;
@ -141,9 +142,10 @@ public final class UnsafeKVExternalSorter {
// Get encoded memory address
// baseObject + baseOffset point to the beginning of the key data in the map, but that
// the KV-pair's length data is stored in the word immediately before that address
// the KV-pair's length data is stored at 2 * uaoSize bytes immediately before that address
MemoryBlock page = loc.getMemoryPage();
long address = taskMemoryManager.encodePageNumberAndOffset(page, baseOffset - 8);
long address = taskMemoryManager.encodePageNumberAndOffset(page,
baseOffset - 2 * UnsafeAlignedOffset.getUaoSize());
// Compute prefix
row.pointTo(baseObject, baseOffset, loc.getKeyLength());
@ -262,10 +264,11 @@ public final class UnsafeKVExternalSorter {
Object baseObj2,
long baseOff2,
int baseLen2) {
int uaoSize = UnsafeAlignedOffset.getUaoSize();
// Note that since ordering doesn't need the total length of the record, we just pass 0
// into the row.
row1.pointTo(baseObj1, baseOff1 + 4, 0);
row2.pointTo(baseObj2, baseOff2 + 4, 0);
row1.pointTo(baseObj1, baseOff1 + uaoSize, 0);
row2.pointTo(baseObj2, baseOff2 + uaoSize, 0);
return ordering.compare(row1, row2);
}
}
@ -289,11 +292,12 @@ public final class UnsafeKVExternalSorter {
long recordOffset = underlying.getBaseOffset();
int recordLen = underlying.getRecordLength();
// Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
// Note that recordLen = keyLen + valueLen + uaoSize (for the keyLen itself)
int uaoSize = UnsafeAlignedOffset.getUaoSize();
int keyLen = Platform.getInt(baseObj, recordOffset);
int valueLen = recordLen - keyLen - 4;
key.pointTo(baseObj, recordOffset + 4, keyLen);
value.pointTo(baseObj, recordOffset + 4 + keyLen, valueLen);
int valueLen = recordLen - keyLen - uaoSize;
key.pointTo(baseObj, recordOffset + uaoSize, keyLen);
value.pointTo(baseObj, recordOffset + uaoSize + keyLen, valueLen);
return true;
} else {

View file

@ -31,6 +31,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UnsafeAlignedOffset
class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction {
@ -1055,30 +1056,35 @@ class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySu
Seq("true", "false").foreach { enableTwoLevelMaps =>
withSQLConf(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key ->
enableTwoLevelMaps) {
(1 to 3).foreach { fallbackStartsAt =>
withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" ->
s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") {
// Create a new df to make sure its physical operator picks up
// spark.sql.TungstenAggregate.testFallbackStartsAt.
// todo: remove it?
val newActual = Dataset.ofRows(spark, actual.logicalPlan)
Seq(4, 8).foreach { uaoSize =>
UnsafeAlignedOffset.setUaoSize(uaoSize)
(1 to 3).foreach { fallbackStartsAt =>
withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" ->
s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") {
// Create a new df to make sure its physical operator picks up
// spark.sql.TungstenAggregate.testFallbackStartsAt.
// todo: remove it?
val newActual = Dataset.ofRows(spark, actual.logicalPlan)
QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match {
case Some(errorMessage) =>
val newErrorMessage =
s"""
|The following aggregation query failed when using HashAggregate with
|controlled fallback (it falls back to bytes to bytes map once it has processed
|${fallbackStartsAt - 1} input rows and to sort-based aggregation once it has
|processed $fallbackStartsAt input rows). The query is ${actual.queryExecution}
|
|$errorMessage
""".stripMargin
QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match {
case Some(errorMessage) =>
val newErrorMessage =
s"""
|The following aggregation query failed when using HashAggregate with
|controlled fallback (it falls back to bytes to bytes map once it has
|processed ${fallbackStartsAt - 1} input rows and to sort-based aggregation
|once it has processed $fallbackStartsAt input rows).
|The query is ${actual.queryExecution}
|$errorMessage
""".stripMargin
fail(newErrorMessage)
case None => // Success
fail(newErrorMessage)
case None => // Success
}
}
}
// reset static uaoSize to avoid affect other tests
UnsafeAlignedOffset.setUaoSize(0)
}
}
}