[SPARK-18458][CORE] Fix signed integer overflow problem at an expression in RadixSort.java

## What changes were proposed in this pull request?

This PR avoids that a result of an expression is negative due to signed integer overflow (e.g. 0x10?????? * 8 < 0). This PR casts each operand to `long` before executing a calculation. Since the result is interpreted as long, the result of the expression is positive.

## How was this patch tested?

Manually executed query82 of TPC-DS with 100TB

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #15907 from kiszk/SPARK-18458.
This commit is contained in:
Kazuaki Ishizaki 2016-11-19 21:50:20 -08:00 committed by Reynold Xin
parent 856e004200
commit d93b655247
3 changed files with 40 additions and 38 deletions

View file

@ -17,6 +17,8 @@
package org.apache.spark.util.collection.unsafe.sort;
import com.google.common.primitives.Ints;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
@ -40,14 +42,14 @@ public class RadixSort {
* of always copying the data back to position zero for efficiency.
*/
public static int sort(
LongArray array, int numRecords, int startByteIndex, int endByteIndex,
LongArray array, long numRecords, int startByteIndex, int endByteIndex,
boolean desc, boolean signed) {
assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
assert endByteIndex > startByteIndex;
assert numRecords * 2 <= array.size();
int inIndex = 0;
int outIndex = numRecords;
long inIndex = 0;
long outIndex = numRecords;
if (numRecords > 0) {
long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex);
for (int i = startByteIndex; i <= endByteIndex; i++) {
@ -55,13 +57,13 @@ public class RadixSort {
sortAtByte(
array, numRecords, counts[i], i, inIndex, outIndex,
desc, signed && i == endByteIndex);
int tmp = inIndex;
long tmp = inIndex;
inIndex = outIndex;
outIndex = tmp;
}
}
}
return inIndex;
return Ints.checkedCast(inIndex);
}
/**
@ -78,14 +80,14 @@ public class RadixSort {
* @param signed whether this is a signed (two's complement) sort (only applies to last byte).
*/
private static void sortAtByte(
LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
boolean desc, boolean signed) {
assert counts.length == 256;
long[] offsets = transformCountsToOffsets(
counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed);
counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed);
Object baseObject = array.getBaseObject();
long baseOffset = array.getBaseOffset() + inIndex * 8;
long maxOffset = baseOffset + numRecords * 8;
long baseOffset = array.getBaseOffset() + inIndex * 8L;
long maxOffset = baseOffset + numRecords * 8L;
for (long offset = baseOffset; offset < maxOffset; offset += 8) {
long value = Platform.getLong(baseObject, offset);
int bucket = (int)((value >>> (byteIdx * 8)) & 0xff);
@ -106,13 +108,13 @@ public class RadixSort {
* significant byte. If the byte does not need sorting the array will be null.
*/
private static long[][] getCounts(
LongArray array, int numRecords, int startByteIndex, int endByteIndex) {
LongArray array, long numRecords, int startByteIndex, int endByteIndex) {
long[][] counts = new long[8][];
// Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting.
// If all the byte values at a particular index are the same we don't need to count it.
long bitwiseMax = 0;
long bitwiseMin = -1L;
long maxOffset = array.getBaseOffset() + numRecords * 8;
long maxOffset = array.getBaseOffset() + numRecords * 8L;
Object baseObject = array.getBaseObject();
for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
long value = Platform.getLong(baseObject, offset);
@ -146,18 +148,18 @@ public class RadixSort {
* @return the input counts array.
*/
private static long[] transformCountsToOffsets(
long[] counts, int numRecords, long outputOffset, int bytesPerRecord,
long[] counts, long numRecords, long outputOffset, long bytesPerRecord,
boolean desc, boolean signed) {
assert counts.length == 256;
int start = signed ? 128 : 0; // output the negative records first (values 129-255).
if (desc) {
int pos = numRecords;
long pos = numRecords;
for (int i = start; i < start + 256; i++) {
pos -= counts[i & 0xff];
counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
}
} else {
int pos = 0;
long pos = 0;
for (int i = start; i < start + 256; i++) {
long tmp = counts[i & 0xff];
counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
@ -176,8 +178,8 @@ public class RadixSort {
*/
public static int sortKeyPrefixArray(
LongArray array,
int startIndex,
int numRecords,
long startIndex,
long numRecords,
int startByteIndex,
int endByteIndex,
boolean desc,
@ -186,8 +188,8 @@ public class RadixSort {
assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
assert endByteIndex > startByteIndex;
assert numRecords * 4 <= array.size();
int inIndex = startIndex;
int outIndex = startIndex + numRecords * 2;
long inIndex = startIndex;
long outIndex = startIndex + numRecords * 2L;
if (numRecords > 0) {
long[][] counts = getKeyPrefixArrayCounts(
array, startIndex, numRecords, startByteIndex, endByteIndex);
@ -196,13 +198,13 @@ public class RadixSort {
sortKeyPrefixArrayAtByte(
array, numRecords, counts[i], i, inIndex, outIndex,
desc, signed && i == endByteIndex);
int tmp = inIndex;
long tmp = inIndex;
inIndex = outIndex;
outIndex = tmp;
}
}
}
return inIndex;
return Ints.checkedCast(inIndex);
}
/**
@ -210,7 +212,7 @@ public class RadixSort {
* getCounts with some added parameters but that seems to hurt in benchmarks.
*/
private static long[][] getKeyPrefixArrayCounts(
LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) {
LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) {
long[][] counts = new long[8][];
long bitwiseMax = 0;
long bitwiseMin = -1L;
@ -238,11 +240,11 @@ public class RadixSort {
* Specialization of sortAtByte() for key-prefix arrays.
*/
private static void sortKeyPrefixArrayAtByte(
LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
boolean desc, boolean signed) {
assert counts.length == 256;
long[] offsets = transformCountsToOffsets(
counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed);
counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16, desc, signed);
Object baseObject = array.getBaseObject();
long baseOffset = array.getBaseOffset() + inIndex * 8L;
long maxOffset = baseOffset + numRecords * 16L;

View file

@ -322,7 +322,7 @@ public final class UnsafeInMemorySorter {
if (sortComparator != null) {
if (this.radixSortSupport != null) {
offset = RadixSort.sortKeyPrefixArray(
array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7,
array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7,
radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
} else {
MemoryBlock unused = new MemoryBlock(

View file

@ -22,6 +22,8 @@ import java.util.{Arrays, Comparator}
import scala.util.Random
import com.google.common.primitives.Ints
import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.unsafe.array.LongArray
@ -30,7 +32,7 @@ import org.apache.spark.util.collection.Sorter
import org.apache.spark.util.random.XORShiftRandom
class RadixSortSuite extends SparkFunSuite with Logging {
private val N = 10000 // scale this down for more readable results
private val N = 10000L // scale this down for more readable results
/**
* Describes a type of sort to test, e.g. two's complement descending. Each sort type has
@ -73,22 +75,22 @@ class RadixSortSuite extends SparkFunSuite with Logging {
},
2, 4, false, false, true))
private def generateTestData(size: Int, rand: => Long): (Array[JLong], LongArray) = {
val ref = Array.tabulate[Long](size) { i => rand }
val extended = ref ++ Array.fill[Long](size)(0)
private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = {
val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand }
val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0)
(ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended)))
}
private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = {
val ref = Array.tabulate[Long](size * 2) { i => rand }
val extended = ref ++ Array.fill[Long](size * 2)(0)
private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = {
val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand }
val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0)
(new LongArray(MemoryBlock.fromLongArray(ref)),
new LongArray(MemoryBlock.fromLongArray(extended)))
}
private def collectToArray(array: LongArray, offset: Int, length: Int): Array[Long] = {
private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = {
var i = 0
val out = new Array[Long](length)
val out = new Array[Long](Ints.checkedCast(length))
while (i < length) {
out(i) = array.get(offset + i)
i += 1
@ -107,15 +109,13 @@ class RadixSortSuite extends SparkFunSuite with Logging {
}
}
private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) {
private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) {
val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] {
buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] {
override def compare(
r1: RecordPointerAndKeyPrefix,
r2: RecordPointerAndKeyPrefix): Int = {
refCmp.compare(r1.keyPrefix, r2.keyPrefix)
}
r2: RecordPointerAndKeyPrefix): Int = refCmp.compare(r1.keyPrefix, r2.keyPrefix)
})
}