diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java index 546e8780a6..d399e66aa2 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/UnsafeAlignedOffset.java @@ -28,12 +28,20 @@ public class UnsafeAlignedOffset { private static final int UAO_SIZE = Platform.unaligned() ? 4 : 8; + private static int TEST_UAO_SIZE = 0; + + // used for test only + public static void setUaoSize(int size) { + assert size == 0 || size == 4 || size == 8; + TEST_UAO_SIZE = size; + } + public static int getUaoSize() { - return UAO_SIZE; + return TEST_UAO_SIZE == 0 ? UAO_SIZE : TEST_UAO_SIZE; } public static int getSize(Object object, long offset) { - switch (UAO_SIZE) { + switch (getUaoSize()) { case 4: return Platform.getInt(object, offset); case 8: @@ -46,7 +54,7 @@ public class UnsafeAlignedOffset { } public static void putSize(Object object, long offset, int value) { - switch (UAO_SIZE) { + switch (getUaoSize()) { case 4: Platform.putInt(object, offset, value); break; diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index a57cd3b3f3..64c240cea8 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -54,13 +54,13 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter; * probably be using sorting instead of hashing for better cache locality. * * The key and values under the hood are stored together, in the following format: - * Bytes 0 to 4: len(k) (key length in bytes) + len(v) (value length in bytes) + 4 - * Bytes 4 to 8: len(k) - * Bytes 8 to 8 + len(k): key data - * Bytes 8 + len(k) to 8 + len(k) + len(v): value data - * Bytes 8 + len(k) + len(v) to 8 + len(k) + len(v) + 8: pointer to next pair + * First uaoSize bytes: len(k) (key length in bytes) + len(v) (value length in bytes) + uaoSize + * Next uaoSize bytes: len(k) + * Next len(k) bytes: key data + * Next len(v) bytes: value data + * Last 8 bytes: pointer to next pair * - * This means that the first four bytes store the entire record (key + value) length. This format + * It means first uaoSize bytes store the entire record (key + value + uaoSize) length. This format * is compatible with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter}, * so we can pass records from this map directly into the sorter to sort records in place. */ @@ -706,7 +706,7 @@ public final class BytesToBytesMap extends MemoryConsumer { // Here, we'll copy the data into our data pages. Because we only store a relative offset from // the key address instead of storing the absolute address of the value, the key and value // must be stored in the same memory page. - // (8 byte key length) (key) (value) (8 byte pointer to next value) + // (total length) (key length) (key) (value) (8 byte pointer to next value) int uaoSize = UnsafeAlignedOffset.getUaoSize(); final long recordLength = (2L * uaoSize) + klen + vlen + 8; if (currentPage == null || currentPage.size() - pageCursor < recordLength) { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index e14964d681..660eb790a5 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -235,7 +235,7 @@ public final class UnsafeInMemorySorter { /** * Inserts a record to be sorted. Assumes that the record pointer points to a record length - * stored as a 4-byte integer, followed by the record's bytes. + * stored as a uaoSize(4 or 8) bytes integer, followed by the record's bytes. * * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}. * @param keyPrefix a user-defined key prefix diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java index c823de4810..4ee913c9bf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; /** * An implementation of `RowBasedKeyValueBatch` in which key-value records have variable lengths. * - * The format for each record looks like this: + * The format for each record looks like this (in case of uaoSize = 4): * [4 bytes total size = (klen + vlen + 4)] [4 bytes key size = klen] * [UnsafeRow for key of length klen] [UnsafeRow for Value of length vlen] * [8 bytes pointer to next] @@ -41,7 +42,8 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB @Override public UnsafeRow appendRow(Object kbase, long koff, int klen, Object vbase, long voff, int vlen) { - final long recordLength = 8L + klen + vlen + 8; + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + final long recordLength = 2 * uaoSize + klen + vlen + 8L; // if run out of max supported rows or page size, return null if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) { return null; @@ -49,10 +51,10 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB long offset = page.getBaseOffset() + pageCursor; final long recordOffset = offset; - Platform.putInt(base, offset, klen + vlen + 4); - Platform.putInt(base, offset + 4, klen); + UnsafeAlignedOffset.putSize(base, offset, klen + vlen + uaoSize); + UnsafeAlignedOffset.putSize(base, offset + uaoSize, klen); - offset += 8; + offset += 2 * uaoSize; Platform.copyMemory(kbase, koff, base, offset, klen); offset += klen; Platform.copyMemory(vbase, voff, base, offset, vlen); @@ -61,11 +63,11 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB pageCursor += recordLength; - keyOffsets[numRows] = recordOffset + 8; + keyOffsets[numRows] = recordOffset + 2 * uaoSize; keyRowId = numRows; - keyRow.pointTo(base, recordOffset + 8, klen); - valueRow.pointTo(base, recordOffset + 8 + klen, vlen); + keyRow.pointTo(base, recordOffset + 2 * uaoSize, klen); + valueRow.pointTo(base, recordOffset + 2 * uaoSize + klen, vlen); numRows++; return valueRow; } @@ -79,7 +81,7 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB assert(rowId < numRows); if (keyRowId != rowId) { // if keyRowId == rowId, desired keyRow is already cached long offset = keyOffsets[rowId]; - int klen = Platform.getInt(base, offset - 4); + int klen = UnsafeAlignedOffset.getSize(base, offset - UnsafeAlignedOffset.getUaoSize()); keyRow.pointTo(base, offset, klen); // set keyRowId so we can check if desired row is cached keyRowId = rowId; @@ -99,9 +101,10 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB getKeyRow(rowId); } assert(rowId >= 0); + int uaoSize = UnsafeAlignedOffset.getUaoSize(); long offset = keyRow.getBaseOffset(); int klen = keyRow.getSizeInBytes(); - int vlen = Platform.getInt(base, offset - 8) - klen - 4; + int vlen = UnsafeAlignedOffset.getSize(base, offset - uaoSize * 2) - klen - uaoSize; valueRow.pointTo(base, offset + klen, vlen); return valueRow; } @@ -141,14 +144,15 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB return false; } - totalLength = Platform.getInt(base, offsetInPage) - 4; - currentklen = Platform.getInt(base, offsetInPage + 4); + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + totalLength = UnsafeAlignedOffset.getSize(base, offsetInPage) - uaoSize; + currentklen = UnsafeAlignedOffset.getSize(base, offsetInPage + uaoSize); currentvlen = totalLength - currentklen; - key.pointTo(base, offsetInPage + 8, currentklen); - value.pointTo(base, offsetInPage + 8 + currentklen, currentvlen); + key.pointTo(base, offsetInPage + 2 * uaoSize, currentklen); + value.pointTo(base, offsetInPage + 2 * uaoSize + currentklen, currentvlen); - offsetInPage += 8 + totalLength + 8; + offsetInPage += 2 * uaoSize + totalLength + 8; recordsInPage -= 1; return true; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index acd54fe25d..7a9f61a2cc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -35,6 +35,7 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -141,9 +142,10 @@ public final class UnsafeKVExternalSorter { // Get encoded memory address // baseObject + baseOffset point to the beginning of the key data in the map, but that - // the KV-pair's length data is stored in the word immediately before that address + // the KV-pair's length data is stored at 2 * uaoSize bytes immediately before that address MemoryBlock page = loc.getMemoryPage(); - long address = taskMemoryManager.encodePageNumberAndOffset(page, baseOffset - 8); + long address = taskMemoryManager.encodePageNumberAndOffset(page, + baseOffset - 2 * UnsafeAlignedOffset.getUaoSize()); // Compute prefix row.pointTo(baseObject, baseOffset, loc.getKeyLength()); @@ -262,10 +264,11 @@ public final class UnsafeKVExternalSorter { Object baseObj2, long baseOff2, int baseLen2) { + int uaoSize = UnsafeAlignedOffset.getUaoSize(); // Note that since ordering doesn't need the total length of the record, we just pass 0 // into the row. - row1.pointTo(baseObj1, baseOff1 + 4, 0); - row2.pointTo(baseObj2, baseOff2 + 4, 0); + row1.pointTo(baseObj1, baseOff1 + uaoSize, 0); + row2.pointTo(baseObj2, baseOff2 + uaoSize, 0); return ordering.compare(row1, row2); } } @@ -289,11 +292,12 @@ public final class UnsafeKVExternalSorter { long recordOffset = underlying.getBaseOffset(); int recordLen = underlying.getRecordLength(); - // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself) + // Note that recordLen = keyLen + valueLen + uaoSize (for the keyLen itself) + int uaoSize = UnsafeAlignedOffset.getUaoSize(); int keyLen = Platform.getInt(baseObj, recordOffset); - int valueLen = recordLen - keyLen - 4; - key.pointTo(baseObj, recordOffset + 4, keyLen); - value.pointTo(baseObj, recordOffset + 4 + keyLen, valueLen); + int valueLen = recordLen - keyLen - uaoSize; + key.pointTo(baseObj, recordOffset + uaoSize, keyLen); + value.pointTo(baseObj, recordOffset + uaoSize + keyLen, valueLen); return true; } else { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index f84b854048..ce40a65ed6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.UnsafeAlignedOffset class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction { @@ -1055,30 +1056,35 @@ class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySu Seq("true", "false").foreach { enableTwoLevelMaps => withSQLConf(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> enableTwoLevelMaps) { - (1 to 3).foreach { fallbackStartsAt => - withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> - s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") { - // Create a new df to make sure its physical operator picks up - // spark.sql.TungstenAggregate.testFallbackStartsAt. - // todo: remove it? - val newActual = Dataset.ofRows(spark, actual.logicalPlan) + Seq(4, 8).foreach { uaoSize => + UnsafeAlignedOffset.setUaoSize(uaoSize) + (1 to 3).foreach { fallbackStartsAt => + withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> + s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") { + // Create a new df to make sure its physical operator picks up + // spark.sql.TungstenAggregate.testFallbackStartsAt. + // todo: remove it? + val newActual = Dataset.ofRows(spark, actual.logicalPlan) - QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match { - case Some(errorMessage) => - val newErrorMessage = - s""" - |The following aggregation query failed when using HashAggregate with - |controlled fallback (it falls back to bytes to bytes map once it has processed - |${fallbackStartsAt - 1} input rows and to sort-based aggregation once it has - |processed $fallbackStartsAt input rows). The query is ${actual.queryExecution} - | - |$errorMessage - """.stripMargin + QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using HashAggregate with + |controlled fallback (it falls back to bytes to bytes map once it has + |processed ${fallbackStartsAt - 1} input rows and to sort-based aggregation + |once it has processed $fallbackStartsAt input rows). + |The query is ${actual.queryExecution} + |$errorMessage + """.stripMargin - fail(newErrorMessage) - case None => // Success + fail(newErrorMessage) + case None => // Success + } } } + // reset static uaoSize to avoid affect other tests + UnsafeAlignedOffset.setUaoSize(0) } } }