[SPARK-31425][SQL][CORE] UnsafeKVExternalSorter/VariableLengthRowBasedKeyValueBatch should also respect UnsafeAlignedOffset
### What changes were proposed in this pull request? Make `UnsafeKVExternalSorter` / `VariableLengthRowBasedKeyValueBatch ` also respect `UnsafeAlignedOffset` when reading the record and update some out of date comemnts. ### Why are the changes needed? Since `BytesToBytesMap` respects `UnsafeAlignedOffset` when writing the record, `UnsafeKVExternalSorter` should also respect `UnsafeAlignedOffset` when reading the record from `BytesToBytesMap` otherwise it will causes data correctness issue. Unlike `UnsafeKVExternalSorter` may reading records from `BytesToBytesMap`, `VariableLengthRowBasedKeyValueBatch` writes and reads records by itself. Thus, similar to #22053 and [comment](https://github.com/apache/spark/pull/22053#issuecomment-411975239) there, fix for `VariableLengthRowBasedKeyValueBatch` more likely an improvement for the support of SPARC platform. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Manually tested `HashAggregationQueryWithControlledFallbackSuite` with `UAO_SIZE=8` to simulate SPARC platform. And tests only pass with this fix. Closes #28195 from Ngone51/fix_uao. Authored-by: yi.wu <yi.wu@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
b2e9e1717b
commit
40f9dbb628
|
@ -28,12 +28,20 @@ public class UnsafeAlignedOffset {
|
|||
|
||||
private static final int UAO_SIZE = Platform.unaligned() ? 4 : 8;
|
||||
|
||||
private static int TEST_UAO_SIZE = 0;
|
||||
|
||||
// used for test only
|
||||
public static void setUaoSize(int size) {
|
||||
assert size == 0 || size == 4 || size == 8;
|
||||
TEST_UAO_SIZE = size;
|
||||
}
|
||||
|
||||
public static int getUaoSize() {
|
||||
return UAO_SIZE;
|
||||
return TEST_UAO_SIZE == 0 ? UAO_SIZE : TEST_UAO_SIZE;
|
||||
}
|
||||
|
||||
public static int getSize(Object object, long offset) {
|
||||
switch (UAO_SIZE) {
|
||||
switch (getUaoSize()) {
|
||||
case 4:
|
||||
return Platform.getInt(object, offset);
|
||||
case 8:
|
||||
|
@ -46,7 +54,7 @@ public class UnsafeAlignedOffset {
|
|||
}
|
||||
|
||||
public static void putSize(Object object, long offset, int value) {
|
||||
switch (UAO_SIZE) {
|
||||
switch (getUaoSize()) {
|
||||
case 4:
|
||||
Platform.putInt(object, offset, value);
|
||||
break;
|
||||
|
|
|
@ -54,13 +54,13 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter;
|
|||
* probably be using sorting instead of hashing for better cache locality.
|
||||
*
|
||||
* The key and values under the hood are stored together, in the following format:
|
||||
* Bytes 0 to 4: len(k) (key length in bytes) + len(v) (value length in bytes) + 4
|
||||
* Bytes 4 to 8: len(k)
|
||||
* Bytes 8 to 8 + len(k): key data
|
||||
* Bytes 8 + len(k) to 8 + len(k) + len(v): value data
|
||||
* Bytes 8 + len(k) + len(v) to 8 + len(k) + len(v) + 8: pointer to next pair
|
||||
* First uaoSize bytes: len(k) (key length in bytes) + len(v) (value length in bytes) + uaoSize
|
||||
* Next uaoSize bytes: len(k)
|
||||
* Next len(k) bytes: key data
|
||||
* Next len(v) bytes: value data
|
||||
* Last 8 bytes: pointer to next pair
|
||||
*
|
||||
* This means that the first four bytes store the entire record (key + value) length. This format
|
||||
* It means first uaoSize bytes store the entire record (key + value + uaoSize) length. This format
|
||||
* is compatible with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter},
|
||||
* so we can pass records from this map directly into the sorter to sort records in place.
|
||||
*/
|
||||
|
@ -706,7 +706,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
|
|||
// Here, we'll copy the data into our data pages. Because we only store a relative offset from
|
||||
// the key address instead of storing the absolute address of the value, the key and value
|
||||
// must be stored in the same memory page.
|
||||
// (8 byte key length) (key) (value) (8 byte pointer to next value)
|
||||
// (total length) (key length) (key) (value) (8 byte pointer to next value)
|
||||
int uaoSize = UnsafeAlignedOffset.getUaoSize();
|
||||
final long recordLength = (2L * uaoSize) + klen + vlen + 8;
|
||||
if (currentPage == null || currentPage.size() - pageCursor < recordLength) {
|
||||
|
|
|
@ -235,7 +235,7 @@ public final class UnsafeInMemorySorter {
|
|||
|
||||
/**
|
||||
* Inserts a record to be sorted. Assumes that the record pointer points to a record length
|
||||
* stored as a 4-byte integer, followed by the record's bytes.
|
||||
* stored as a uaoSize(4 or 8) bytes integer, followed by the record's bytes.
|
||||
*
|
||||
* @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
|
||||
* @param keyPrefix a user-defined key prefix
|
||||
|
|
|
@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.expressions;
|
|||
import org.apache.spark.memory.TaskMemoryManager;
|
||||
import org.apache.spark.sql.types.*;
|
||||
import org.apache.spark.unsafe.Platform;
|
||||
import org.apache.spark.unsafe.UnsafeAlignedOffset;
|
||||
|
||||
/**
|
||||
* An implementation of `RowBasedKeyValueBatch` in which key-value records have variable lengths.
|
||||
*
|
||||
* The format for each record looks like this:
|
||||
* The format for each record looks like this (in case of uaoSize = 4):
|
||||
* [4 bytes total size = (klen + vlen + 4)] [4 bytes key size = klen]
|
||||
* [UnsafeRow for key of length klen] [UnsafeRow for Value of length vlen]
|
||||
* [8 bytes pointer to next]
|
||||
|
@ -41,7 +42,8 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB
|
|||
@Override
|
||||
public UnsafeRow appendRow(Object kbase, long koff, int klen,
|
||||
Object vbase, long voff, int vlen) {
|
||||
final long recordLength = 8L + klen + vlen + 8;
|
||||
int uaoSize = UnsafeAlignedOffset.getUaoSize();
|
||||
final long recordLength = 2 * uaoSize + klen + vlen + 8L;
|
||||
// if run out of max supported rows or page size, return null
|
||||
if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) {
|
||||
return null;
|
||||
|
@ -49,10 +51,10 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB
|
|||
|
||||
long offset = page.getBaseOffset() + pageCursor;
|
||||
final long recordOffset = offset;
|
||||
Platform.putInt(base, offset, klen + vlen + 4);
|
||||
Platform.putInt(base, offset + 4, klen);
|
||||
UnsafeAlignedOffset.putSize(base, offset, klen + vlen + uaoSize);
|
||||
UnsafeAlignedOffset.putSize(base, offset + uaoSize, klen);
|
||||
|
||||
offset += 8;
|
||||
offset += 2 * uaoSize;
|
||||
Platform.copyMemory(kbase, koff, base, offset, klen);
|
||||
offset += klen;
|
||||
Platform.copyMemory(vbase, voff, base, offset, vlen);
|
||||
|
@ -61,11 +63,11 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB
|
|||
|
||||
pageCursor += recordLength;
|
||||
|
||||
keyOffsets[numRows] = recordOffset + 8;
|
||||
keyOffsets[numRows] = recordOffset + 2 * uaoSize;
|
||||
|
||||
keyRowId = numRows;
|
||||
keyRow.pointTo(base, recordOffset + 8, klen);
|
||||
valueRow.pointTo(base, recordOffset + 8 + klen, vlen);
|
||||
keyRow.pointTo(base, recordOffset + 2 * uaoSize, klen);
|
||||
valueRow.pointTo(base, recordOffset + 2 * uaoSize + klen, vlen);
|
||||
numRows++;
|
||||
return valueRow;
|
||||
}
|
||||
|
@ -79,7 +81,7 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB
|
|||
assert(rowId < numRows);
|
||||
if (keyRowId != rowId) { // if keyRowId == rowId, desired keyRow is already cached
|
||||
long offset = keyOffsets[rowId];
|
||||
int klen = Platform.getInt(base, offset - 4);
|
||||
int klen = UnsafeAlignedOffset.getSize(base, offset - UnsafeAlignedOffset.getUaoSize());
|
||||
keyRow.pointTo(base, offset, klen);
|
||||
// set keyRowId so we can check if desired row is cached
|
||||
keyRowId = rowId;
|
||||
|
@ -99,9 +101,10 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB
|
|||
getKeyRow(rowId);
|
||||
}
|
||||
assert(rowId >= 0);
|
||||
int uaoSize = UnsafeAlignedOffset.getUaoSize();
|
||||
long offset = keyRow.getBaseOffset();
|
||||
int klen = keyRow.getSizeInBytes();
|
||||
int vlen = Platform.getInt(base, offset - 8) - klen - 4;
|
||||
int vlen = UnsafeAlignedOffset.getSize(base, offset - uaoSize * 2) - klen - uaoSize;
|
||||
valueRow.pointTo(base, offset + klen, vlen);
|
||||
return valueRow;
|
||||
}
|
||||
|
@ -141,14 +144,15 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB
|
|||
return false;
|
||||
}
|
||||
|
||||
totalLength = Platform.getInt(base, offsetInPage) - 4;
|
||||
currentklen = Platform.getInt(base, offsetInPage + 4);
|
||||
int uaoSize = UnsafeAlignedOffset.getUaoSize();
|
||||
totalLength = UnsafeAlignedOffset.getSize(base, offsetInPage) - uaoSize;
|
||||
currentklen = UnsafeAlignedOffset.getSize(base, offsetInPage + uaoSize);
|
||||
currentvlen = totalLength - currentklen;
|
||||
|
||||
key.pointTo(base, offsetInPage + 8, currentklen);
|
||||
value.pointTo(base, offsetInPage + 8 + currentklen, currentvlen);
|
||||
key.pointTo(base, offsetInPage + 2 * uaoSize, currentklen);
|
||||
value.pointTo(base, offsetInPage + 2 * uaoSize + currentklen, currentvlen);
|
||||
|
||||
offsetInPage += 8 + totalLength + 8;
|
||||
offsetInPage += 2 * uaoSize + totalLength + 8;
|
||||
recordsInPage -= 1;
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -35,6 +35,7 @@ import org.apache.spark.sql.types.StructType;
|
|||
import org.apache.spark.storage.BlockManager;
|
||||
import org.apache.spark.unsafe.KVIterator;
|
||||
import org.apache.spark.unsafe.Platform;
|
||||
import org.apache.spark.unsafe.UnsafeAlignedOffset;
|
||||
import org.apache.spark.unsafe.array.LongArray;
|
||||
import org.apache.spark.unsafe.map.BytesToBytesMap;
|
||||
import org.apache.spark.unsafe.memory.MemoryBlock;
|
||||
|
@ -141,9 +142,10 @@ public final class UnsafeKVExternalSorter {
|
|||
|
||||
// Get encoded memory address
|
||||
// baseObject + baseOffset point to the beginning of the key data in the map, but that
|
||||
// the KV-pair's length data is stored in the word immediately before that address
|
||||
// the KV-pair's length data is stored at 2 * uaoSize bytes immediately before that address
|
||||
MemoryBlock page = loc.getMemoryPage();
|
||||
long address = taskMemoryManager.encodePageNumberAndOffset(page, baseOffset - 8);
|
||||
long address = taskMemoryManager.encodePageNumberAndOffset(page,
|
||||
baseOffset - 2 * UnsafeAlignedOffset.getUaoSize());
|
||||
|
||||
// Compute prefix
|
||||
row.pointTo(baseObject, baseOffset, loc.getKeyLength());
|
||||
|
@ -262,10 +264,11 @@ public final class UnsafeKVExternalSorter {
|
|||
Object baseObj2,
|
||||
long baseOff2,
|
||||
int baseLen2) {
|
||||
int uaoSize = UnsafeAlignedOffset.getUaoSize();
|
||||
// Note that since ordering doesn't need the total length of the record, we just pass 0
|
||||
// into the row.
|
||||
row1.pointTo(baseObj1, baseOff1 + 4, 0);
|
||||
row2.pointTo(baseObj2, baseOff2 + 4, 0);
|
||||
row1.pointTo(baseObj1, baseOff1 + uaoSize, 0);
|
||||
row2.pointTo(baseObj2, baseOff2 + uaoSize, 0);
|
||||
return ordering.compare(row1, row2);
|
||||
}
|
||||
}
|
||||
|
@ -289,11 +292,12 @@ public final class UnsafeKVExternalSorter {
|
|||
long recordOffset = underlying.getBaseOffset();
|
||||
int recordLen = underlying.getRecordLength();
|
||||
|
||||
// Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
|
||||
// Note that recordLen = keyLen + valueLen + uaoSize (for the keyLen itself)
|
||||
int uaoSize = UnsafeAlignedOffset.getUaoSize();
|
||||
int keyLen = Platform.getInt(baseObj, recordOffset);
|
||||
int valueLen = recordLen - keyLen - 4;
|
||||
key.pointTo(baseObj, recordOffset + 4, keyLen);
|
||||
value.pointTo(baseObj, recordOffset + 4 + keyLen, valueLen);
|
||||
int valueLen = recordLen - keyLen - uaoSize;
|
||||
key.pointTo(baseObj, recordOffset + uaoSize, keyLen);
|
||||
value.pointTo(baseObj, recordOffset + uaoSize + keyLen, valueLen);
|
||||
|
||||
return true;
|
||||
} else {
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton
|
|||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.UnsafeAlignedOffset
|
||||
|
||||
|
||||
class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction {
|
||||
|
@ -1055,30 +1056,35 @@ class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySu
|
|||
Seq("true", "false").foreach { enableTwoLevelMaps =>
|
||||
withSQLConf(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key ->
|
||||
enableTwoLevelMaps) {
|
||||
(1 to 3).foreach { fallbackStartsAt =>
|
||||
withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" ->
|
||||
s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") {
|
||||
// Create a new df to make sure its physical operator picks up
|
||||
// spark.sql.TungstenAggregate.testFallbackStartsAt.
|
||||
// todo: remove it?
|
||||
val newActual = Dataset.ofRows(spark, actual.logicalPlan)
|
||||
Seq(4, 8).foreach { uaoSize =>
|
||||
UnsafeAlignedOffset.setUaoSize(uaoSize)
|
||||
(1 to 3).foreach { fallbackStartsAt =>
|
||||
withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" ->
|
||||
s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") {
|
||||
// Create a new df to make sure its physical operator picks up
|
||||
// spark.sql.TungstenAggregate.testFallbackStartsAt.
|
||||
// todo: remove it?
|
||||
val newActual = Dataset.ofRows(spark, actual.logicalPlan)
|
||||
|
||||
QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match {
|
||||
case Some(errorMessage) =>
|
||||
val newErrorMessage =
|
||||
s"""
|
||||
|The following aggregation query failed when using HashAggregate with
|
||||
|controlled fallback (it falls back to bytes to bytes map once it has processed
|
||||
|${fallbackStartsAt - 1} input rows and to sort-based aggregation once it has
|
||||
|processed $fallbackStartsAt input rows). The query is ${actual.queryExecution}
|
||||
|
|
||||
|$errorMessage
|
||||
""".stripMargin
|
||||
QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match {
|
||||
case Some(errorMessage) =>
|
||||
val newErrorMessage =
|
||||
s"""
|
||||
|The following aggregation query failed when using HashAggregate with
|
||||
|controlled fallback (it falls back to bytes to bytes map once it has
|
||||
|processed ${fallbackStartsAt - 1} input rows and to sort-based aggregation
|
||||
|once it has processed $fallbackStartsAt input rows).
|
||||
|The query is ${actual.queryExecution}
|
||||
|$errorMessage
|
||||
""".stripMargin
|
||||
|
||||
fail(newErrorMessage)
|
||||
case None => // Success
|
||||
fail(newErrorMessage)
|
||||
case None => // Success
|
||||
}
|
||||
}
|
||||
}
|
||||
// reset static uaoSize to avoid affect other tests
|
||||
UnsafeAlignedOffset.setUaoSize(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue