[SPARK-18458][CORE] Fix signed integer overflow problem at an expression in RadixSort.java
## What changes were proposed in this pull request? This PR avoids that a result of an expression is negative due to signed integer overflow (e.g. 0x10?????? * 8 < 0). This PR casts each operand to `long` before executing a calculation. Since the result is interpreted as long, the result of the expression is positive. ## How was this patch tested? Manually executed query82 of TPC-DS with 100TB Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com> Closes #15907 from kiszk/SPARK-18458.
This commit is contained in:
parent
856e004200
commit
d93b655247
|
@ -17,6 +17,8 @@
|
|||
|
||||
package org.apache.spark.util.collection.unsafe.sort;
|
||||
|
||||
import com.google.common.primitives.Ints;
|
||||
|
||||
import org.apache.spark.unsafe.Platform;
|
||||
import org.apache.spark.unsafe.array.LongArray;
|
||||
|
||||
|
@ -40,14 +42,14 @@ public class RadixSort {
|
|||
* of always copying the data back to position zero for efficiency.
|
||||
*/
|
||||
public static int sort(
|
||||
LongArray array, int numRecords, int startByteIndex, int endByteIndex,
|
||||
LongArray array, long numRecords, int startByteIndex, int endByteIndex,
|
||||
boolean desc, boolean signed) {
|
||||
assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
|
||||
assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
|
||||
assert endByteIndex > startByteIndex;
|
||||
assert numRecords * 2 <= array.size();
|
||||
int inIndex = 0;
|
||||
int outIndex = numRecords;
|
||||
long inIndex = 0;
|
||||
long outIndex = numRecords;
|
||||
if (numRecords > 0) {
|
||||
long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex);
|
||||
for (int i = startByteIndex; i <= endByteIndex; i++) {
|
||||
|
@ -55,13 +57,13 @@ public class RadixSort {
|
|||
sortAtByte(
|
||||
array, numRecords, counts[i], i, inIndex, outIndex,
|
||||
desc, signed && i == endByteIndex);
|
||||
int tmp = inIndex;
|
||||
long tmp = inIndex;
|
||||
inIndex = outIndex;
|
||||
outIndex = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
return inIndex;
|
||||
return Ints.checkedCast(inIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -78,14 +80,14 @@ public class RadixSort {
|
|||
* @param signed whether this is a signed (two's complement) sort (only applies to last byte).
|
||||
*/
|
||||
private static void sortAtByte(
|
||||
LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
|
||||
LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
|
||||
boolean desc, boolean signed) {
|
||||
assert counts.length == 256;
|
||||
long[] offsets = transformCountsToOffsets(
|
||||
counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed);
|
||||
counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed);
|
||||
Object baseObject = array.getBaseObject();
|
||||
long baseOffset = array.getBaseOffset() + inIndex * 8;
|
||||
long maxOffset = baseOffset + numRecords * 8;
|
||||
long baseOffset = array.getBaseOffset() + inIndex * 8L;
|
||||
long maxOffset = baseOffset + numRecords * 8L;
|
||||
for (long offset = baseOffset; offset < maxOffset; offset += 8) {
|
||||
long value = Platform.getLong(baseObject, offset);
|
||||
int bucket = (int)((value >>> (byteIdx * 8)) & 0xff);
|
||||
|
@ -106,13 +108,13 @@ public class RadixSort {
|
|||
* significant byte. If the byte does not need sorting the array will be null.
|
||||
*/
|
||||
private static long[][] getCounts(
|
||||
LongArray array, int numRecords, int startByteIndex, int endByteIndex) {
|
||||
LongArray array, long numRecords, int startByteIndex, int endByteIndex) {
|
||||
long[][] counts = new long[8][];
|
||||
// Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting.
|
||||
// If all the byte values at a particular index are the same we don't need to count it.
|
||||
long bitwiseMax = 0;
|
||||
long bitwiseMin = -1L;
|
||||
long maxOffset = array.getBaseOffset() + numRecords * 8;
|
||||
long maxOffset = array.getBaseOffset() + numRecords * 8L;
|
||||
Object baseObject = array.getBaseObject();
|
||||
for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
|
||||
long value = Platform.getLong(baseObject, offset);
|
||||
|
@ -146,18 +148,18 @@ public class RadixSort {
|
|||
* @return the input counts array.
|
||||
*/
|
||||
private static long[] transformCountsToOffsets(
|
||||
long[] counts, int numRecords, long outputOffset, int bytesPerRecord,
|
||||
long[] counts, long numRecords, long outputOffset, long bytesPerRecord,
|
||||
boolean desc, boolean signed) {
|
||||
assert counts.length == 256;
|
||||
int start = signed ? 128 : 0; // output the negative records first (values 129-255).
|
||||
if (desc) {
|
||||
int pos = numRecords;
|
||||
long pos = numRecords;
|
||||
for (int i = start; i < start + 256; i++) {
|
||||
pos -= counts[i & 0xff];
|
||||
counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
|
||||
}
|
||||
} else {
|
||||
int pos = 0;
|
||||
long pos = 0;
|
||||
for (int i = start; i < start + 256; i++) {
|
||||
long tmp = counts[i & 0xff];
|
||||
counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
|
||||
|
@ -176,8 +178,8 @@ public class RadixSort {
|
|||
*/
|
||||
public static int sortKeyPrefixArray(
|
||||
LongArray array,
|
||||
int startIndex,
|
||||
int numRecords,
|
||||
long startIndex,
|
||||
long numRecords,
|
||||
int startByteIndex,
|
||||
int endByteIndex,
|
||||
boolean desc,
|
||||
|
@ -186,8 +188,8 @@ public class RadixSort {
|
|||
assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
|
||||
assert endByteIndex > startByteIndex;
|
||||
assert numRecords * 4 <= array.size();
|
||||
int inIndex = startIndex;
|
||||
int outIndex = startIndex + numRecords * 2;
|
||||
long inIndex = startIndex;
|
||||
long outIndex = startIndex + numRecords * 2L;
|
||||
if (numRecords > 0) {
|
||||
long[][] counts = getKeyPrefixArrayCounts(
|
||||
array, startIndex, numRecords, startByteIndex, endByteIndex);
|
||||
|
@ -196,13 +198,13 @@ public class RadixSort {
|
|||
sortKeyPrefixArrayAtByte(
|
||||
array, numRecords, counts[i], i, inIndex, outIndex,
|
||||
desc, signed && i == endByteIndex);
|
||||
int tmp = inIndex;
|
||||
long tmp = inIndex;
|
||||
inIndex = outIndex;
|
||||
outIndex = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
return inIndex;
|
||||
return Ints.checkedCast(inIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -210,7 +212,7 @@ public class RadixSort {
|
|||
* getCounts with some added parameters but that seems to hurt in benchmarks.
|
||||
*/
|
||||
private static long[][] getKeyPrefixArrayCounts(
|
||||
LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) {
|
||||
LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) {
|
||||
long[][] counts = new long[8][];
|
||||
long bitwiseMax = 0;
|
||||
long bitwiseMin = -1L;
|
||||
|
@ -238,11 +240,11 @@ public class RadixSort {
|
|||
* Specialization of sortAtByte() for key-prefix arrays.
|
||||
*/
|
||||
private static void sortKeyPrefixArrayAtByte(
|
||||
LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
|
||||
LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
|
||||
boolean desc, boolean signed) {
|
||||
assert counts.length == 256;
|
||||
long[] offsets = transformCountsToOffsets(
|
||||
counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed);
|
||||
counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16, desc, signed);
|
||||
Object baseObject = array.getBaseObject();
|
||||
long baseOffset = array.getBaseOffset() + inIndex * 8L;
|
||||
long maxOffset = baseOffset + numRecords * 16L;
|
||||
|
|
|
@ -322,7 +322,7 @@ public final class UnsafeInMemorySorter {
|
|||
if (sortComparator != null) {
|
||||
if (this.radixSortSupport != null) {
|
||||
offset = RadixSort.sortKeyPrefixArray(
|
||||
array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7,
|
||||
array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7,
|
||||
radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
|
||||
} else {
|
||||
MemoryBlock unused = new MemoryBlock(
|
||||
|
|
|
@ -22,6 +22,8 @@ import java.util.{Arrays, Comparator}
|
|||
|
||||
import scala.util.Random
|
||||
|
||||
import com.google.common.primitives.Ints
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.unsafe.array.LongArray
|
||||
|
@ -30,7 +32,7 @@ import org.apache.spark.util.collection.Sorter
|
|||
import org.apache.spark.util.random.XORShiftRandom
|
||||
|
||||
class RadixSortSuite extends SparkFunSuite with Logging {
|
||||
private val N = 10000 // scale this down for more readable results
|
||||
private val N = 10000L // scale this down for more readable results
|
||||
|
||||
/**
|
||||
* Describes a type of sort to test, e.g. two's complement descending. Each sort type has
|
||||
|
@ -73,22 +75,22 @@ class RadixSortSuite extends SparkFunSuite with Logging {
|
|||
},
|
||||
2, 4, false, false, true))
|
||||
|
||||
private def generateTestData(size: Int, rand: => Long): (Array[JLong], LongArray) = {
|
||||
val ref = Array.tabulate[Long](size) { i => rand }
|
||||
val extended = ref ++ Array.fill[Long](size)(0)
|
||||
private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = {
|
||||
val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand }
|
||||
val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0)
|
||||
(ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended)))
|
||||
}
|
||||
|
||||
private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = {
|
||||
val ref = Array.tabulate[Long](size * 2) { i => rand }
|
||||
val extended = ref ++ Array.fill[Long](size * 2)(0)
|
||||
private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = {
|
||||
val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand }
|
||||
val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0)
|
||||
(new LongArray(MemoryBlock.fromLongArray(ref)),
|
||||
new LongArray(MemoryBlock.fromLongArray(extended)))
|
||||
}
|
||||
|
||||
private def collectToArray(array: LongArray, offset: Int, length: Int): Array[Long] = {
|
||||
private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = {
|
||||
var i = 0
|
||||
val out = new Array[Long](length)
|
||||
val out = new Array[Long](Ints.checkedCast(length))
|
||||
while (i < length) {
|
||||
out(i) = array.get(offset + i)
|
||||
i += 1
|
||||
|
@ -107,15 +109,13 @@ class RadixSortSuite extends SparkFunSuite with Logging {
|
|||
}
|
||||
}
|
||||
|
||||
private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) {
|
||||
private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) {
|
||||
val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
|
||||
new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
|
||||
buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] {
|
||||
buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] {
|
||||
override def compare(
|
||||
r1: RecordPointerAndKeyPrefix,
|
||||
r2: RecordPointerAndKeyPrefix): Int = {
|
||||
refCmp.compare(r1.keyPrefix, r2.keyPrefix)
|
||||
}
|
||||
r2: RecordPointerAndKeyPrefix): Int = refCmp.compare(r1.keyPrefix, r2.keyPrefix)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue