[SPARK-15962][SQL] Introduce implementation with a dense format for UnsafeArrayData
## What changes were proposed in this pull request?
This PR introduces more compact representation for ```UnsafeArrayData```.
```UnsafeArrayData``` needs to accept ```null``` value in each entry of an array. In the current version, it has three parts
```
[numElements] [offsets] [values]
```
`Offsets` has the number of `numElements`, and represents `null` if its value is negative. It may increase memory footprint, and introduces an indirection for accessing each of `values`.
This PR uses bitvectors to represent nullability for each element like `UnsafeRow`, and eliminates an indirection for accessing each element. The new ```UnsafeArrayData``` has four parts.
```
[numElements][null bits][values or offset&length][variable length portion]
```
In the `null bits` region, we store 1 bit per element, represents whether an element is null. Its total size is ceil(numElements / 8) bytes, and it is aligned to 8-byte boundaries.
In the `values or offset&length` region, we store the content of elements. For fields that hold fixed-length primitive types, such as long, double, or int, we store the value directly in the field. For fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the base address of the array) that points to the beginning of the variable-length field and length (they are combined into a long). Each is word-aligned. For `variable length portion`, each is aligned to 8-byte boundaries.
The new format can reduce memory footprint and improve performance of accessing each element. An example of memory foot comparison:
1024x1024 elements integer array
Size of ```baseObject``` for ```UnsafeArrayData```: 8 + 1024x1024 + 1024x1024 = 2M bytes
Size of ```baseObject``` for ```UnsafeArrayData```: 8 + 1024x1024/8 + 1024x1024 = 1.25M bytes
In summary, we got 1.0-2.6x performance improvements over the code before applying this PR.
Here are performance results of [benchmark programs](04d2e4b6db/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala
):
**Read UnsafeArrayData**: 1.7x and 1.6x performance improvements over the code before applying this PR
````
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)
Without SPARK-15962
Read UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Int 430 / 436 390.0 2.6 1.0X
Double 456 / 485 367.8 2.7 0.9X
With SPARK-15962
Read UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Int 252 / 260 666.1 1.5 1.0X
Double 281 / 292 597.7 1.7 0.9X
````
**Write UnsafeArrayData**: 1.0x and 1.1x performance improvements over the code before applying this PR
````
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.0.4-301.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)
Without SPARK-15962
Write UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Int 203 / 273 103.4 9.7 1.0X
Double 239 / 356 87.9 11.4 0.8X
With SPARK-15962
Write UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Int 196 / 249 107.0 9.3 1.0X
Double 227 / 367 92.3 10.8 0.9X
````
**Get primitive array from UnsafeArrayData**: 2.6x and 1.6x performance improvements over the code before applying this PR
````
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.0.4-301.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)
Without SPARK-15962
Get primitive array from UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Int 207 / 217 304.2 3.3 1.0X
Double 257 / 363 245.2 4.1 0.8X
With SPARK-15962
Get primitive array from UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Int 151 / 198 415.8 2.4 1.0X
Double 214 / 394 293.6 3.4 0.7X
````
**Create UnsafeArrayData from primitive array**: 1.7x and 2.1x performance improvements over the code before applying this PR
````
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.0.4-301.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)
Without SPARK-15962
Create UnsafeArrayData from primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Int 340 / 385 185.1 5.4 1.0X
Double 479 / 705 131.3 7.6 0.7X
With SPARK-15962
Create UnsafeArrayData from primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Int 206 / 211 306.0 3.3 1.0X
Double 232 / 406 271.6 3.7 0.9X
````
1.7x and 1.4x performance improvements in [```UDTSerializationBenchmark```](https://github.com/apache/spark/blob/master/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala) over the code before applying this PR
````
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)
Without SPARK-15962
VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
serialize 442 / 533 0.0 441927.1 1.0X
deserialize 217 / 274 0.0 217087.6 2.0X
With SPARK-15962
VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
serialize 265 / 318 0.0 265138.5 1.0X
deserialize 155 / 197 0.0 154611.4 1.7X
````
## How was this patch tested?
Added unit tests into ```UnsafeArraySuite```
Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>
Closes #13680 from kiszk/SPARK-15962.
This commit is contained in:
parent
6ee28423ad
commit
85b0a15754
|
@ -29,6 +29,8 @@ public final class Platform {
|
|||
|
||||
private static final Unsafe _UNSAFE;
|
||||
|
||||
public static final int BOOLEAN_ARRAY_OFFSET;
|
||||
|
||||
public static final int BYTE_ARRAY_OFFSET;
|
||||
|
||||
public static final int SHORT_ARRAY_OFFSET;
|
||||
|
@ -235,6 +237,7 @@ public final class Platform {
|
|||
_UNSAFE = unsafe;
|
||||
|
||||
if (_UNSAFE != null) {
|
||||
BOOLEAN_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(boolean[].class);
|
||||
BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class);
|
||||
SHORT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(short[].class);
|
||||
INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class);
|
||||
|
@ -242,6 +245,7 @@ public final class Platform {
|
|||
FLOAT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(float[].class);
|
||||
DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class);
|
||||
} else {
|
||||
BOOLEAN_ARRAY_OFFSET = 0;
|
||||
BYTE_ARRAY_OFFSET = 0;
|
||||
SHORT_ARRAY_OFFSET = 0;
|
||||
INT_ARRAY_OFFSET = 0;
|
||||
|
|
|
@ -57,13 +57,12 @@ object UDTSerializationBenchmark {
|
|||
}
|
||||
|
||||
/*
|
||||
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
|
||||
Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
|
||||
|
||||
VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||
-------------------------------------------------------------------------------------------
|
||||
serialize 380 / 392 0.0 379730.0 1.0X
|
||||
deserialize 138 / 142 0.0 137816.6 2.8X
|
||||
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
|
||||
Intel Xeon E3-12xx v2 (Ivy Bridge)
|
||||
VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||
------------------------------------------------------------------------------------------------
|
||||
serialize 265 / 318 0.0 265138.5 1.0X
|
||||
deserialize 155 / 197 0.0 154611.4 1.7X
|
||||
*/
|
||||
benchmark.run()
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.util.ArrayData;
|
|||
import org.apache.spark.sql.types.*;
|
||||
import org.apache.spark.unsafe.Platform;
|
||||
import org.apache.spark.unsafe.array.ByteArrayMethods;
|
||||
import org.apache.spark.unsafe.bitset.BitSetMethods;
|
||||
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
|
||||
import org.apache.spark.unsafe.types.CalendarInterval;
|
||||
import org.apache.spark.unsafe.types.UTF8String;
|
||||
|
@ -32,23 +33,31 @@ import org.apache.spark.unsafe.types.UTF8String;
|
|||
/**
|
||||
* An Unsafe implementation of Array which is backed by raw memory instead of Java objects.
|
||||
*
|
||||
* Each tuple has three parts: [numElements] [offsets] [values]
|
||||
* Each array has four parts:
|
||||
* [numElements][null bits][values or offset&length][variable length portion]
|
||||
*
|
||||
* The `numElements` is 4 bytes storing the number of elements of this array.
|
||||
* The `numElements` is 8 bytes storing the number of elements of this array.
|
||||
*
|
||||
* In the `offsets` region, we store 4 bytes per element, represents the relative offset (w.r.t. the
|
||||
* base address of the array) of this element in `values` region. We can get the length of this
|
||||
* element by subtracting next offset.
|
||||
* Note that offset can by negative which means this element is null.
|
||||
* In the `null bits` region, we store 1 bit per element, represents whether an element is null
|
||||
* Its total size is ceil(numElements / 8) bytes, and it is aligned to 8-byte boundaries.
|
||||
*
|
||||
* In the `values` region, we store the content of elements. As we can get length info, so elements
|
||||
* can be variable-length.
|
||||
* In the `values or offset&length` region, we store the content of elements. For fields that hold
|
||||
* fixed-length primitive types, such as long, double, or int, we store the value directly
|
||||
* in the field. The whole fixed-length portion (even for byte) is aligned to 8-byte boundaries.
|
||||
* For fields with non-primitive or variable-length values, we store a relative offset
|
||||
* (w.r.t. the base address of the array) that points to the beginning of the variable-length field
|
||||
* and length (they are combined into a long). For variable length portion, each is aligned
|
||||
* to 8-byte boundaries.
|
||||
*
|
||||
* Instances of `UnsafeArrayData` act as pointers to row data stored in this format.
|
||||
*/
|
||||
// todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData.
|
||||
|
||||
public final class UnsafeArrayData extends ArrayData {
|
||||
|
||||
public static int calculateHeaderPortionInBytes(int numFields) {
|
||||
return 8 + ((numFields + 63)/ 64) * 8;
|
||||
}
|
||||
|
||||
private Object baseObject;
|
||||
private long baseOffset;
|
||||
|
||||
|
@ -56,25 +65,20 @@ public final class UnsafeArrayData extends ArrayData {
|
|||
private int numElements;
|
||||
|
||||
// The size of this array's backing data, in bytes.
|
||||
// The 4-bytes header of `numElements` is also included.
|
||||
// The 8-bytes header of `numElements` is also included.
|
||||
private int sizeInBytes;
|
||||
|
||||
/** The position to start storing array elements, */
|
||||
private long elementOffset;
|
||||
|
||||
private long getElementOffset(int ordinal, int elementSize) {
|
||||
return elementOffset + ordinal * elementSize;
|
||||
}
|
||||
|
||||
public Object getBaseObject() { return baseObject; }
|
||||
public long getBaseOffset() { return baseOffset; }
|
||||
public int getSizeInBytes() { return sizeInBytes; }
|
||||
|
||||
private int getElementOffset(int ordinal) {
|
||||
return Platform.getInt(baseObject, baseOffset + 4 + ordinal * 4L);
|
||||
}
|
||||
|
||||
private int getElementSize(int offset, int ordinal) {
|
||||
if (ordinal == numElements - 1) {
|
||||
return sizeInBytes - offset;
|
||||
} else {
|
||||
return Math.abs(getElementOffset(ordinal + 1)) - offset;
|
||||
}
|
||||
}
|
||||
|
||||
private void assertIndexIsValid(int ordinal) {
|
||||
assert ordinal >= 0 : "ordinal (" + ordinal + ") should >= 0";
|
||||
assert ordinal < numElements : "ordinal (" + ordinal + ") should < " + numElements;
|
||||
|
@ -102,20 +106,22 @@ public final class UnsafeArrayData extends ArrayData {
|
|||
* @param sizeInBytes the size of this array's backing data, in bytes
|
||||
*/
|
||||
public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) {
|
||||
// Read the number of elements from the first 4 bytes.
|
||||
final int numElements = Platform.getInt(baseObject, baseOffset);
|
||||
// Read the number of elements from the first 8 bytes.
|
||||
final long numElements = Platform.getLong(baseObject, baseOffset);
|
||||
assert numElements >= 0 : "numElements (" + numElements + ") should >= 0";
|
||||
assert numElements <= Integer.MAX_VALUE : "numElements (" + numElements + ") should <= Integer.MAX_VALUE";
|
||||
|
||||
this.numElements = numElements;
|
||||
this.numElements = (int)numElements;
|
||||
this.baseObject = baseObject;
|
||||
this.baseOffset = baseOffset;
|
||||
this.sizeInBytes = sizeInBytes;
|
||||
this.elementOffset = baseOffset + calculateHeaderPortionInBytes(this.numElements);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isNullAt(int ordinal) {
|
||||
assertIndexIsValid(ordinal);
|
||||
return getElementOffset(ordinal) < 0;
|
||||
return BitSetMethods.isSet(baseObject, baseOffset + 8, ordinal);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -165,68 +171,50 @@ public final class UnsafeArrayData extends ArrayData {
|
|||
@Override
|
||||
public boolean getBoolean(int ordinal) {
|
||||
assertIndexIsValid(ordinal);
|
||||
final int offset = getElementOffset(ordinal);
|
||||
if (offset < 0) return false;
|
||||
return Platform.getBoolean(baseObject, baseOffset + offset);
|
||||
return Platform.getBoolean(baseObject, getElementOffset(ordinal, 1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte getByte(int ordinal) {
|
||||
assertIndexIsValid(ordinal);
|
||||
final int offset = getElementOffset(ordinal);
|
||||
if (offset < 0) return 0;
|
||||
return Platform.getByte(baseObject, baseOffset + offset);
|
||||
return Platform.getByte(baseObject, getElementOffset(ordinal, 1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public short getShort(int ordinal) {
|
||||
assertIndexIsValid(ordinal);
|
||||
final int offset = getElementOffset(ordinal);
|
||||
if (offset < 0) return 0;
|
||||
return Platform.getShort(baseObject, baseOffset + offset);
|
||||
return Platform.getShort(baseObject, getElementOffset(ordinal, 2));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getInt(int ordinal) {
|
||||
assertIndexIsValid(ordinal);
|
||||
final int offset = getElementOffset(ordinal);
|
||||
if (offset < 0) return 0;
|
||||
return Platform.getInt(baseObject, baseOffset + offset);
|
||||
return Platform.getInt(baseObject, getElementOffset(ordinal, 4));
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getLong(int ordinal) {
|
||||
assertIndexIsValid(ordinal);
|
||||
final int offset = getElementOffset(ordinal);
|
||||
if (offset < 0) return 0;
|
||||
return Platform.getLong(baseObject, baseOffset + offset);
|
||||
return Platform.getLong(baseObject, getElementOffset(ordinal, 8));
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getFloat(int ordinal) {
|
||||
assertIndexIsValid(ordinal);
|
||||
final int offset = getElementOffset(ordinal);
|
||||
if (offset < 0) return 0;
|
||||
return Platform.getFloat(baseObject, baseOffset + offset);
|
||||
return Platform.getFloat(baseObject, getElementOffset(ordinal, 4));
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getDouble(int ordinal) {
|
||||
assertIndexIsValid(ordinal);
|
||||
final int offset = getElementOffset(ordinal);
|
||||
if (offset < 0) return 0;
|
||||
return Platform.getDouble(baseObject, baseOffset + offset);
|
||||
return Platform.getDouble(baseObject, getElementOffset(ordinal, 8));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Decimal getDecimal(int ordinal, int precision, int scale) {
|
||||
assertIndexIsValid(ordinal);
|
||||
final int offset = getElementOffset(ordinal);
|
||||
if (offset < 0) return null;
|
||||
|
||||
if (isNullAt(ordinal)) return null;
|
||||
if (precision <= Decimal.MAX_LONG_DIGITS()) {
|
||||
final long value = Platform.getLong(baseObject, baseOffset + offset);
|
||||
return Decimal.apply(value, precision, scale);
|
||||
return Decimal.apply(getLong(ordinal), precision, scale);
|
||||
} else {
|
||||
final byte[] bytes = getBinary(ordinal);
|
||||
final BigInteger bigInteger = new BigInteger(bytes);
|
||||
|
@ -237,19 +225,19 @@ public final class UnsafeArrayData extends ArrayData {
|
|||
|
||||
@Override
|
||||
public UTF8String getUTF8String(int ordinal) {
|
||||
assertIndexIsValid(ordinal);
|
||||
final int offset = getElementOffset(ordinal);
|
||||
if (offset < 0) return null;
|
||||
final int size = getElementSize(offset, ordinal);
|
||||
if (isNullAt(ordinal)) return null;
|
||||
final long offsetAndSize = getLong(ordinal);
|
||||
final int offset = (int) (offsetAndSize >> 32);
|
||||
final int size = (int) offsetAndSize;
|
||||
return UTF8String.fromAddress(baseObject, baseOffset + offset, size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] getBinary(int ordinal) {
|
||||
assertIndexIsValid(ordinal);
|
||||
final int offset = getElementOffset(ordinal);
|
||||
if (offset < 0) return null;
|
||||
final int size = getElementSize(offset, ordinal);
|
||||
if (isNullAt(ordinal)) return null;
|
||||
final long offsetAndSize = getLong(ordinal);
|
||||
final int offset = (int) (offsetAndSize >> 32);
|
||||
final int size = (int) offsetAndSize;
|
||||
final byte[] bytes = new byte[size];
|
||||
Platform.copyMemory(baseObject, baseOffset + offset, bytes, Platform.BYTE_ARRAY_OFFSET, size);
|
||||
return bytes;
|
||||
|
@ -257,9 +245,9 @@ public final class UnsafeArrayData extends ArrayData {
|
|||
|
||||
@Override
|
||||
public CalendarInterval getInterval(int ordinal) {
|
||||
assertIndexIsValid(ordinal);
|
||||
final int offset = getElementOffset(ordinal);
|
||||
if (offset < 0) return null;
|
||||
if (isNullAt(ordinal)) return null;
|
||||
final long offsetAndSize = getLong(ordinal);
|
||||
final int offset = (int) (offsetAndSize >> 32);
|
||||
final int months = (int) Platform.getLong(baseObject, baseOffset + offset);
|
||||
final long microseconds = Platform.getLong(baseObject, baseOffset + offset + 8);
|
||||
return new CalendarInterval(months, microseconds);
|
||||
|
@ -267,10 +255,10 @@ public final class UnsafeArrayData extends ArrayData {
|
|||
|
||||
@Override
|
||||
public UnsafeRow getStruct(int ordinal, int numFields) {
|
||||
assertIndexIsValid(ordinal);
|
||||
final int offset = getElementOffset(ordinal);
|
||||
if (offset < 0) return null;
|
||||
final int size = getElementSize(offset, ordinal);
|
||||
if (isNullAt(ordinal)) return null;
|
||||
final long offsetAndSize = getLong(ordinal);
|
||||
final int offset = (int) (offsetAndSize >> 32);
|
||||
final int size = (int) offsetAndSize;
|
||||
final UnsafeRow row = new UnsafeRow(numFields);
|
||||
row.pointTo(baseObject, baseOffset + offset, size);
|
||||
return row;
|
||||
|
@ -278,10 +266,10 @@ public final class UnsafeArrayData extends ArrayData {
|
|||
|
||||
@Override
|
||||
public UnsafeArrayData getArray(int ordinal) {
|
||||
assertIndexIsValid(ordinal);
|
||||
final int offset = getElementOffset(ordinal);
|
||||
if (offset < 0) return null;
|
||||
final int size = getElementSize(offset, ordinal);
|
||||
if (isNullAt(ordinal)) return null;
|
||||
final long offsetAndSize = getLong(ordinal);
|
||||
final int offset = (int) (offsetAndSize >> 32);
|
||||
final int size = (int) offsetAndSize;
|
||||
final UnsafeArrayData array = new UnsafeArrayData();
|
||||
array.pointTo(baseObject, baseOffset + offset, size);
|
||||
return array;
|
||||
|
@ -289,10 +277,10 @@ public final class UnsafeArrayData extends ArrayData {
|
|||
|
||||
@Override
|
||||
public UnsafeMapData getMap(int ordinal) {
|
||||
assertIndexIsValid(ordinal);
|
||||
final int offset = getElementOffset(ordinal);
|
||||
if (offset < 0) return null;
|
||||
final int size = getElementSize(offset, ordinal);
|
||||
if (isNullAt(ordinal)) return null;
|
||||
final long offsetAndSize = getLong(ordinal);
|
||||
final int offset = (int) (offsetAndSize >> 32);
|
||||
final int size = (int) offsetAndSize;
|
||||
final UnsafeMapData map = new UnsafeMapData();
|
||||
map.pointTo(baseObject, baseOffset + offset, size);
|
||||
return map;
|
||||
|
@ -341,63 +329,108 @@ public final class UnsafeArrayData extends ArrayData {
|
|||
return arrayCopy;
|
||||
}
|
||||
|
||||
public static UnsafeArrayData fromPrimitiveArray(int[] arr) {
|
||||
if (arr.length > (Integer.MAX_VALUE - 4) / 8) {
|
||||
@Override
|
||||
public boolean[] toBooleanArray() {
|
||||
boolean[] values = new boolean[numElements];
|
||||
Platform.copyMemory(
|
||||
baseObject, elementOffset, values, Platform.BOOLEAN_ARRAY_OFFSET, numElements);
|
||||
return values;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] toByteArray() {
|
||||
byte[] values = new byte[numElements];
|
||||
Platform.copyMemory(
|
||||
baseObject, elementOffset, values, Platform.BYTE_ARRAY_OFFSET, numElements);
|
||||
return values;
|
||||
}
|
||||
|
||||
@Override
|
||||
public short[] toShortArray() {
|
||||
short[] values = new short[numElements];
|
||||
Platform.copyMemory(
|
||||
baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2);
|
||||
return values;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int[] toIntArray() {
|
||||
int[] values = new int[numElements];
|
||||
Platform.copyMemory(
|
||||
baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4);
|
||||
return values;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long[] toLongArray() {
|
||||
long[] values = new long[numElements];
|
||||
Platform.copyMemory(
|
||||
baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8);
|
||||
return values;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] toFloatArray() {
|
||||
float[] values = new float[numElements];
|
||||
Platform.copyMemory(
|
||||
baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4);
|
||||
return values;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] toDoubleArray() {
|
||||
double[] values = new double[numElements];
|
||||
Platform.copyMemory(
|
||||
baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8);
|
||||
return values;
|
||||
}
|
||||
|
||||
private static UnsafeArrayData fromPrimitiveArray(
|
||||
Object arr, int offset, int length, int elementSize) {
|
||||
final long headerInBytes = calculateHeaderPortionInBytes(length);
|
||||
final long valueRegionInBytes = elementSize * length;
|
||||
final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8;
|
||||
if (totalSizeInLongs > Integer.MAX_VALUE / 8) {
|
||||
throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " +
|
||||
"it's too big.");
|
||||
}
|
||||
|
||||
final int offsetRegionSize = 4 * arr.length;
|
||||
final int valueRegionSize = 4 * arr.length;
|
||||
final int totalSize = 4 + offsetRegionSize + valueRegionSize;
|
||||
final byte[] data = new byte[totalSize];
|
||||
final long[] data = new long[(int)totalSizeInLongs];
|
||||
|
||||
Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length);
|
||||
|
||||
int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4;
|
||||
int valueOffset = 4 + offsetRegionSize;
|
||||
for (int i = 0; i < arr.length; i++) {
|
||||
Platform.putInt(data, offsetPosition, valueOffset);
|
||||
offsetPosition += 4;
|
||||
valueOffset += 4;
|
||||
}
|
||||
|
||||
Platform.copyMemory(arr, Platform.INT_ARRAY_OFFSET, data,
|
||||
Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize);
|
||||
Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length);
|
||||
Platform.copyMemory(arr, offset, data,
|
||||
Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes);
|
||||
|
||||
UnsafeArrayData result = new UnsafeArrayData();
|
||||
result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize);
|
||||
result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8);
|
||||
return result;
|
||||
}
|
||||
|
||||
public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) {
|
||||
return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1);
|
||||
}
|
||||
|
||||
public static UnsafeArrayData fromPrimitiveArray(byte[] arr) {
|
||||
return fromPrimitiveArray(arr, Platform.BYTE_ARRAY_OFFSET, arr.length, 1);
|
||||
}
|
||||
|
||||
public static UnsafeArrayData fromPrimitiveArray(short[] arr) {
|
||||
return fromPrimitiveArray(arr, Platform.SHORT_ARRAY_OFFSET, arr.length, 2);
|
||||
}
|
||||
|
||||
public static UnsafeArrayData fromPrimitiveArray(int[] arr) {
|
||||
return fromPrimitiveArray(arr, Platform.INT_ARRAY_OFFSET, arr.length, 4);
|
||||
}
|
||||
|
||||
public static UnsafeArrayData fromPrimitiveArray(long[] arr) {
|
||||
return fromPrimitiveArray(arr, Platform.LONG_ARRAY_OFFSET, arr.length, 8);
|
||||
}
|
||||
|
||||
public static UnsafeArrayData fromPrimitiveArray(float[] arr) {
|
||||
return fromPrimitiveArray(arr, Platform.FLOAT_ARRAY_OFFSET, arr.length, 4);
|
||||
}
|
||||
|
||||
public static UnsafeArrayData fromPrimitiveArray(double[] arr) {
|
||||
if (arr.length > (Integer.MAX_VALUE - 4) / 12) {
|
||||
throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " +
|
||||
"it's too big.");
|
||||
}
|
||||
|
||||
final int offsetRegionSize = 4 * arr.length;
|
||||
final int valueRegionSize = 8 * arr.length;
|
||||
final int totalSize = 4 + offsetRegionSize + valueRegionSize;
|
||||
final byte[] data = new byte[totalSize];
|
||||
|
||||
Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length);
|
||||
|
||||
int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4;
|
||||
int valueOffset = 4 + offsetRegionSize;
|
||||
for (int i = 0; i < arr.length; i++) {
|
||||
Platform.putInt(data, offsetPosition, valueOffset);
|
||||
offsetPosition += 4;
|
||||
valueOffset += 8;
|
||||
}
|
||||
|
||||
Platform.copyMemory(arr, Platform.DOUBLE_ARRAY_OFFSET, data,
|
||||
Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize);
|
||||
|
||||
UnsafeArrayData result = new UnsafeArrayData();
|
||||
result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize);
|
||||
return result;
|
||||
return fromPrimitiveArray(arr, Platform.DOUBLE_ARRAY_OFFSET, arr.length, 8);
|
||||
}
|
||||
|
||||
// TODO: add more specialized methods.
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ import org.apache.spark.unsafe.Platform;
|
|||
/**
|
||||
* An Unsafe implementation of Map which is backed by raw memory instead of Java objects.
|
||||
*
|
||||
* Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 4 bytes at head
|
||||
* Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 8 bytes at head
|
||||
* to indicate the number of bytes of the unsafe key array.
|
||||
* [unsafe key array numBytes] [unsafe key array] [unsafe value array]
|
||||
*/
|
||||
|
@ -65,14 +65,15 @@ public final class UnsafeMapData extends MapData {
|
|||
* @param sizeInBytes the size of this map's backing data, in bytes
|
||||
*/
|
||||
public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) {
|
||||
// Read the numBytes of key array from the first 4 bytes.
|
||||
final int keyArraySize = Platform.getInt(baseObject, baseOffset);
|
||||
final int valueArraySize = sizeInBytes - keyArraySize - 4;
|
||||
// Read the numBytes of key array from the first 8 bytes.
|
||||
final long keyArraySize = Platform.getLong(baseObject, baseOffset);
|
||||
assert keyArraySize >= 0 : "keyArraySize (" + keyArraySize + ") should >= 0";
|
||||
assert keyArraySize <= Integer.MAX_VALUE : "keyArraySize (" + keyArraySize + ") should <= Integer.MAX_VALUE";
|
||||
final int valueArraySize = sizeInBytes - (int)keyArraySize - 8;
|
||||
assert valueArraySize >= 0 : "valueArraySize (" + valueArraySize + ") should >= 0";
|
||||
|
||||
keys.pointTo(baseObject, baseOffset + 4, keyArraySize);
|
||||
values.pointTo(baseObject, baseOffset + 4 + keyArraySize, valueArraySize);
|
||||
keys.pointTo(baseObject, baseOffset + 8, (int)keyArraySize);
|
||||
values.pointTo(baseObject, baseOffset + 8 + keyArraySize, valueArraySize);
|
||||
|
||||
assert keys.numElements() == values.numElements();
|
||||
|
||||
|
|
|
@ -19,9 +19,13 @@ package org.apache.spark.sql.catalyst.expressions.codegen;
|
|||
|
||||
import org.apache.spark.sql.types.Decimal;
|
||||
import org.apache.spark.unsafe.Platform;
|
||||
import org.apache.spark.unsafe.array.ByteArrayMethods;
|
||||
import org.apache.spark.unsafe.bitset.BitSetMethods;
|
||||
import org.apache.spark.unsafe.types.CalendarInterval;
|
||||
import org.apache.spark.unsafe.types.UTF8String;
|
||||
|
||||
import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes;
|
||||
|
||||
/**
|
||||
* A helper class to write data into global row buffer using `UnsafeArrayData` format,
|
||||
* used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
|
||||
|
@ -33,134 +37,213 @@ public class UnsafeArrayWriter {
|
|||
// The offset of the global buffer where we start to write this array.
|
||||
private int startingOffset;
|
||||
|
||||
public void initialize(BufferHolder holder, int numElements, int fixedElementSize) {
|
||||
// We need 4 bytes to store numElements and 4 bytes each element to store offset.
|
||||
final int fixedSize = 4 + 4 * numElements;
|
||||
// The number of elements in this array
|
||||
private int numElements;
|
||||
|
||||
private int headerInBytes;
|
||||
|
||||
private void assertIndexIsValid(int index) {
|
||||
assert index >= 0 : "index (" + index + ") should >= 0";
|
||||
assert index < numElements : "index (" + index + ") should < " + numElements;
|
||||
}
|
||||
|
||||
public void initialize(BufferHolder holder, int numElements, int elementSize) {
|
||||
// We need 8 bytes to store numElements in header
|
||||
this.numElements = numElements;
|
||||
this.headerInBytes = calculateHeaderPortionInBytes(numElements);
|
||||
|
||||
this.holder = holder;
|
||||
this.startingOffset = holder.cursor;
|
||||
|
||||
holder.grow(fixedSize);
|
||||
Platform.putInt(holder.buffer, holder.cursor, numElements);
|
||||
holder.cursor += fixedSize;
|
||||
// Grows the global buffer ahead for header and fixed size data.
|
||||
int fixedPartInBytes =
|
||||
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementSize * numElements);
|
||||
holder.grow(headerInBytes + fixedPartInBytes);
|
||||
|
||||
// Grows the global buffer ahead for fixed size data.
|
||||
holder.grow(fixedElementSize * numElements);
|
||||
// Write numElements and clear out null bits to header
|
||||
Platform.putLong(holder.buffer, startingOffset, numElements);
|
||||
for (int i = 8; i < headerInBytes; i += 8) {
|
||||
Platform.putLong(holder.buffer, startingOffset + i, 0L);
|
||||
}
|
||||
|
||||
// fill 0 into reminder part of 8-bytes alignment in unsafe array
|
||||
for (int i = elementSize * numElements; i < fixedPartInBytes; i++) {
|
||||
Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, (byte) 0);
|
||||
}
|
||||
holder.cursor += (headerInBytes + fixedPartInBytes);
|
||||
}
|
||||
|
||||
private long getElementOffset(int ordinal) {
|
||||
return startingOffset + 4 + 4 * ordinal;
|
||||
private void zeroOutPaddingBytes(int numBytes) {
|
||||
if ((numBytes & 0x07) > 0) {
|
||||
Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
|
||||
}
|
||||
}
|
||||
|
||||
public void setNullAt(int ordinal) {
|
||||
final int relativeOffset = holder.cursor - startingOffset;
|
||||
// Writes negative offset value to represent null element.
|
||||
Platform.putInt(holder.buffer, getElementOffset(ordinal), -relativeOffset);
|
||||
private long getElementOffset(int ordinal, int elementSize) {
|
||||
return startingOffset + headerInBytes + ordinal * elementSize;
|
||||
}
|
||||
|
||||
public void setOffset(int ordinal) {
|
||||
final int relativeOffset = holder.cursor - startingOffset;
|
||||
Platform.putInt(holder.buffer, getElementOffset(ordinal), relativeOffset);
|
||||
public void setOffsetAndSize(int ordinal, long currentCursor, int size) {
|
||||
assertIndexIsValid(ordinal);
|
||||
final long relativeOffset = currentCursor - startingOffset;
|
||||
final long offsetAndSize = (relativeOffset << 32) | (long)size;
|
||||
|
||||
write(ordinal, offsetAndSize);
|
||||
}
|
||||
|
||||
private void setNullBit(int ordinal) {
|
||||
assertIndexIsValid(ordinal);
|
||||
BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal);
|
||||
}
|
||||
|
||||
public void setNullBoolean(int ordinal) {
|
||||
setNullBit(ordinal);
|
||||
// put zero into the corresponding field when set null
|
||||
Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), false);
|
||||
}
|
||||
|
||||
public void setNullByte(int ordinal) {
|
||||
setNullBit(ordinal);
|
||||
// put zero into the corresponding field when set null
|
||||
Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0);
|
||||
}
|
||||
|
||||
public void setNullShort(int ordinal) {
|
||||
setNullBit(ordinal);
|
||||
// put zero into the corresponding field when set null
|
||||
Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0);
|
||||
}
|
||||
|
||||
public void setNullInt(int ordinal) {
|
||||
setNullBit(ordinal);
|
||||
// put zero into the corresponding field when set null
|
||||
Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), (int)0);
|
||||
}
|
||||
|
||||
public void setNullLong(int ordinal) {
|
||||
setNullBit(ordinal);
|
||||
// put zero into the corresponding field when set null
|
||||
Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0);
|
||||
}
|
||||
|
||||
public void setNullFloat(int ordinal) {
|
||||
setNullBit(ordinal);
|
||||
// put zero into the corresponding field when set null
|
||||
Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), (float)0);
|
||||
}
|
||||
|
||||
public void setNullDouble(int ordinal) {
|
||||
setNullBit(ordinal);
|
||||
// put zero into the corresponding field when set null
|
||||
Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), (double)0);
|
||||
}
|
||||
|
||||
public void setNull(int ordinal) { setNullLong(ordinal); }
|
||||
|
||||
public void write(int ordinal, boolean value) {
|
||||
Platform.putBoolean(holder.buffer, holder.cursor, value);
|
||||
setOffset(ordinal);
|
||||
holder.cursor += 1;
|
||||
assertIndexIsValid(ordinal);
|
||||
Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), value);
|
||||
}
|
||||
|
||||
public void write(int ordinal, byte value) {
|
||||
Platform.putByte(holder.buffer, holder.cursor, value);
|
||||
setOffset(ordinal);
|
||||
holder.cursor += 1;
|
||||
assertIndexIsValid(ordinal);
|
||||
Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), value);
|
||||
}
|
||||
|
||||
public void write(int ordinal, short value) {
|
||||
Platform.putShort(holder.buffer, holder.cursor, value);
|
||||
setOffset(ordinal);
|
||||
holder.cursor += 2;
|
||||
assertIndexIsValid(ordinal);
|
||||
Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), value);
|
||||
}
|
||||
|
||||
public void write(int ordinal, int value) {
|
||||
Platform.putInt(holder.buffer, holder.cursor, value);
|
||||
setOffset(ordinal);
|
||||
holder.cursor += 4;
|
||||
assertIndexIsValid(ordinal);
|
||||
Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), value);
|
||||
}
|
||||
|
||||
public void write(int ordinal, long value) {
|
||||
Platform.putLong(holder.buffer, holder.cursor, value);
|
||||
setOffset(ordinal);
|
||||
holder.cursor += 8;
|
||||
assertIndexIsValid(ordinal);
|
||||
Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), value);
|
||||
}
|
||||
|
||||
public void write(int ordinal, float value) {
|
||||
if (Float.isNaN(value)) {
|
||||
value = Float.NaN;
|
||||
}
|
||||
Platform.putFloat(holder.buffer, holder.cursor, value);
|
||||
setOffset(ordinal);
|
||||
holder.cursor += 4;
|
||||
assertIndexIsValid(ordinal);
|
||||
Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), value);
|
||||
}
|
||||
|
||||
public void write(int ordinal, double value) {
|
||||
if (Double.isNaN(value)) {
|
||||
value = Double.NaN;
|
||||
}
|
||||
Platform.putDouble(holder.buffer, holder.cursor, value);
|
||||
setOffset(ordinal);
|
||||
holder.cursor += 8;
|
||||
assertIndexIsValid(ordinal);
|
||||
Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), value);
|
||||
}
|
||||
|
||||
public void write(int ordinal, Decimal input, int precision, int scale) {
|
||||
// make sure Decimal object has the same scale as DecimalType
|
||||
assertIndexIsValid(ordinal);
|
||||
if (input.changePrecision(precision, scale)) {
|
||||
if (precision <= Decimal.MAX_LONG_DIGITS()) {
|
||||
Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong());
|
||||
setOffset(ordinal);
|
||||
holder.cursor += 8;
|
||||
write(ordinal, input.toUnscaledLong());
|
||||
} else {
|
||||
final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
|
||||
assert bytes.length <= 16;
|
||||
holder.grow(bytes.length);
|
||||
final int numBytes = bytes.length;
|
||||
assert numBytes <= 16;
|
||||
int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
|
||||
holder.grow(roundedSize);
|
||||
|
||||
zeroOutPaddingBytes(numBytes);
|
||||
|
||||
// Write the bytes to the variable length portion.
|
||||
Platform.copyMemory(
|
||||
bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length);
|
||||
setOffset(ordinal);
|
||||
holder.cursor += bytes.length;
|
||||
bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
|
||||
setOffsetAndSize(ordinal, holder.cursor, numBytes);
|
||||
|
||||
// move the cursor forward with 8-bytes boundary
|
||||
holder.cursor += roundedSize;
|
||||
}
|
||||
} else {
|
||||
setNullAt(ordinal);
|
||||
setNull(ordinal);
|
||||
}
|
||||
}
|
||||
|
||||
public void write(int ordinal, UTF8String input) {
|
||||
final int numBytes = input.numBytes();
|
||||
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
|
||||
|
||||
// grow the global buffer before writing data.
|
||||
holder.grow(numBytes);
|
||||
holder.grow(roundedSize);
|
||||
|
||||
zeroOutPaddingBytes(numBytes);
|
||||
|
||||
// Write the bytes to the variable length portion.
|
||||
input.writeToMemory(holder.buffer, holder.cursor);
|
||||
|
||||
setOffset(ordinal);
|
||||
setOffsetAndSize(ordinal, holder.cursor, numBytes);
|
||||
|
||||
// move the cursor forward.
|
||||
holder.cursor += numBytes;
|
||||
holder.cursor += roundedSize;
|
||||
}
|
||||
|
||||
public void write(int ordinal, byte[] input) {
|
||||
final int numBytes = input.length;
|
||||
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length);
|
||||
|
||||
// grow the global buffer before writing data.
|
||||
holder.grow(input.length);
|
||||
holder.grow(roundedSize);
|
||||
|
||||
zeroOutPaddingBytes(numBytes);
|
||||
|
||||
// Write the bytes to the variable length portion.
|
||||
Platform.copyMemory(
|
||||
input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, input.length);
|
||||
input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
|
||||
|
||||
setOffset(ordinal);
|
||||
setOffsetAndSize(ordinal, holder.cursor, numBytes);
|
||||
|
||||
// move the cursor forward.
|
||||
holder.cursor += input.length;
|
||||
holder.cursor += roundedSize;
|
||||
}
|
||||
|
||||
public void write(int ordinal, CalendarInterval input) {
|
||||
|
@ -171,7 +254,7 @@ public class UnsafeArrayWriter {
|
|||
Platform.putLong(holder.buffer, holder.cursor, input.months);
|
||||
Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds);
|
||||
|
||||
setOffset(ordinal);
|
||||
setOffsetAndSize(ordinal, holder.cursor, 16);
|
||||
|
||||
// move the cursor forward.
|
||||
holder.cursor += 16;
|
||||
|
|
|
@ -124,7 +124,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
|||
final int $tmpCursor = $bufferHolder.cursor;
|
||||
${writeArrayToBuffer(ctx, input.value, et, bufferHolder)}
|
||||
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
|
||||
$rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
|
||||
"""
|
||||
|
||||
case m @ MapType(kt, vt, _) =>
|
||||
|
@ -134,7 +133,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
|||
final int $tmpCursor = $bufferHolder.cursor;
|
||||
${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)}
|
||||
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
|
||||
$rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
|
||||
"""
|
||||
|
||||
case t: DecimalType =>
|
||||
|
@ -189,29 +187,33 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
|||
|
||||
val jt = ctx.javaType(et)
|
||||
|
||||
val fixedElementSize = et match {
|
||||
val elementOrOffsetSize = et match {
|
||||
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8
|
||||
case _ if ctx.isPrimitiveType(jt) => et.defaultSize
|
||||
case _ => 0
|
||||
case _ => 8 // we need 8 bytes to store offset and length
|
||||
}
|
||||
|
||||
val tmpCursor = ctx.freshName("tmpCursor")
|
||||
val writeElement = et match {
|
||||
case t: StructType =>
|
||||
s"""
|
||||
$arrayWriter.setOffset($index);
|
||||
final int $tmpCursor = $bufferHolder.cursor;
|
||||
${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)}
|
||||
$arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
|
||||
"""
|
||||
|
||||
case a @ ArrayType(et, _) =>
|
||||
s"""
|
||||
$arrayWriter.setOffset($index);
|
||||
final int $tmpCursor = $bufferHolder.cursor;
|
||||
${writeArrayToBuffer(ctx, element, et, bufferHolder)}
|
||||
$arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
|
||||
"""
|
||||
|
||||
case m @ MapType(kt, vt, _) =>
|
||||
s"""
|
||||
$arrayWriter.setOffset($index);
|
||||
final int $tmpCursor = $bufferHolder.cursor;
|
||||
${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
|
||||
$arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
|
||||
"""
|
||||
|
||||
case t: DecimalType =>
|
||||
|
@ -222,16 +224,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
|||
case _ => s"$arrayWriter.write($index, $element);"
|
||||
}
|
||||
|
||||
val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else ""
|
||||
s"""
|
||||
if ($input instanceof UnsafeArrayData) {
|
||||
${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)}
|
||||
} else {
|
||||
final int $numElements = $input.numElements();
|
||||
$arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize);
|
||||
$arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize);
|
||||
|
||||
for (int $index = 0; $index < $numElements; $index++) {
|
||||
if ($input.isNullAt($index)) {
|
||||
$arrayWriter.setNullAt($index);
|
||||
$arrayWriter.setNull$primitiveTypeName($index);
|
||||
} else {
|
||||
final $jt $element = ${ctx.getValue(input, et, index)};
|
||||
$writeElement
|
||||
|
@ -261,16 +264,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
|||
final ArrayData $keys = $input.keyArray();
|
||||
final ArrayData $values = $input.valueArray();
|
||||
|
||||
// preserve 4 bytes to write the key array numBytes later.
|
||||
$bufferHolder.grow(4);
|
||||
$bufferHolder.cursor += 4;
|
||||
// preserve 8 bytes to write the key array numBytes later.
|
||||
$bufferHolder.grow(8);
|
||||
$bufferHolder.cursor += 8;
|
||||
|
||||
// Remember the current cursor so that we can write numBytes of key array later.
|
||||
final int $tmpCursor = $bufferHolder.cursor;
|
||||
|
||||
${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)}
|
||||
// Write the numBytes of key array into the first 4 bytes.
|
||||
Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor);
|
||||
// Write the numBytes of key array into the first 8 bytes.
|
||||
Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor);
|
||||
|
||||
${writeArrayToBuffer(ctx, values, valueType, bufferHolder)}
|
||||
}
|
||||
|
|
|
@ -300,7 +300,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
|
|||
|
||||
private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = {
|
||||
assert(array.numElements == values.length)
|
||||
assert(array.getSizeInBytes == 4 + (4 + 4) * values.length)
|
||||
assert(array.getSizeInBytes ==
|
||||
8 + scala.math.ceil(values.length / 64.toDouble) * 8 + roundedSize(4 * values.length))
|
||||
values.zipWithIndex.foreach {
|
||||
case (value, index) => assert(array.getInt(index) == value)
|
||||
}
|
||||
|
@ -313,7 +314,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
|
|||
testArrayInt(map.keyArray, keys)
|
||||
testArrayInt(map.valueArray, values)
|
||||
|
||||
assert(map.getSizeInBytes == 4 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
|
||||
assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
|
||||
}
|
||||
|
||||
test("basic conversion with array type") {
|
||||
|
@ -339,7 +340,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
|
|||
val nestedArray = unsafeArray2.getArray(0)
|
||||
testArrayInt(nestedArray, Seq(3, 4))
|
||||
|
||||
assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes)
|
||||
assert(unsafeArray2.getSizeInBytes == 8 + 8 + 8 + nestedArray.getSizeInBytes)
|
||||
|
||||
val array1Size = roundedSize(unsafeArray1.getSizeInBytes)
|
||||
val array2Size = roundedSize(unsafeArray2.getSizeInBytes)
|
||||
|
@ -382,10 +383,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
|
|||
val nestedMap = valueArray.getMap(0)
|
||||
testMapInt(nestedMap, Seq(5, 6), Seq(7, 8))
|
||||
|
||||
assert(valueArray.getSizeInBytes == 4 + 4 + nestedMap.getSizeInBytes)
|
||||
assert(valueArray.getSizeInBytes == 8 + 8 + 8 + roundedSize(nestedMap.getSizeInBytes))
|
||||
}
|
||||
|
||||
assert(unsafeMap2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
|
||||
assert(unsafeMap2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
|
||||
|
||||
val map1Size = roundedSize(unsafeMap1.getSizeInBytes)
|
||||
val map2Size = roundedSize(unsafeMap2.getSizeInBytes)
|
||||
|
@ -425,7 +426,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
|
|||
assert(innerStruct.getLong(0) == 2L)
|
||||
}
|
||||
|
||||
assert(field2.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes)
|
||||
assert(field2.getSizeInBytes == 8 + 8 + 8 + innerStruct.getSizeInBytes)
|
||||
|
||||
assert(unsafeRow.getSizeInBytes ==
|
||||
8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
|
||||
|
@ -468,10 +469,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
|
|||
assert(innerStruct.getSizeInBytes == 8 + 8)
|
||||
assert(innerStruct.getLong(0) == 4L)
|
||||
|
||||
assert(valueArray.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes)
|
||||
assert(valueArray.getSizeInBytes == 8 + 8 + 8 + innerStruct.getSizeInBytes)
|
||||
}
|
||||
|
||||
assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
|
||||
assert(field2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
|
||||
|
||||
assert(unsafeRow.getSizeInBytes ==
|
||||
8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
|
||||
|
@ -497,7 +498,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
|
|||
val innerMap = field1.getMap(0)
|
||||
testMapInt(innerMap, Seq(1), Seq(2))
|
||||
|
||||
assert(field1.getSizeInBytes == 4 + 4 + innerMap.getSizeInBytes)
|
||||
assert(field1.getSizeInBytes == 8 + 8 + 8 + roundedSize(innerMap.getSizeInBytes))
|
||||
|
||||
val field2 = unsafeRow.getMap(1)
|
||||
assert(field2.numElements == 1)
|
||||
|
@ -513,10 +514,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
|
|||
val innerArray = valueArray.getArray(0)
|
||||
testArrayInt(innerArray, Seq(4))
|
||||
|
||||
assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes))
|
||||
assert(valueArray.getSizeInBytes == 8 + 8 + 8 + innerArray.getSizeInBytes)
|
||||
}
|
||||
|
||||
assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
|
||||
assert(field2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
|
||||
|
||||
assert(unsafeRow.getSizeInBytes ==
|
||||
8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes))
|
||||
|
|
|
@ -18,27 +18,190 @@
|
|||
package org.apache.spark.sql.catalyst.util
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
|
||||
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||
|
||||
class UnsafeArraySuite extends SparkFunSuite {
|
||||
|
||||
test("from primitive int array") {
|
||||
val array = Array(1, 10, 100)
|
||||
val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
|
||||
assert(unsafe.numElements == 3)
|
||||
assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 4 * 3)
|
||||
assert(unsafe.getInt(0) == 1)
|
||||
assert(unsafe.getInt(1) == 10)
|
||||
assert(unsafe.getInt(2) == 100)
|
||||
val booleanArray = Array(false, true)
|
||||
val shortArray = Array(1.toShort, 10.toShort, 100.toShort)
|
||||
val intArray = Array(1, 10, 100)
|
||||
val longArray = Array(1.toLong, 10.toLong, 100.toLong)
|
||||
val floatArray = Array(1.1.toFloat, 2.2.toFloat, 3.3.toFloat)
|
||||
val doubleArray = Array(1.1, 2.2, 3.3)
|
||||
val stringArray = Array("1", "10", "100")
|
||||
val dateArray = Array(
|
||||
DateTimeUtils.stringToDate(UTF8String.fromString("1970-1-1")).get,
|
||||
DateTimeUtils.stringToDate(UTF8String.fromString("2016-7-26")).get)
|
||||
val timestampArray = Array(
|
||||
DateTimeUtils.stringToTimestamp(UTF8String.fromString("1970-1-1 00:00:00")).get,
|
||||
DateTimeUtils.stringToTimestamp(UTF8String.fromString("2016-7-26 00:00:00")).get)
|
||||
val decimalArray4_1 = Array(
|
||||
BigDecimal("123.4").setScale(1, BigDecimal.RoundingMode.FLOOR),
|
||||
BigDecimal("567.8").setScale(1, BigDecimal.RoundingMode.FLOOR))
|
||||
val decimalArray20_20 = Array(
|
||||
BigDecimal("1.2345678901234567890123456").setScale(21, BigDecimal.RoundingMode.FLOOR),
|
||||
BigDecimal("2.3456789012345678901234567").setScale(21, BigDecimal.RoundingMode.FLOOR))
|
||||
|
||||
val calenderintervalArray = Array(new CalendarInterval(3, 321), new CalendarInterval(1, 123))
|
||||
|
||||
val intMultiDimArray = Array(Array(1), Array(2, 20), Array(3, 30, 300))
|
||||
val doubleMultiDimArray = Array(
|
||||
Array(1.1, 11.1), Array(2.2, 22.2, 222.2), Array(3.3, 33.3, 333.3, 3333.3))
|
||||
|
||||
test("read array") {
|
||||
val unsafeBoolean = ExpressionEncoder[Array[Boolean]].resolveAndBind().
|
||||
toRow(booleanArray).getArray(0)
|
||||
assert(unsafeBoolean.isInstanceOf[UnsafeArrayData])
|
||||
assert(unsafeBoolean.numElements == booleanArray.length)
|
||||
booleanArray.zipWithIndex.map { case (e, i) =>
|
||||
assert(unsafeBoolean.getBoolean(i) == e)
|
||||
}
|
||||
|
||||
val unsafeShort = ExpressionEncoder[Array[Short]].resolveAndBind().
|
||||
toRow(shortArray).getArray(0)
|
||||
assert(unsafeShort.isInstanceOf[UnsafeArrayData])
|
||||
assert(unsafeShort.numElements == shortArray.length)
|
||||
shortArray.zipWithIndex.map { case (e, i) =>
|
||||
assert(unsafeShort.getShort(i) == e)
|
||||
}
|
||||
|
||||
val unsafeInt = ExpressionEncoder[Array[Int]].resolveAndBind().
|
||||
toRow(intArray).getArray(0)
|
||||
assert(unsafeInt.isInstanceOf[UnsafeArrayData])
|
||||
assert(unsafeInt.numElements == intArray.length)
|
||||
intArray.zipWithIndex.map { case (e, i) =>
|
||||
assert(unsafeInt.getInt(i) == e)
|
||||
}
|
||||
|
||||
val unsafeLong = ExpressionEncoder[Array[Long]].resolveAndBind().
|
||||
toRow(longArray).getArray(0)
|
||||
assert(unsafeLong.isInstanceOf[UnsafeArrayData])
|
||||
assert(unsafeLong.numElements == longArray.length)
|
||||
longArray.zipWithIndex.map { case (e, i) =>
|
||||
assert(unsafeLong.getLong(i) == e)
|
||||
}
|
||||
|
||||
val unsafeFloat = ExpressionEncoder[Array[Float]].resolveAndBind().
|
||||
toRow(floatArray).getArray(0)
|
||||
assert(unsafeFloat.isInstanceOf[UnsafeArrayData])
|
||||
assert(unsafeFloat.numElements == floatArray.length)
|
||||
floatArray.zipWithIndex.map { case (e, i) =>
|
||||
assert(unsafeFloat.getFloat(i) == e)
|
||||
}
|
||||
|
||||
val unsafeDouble = ExpressionEncoder[Array[Double]].resolveAndBind().
|
||||
toRow(doubleArray).getArray(0)
|
||||
assert(unsafeDouble.isInstanceOf[UnsafeArrayData])
|
||||
assert(unsafeDouble.numElements == doubleArray.length)
|
||||
doubleArray.zipWithIndex.map { case (e, i) =>
|
||||
assert(unsafeDouble.getDouble(i) == e)
|
||||
}
|
||||
|
||||
val unsafeString = ExpressionEncoder[Array[String]].resolveAndBind().
|
||||
toRow(stringArray).getArray(0)
|
||||
assert(unsafeString.isInstanceOf[UnsafeArrayData])
|
||||
assert(unsafeString.numElements == stringArray.length)
|
||||
stringArray.zipWithIndex.map { case (e, i) =>
|
||||
assert(unsafeString.getUTF8String(i).toString().equals(e))
|
||||
}
|
||||
|
||||
val unsafeDate = ExpressionEncoder[Array[Int]].resolveAndBind().
|
||||
toRow(dateArray).getArray(0)
|
||||
assert(unsafeDate.isInstanceOf[UnsafeArrayData])
|
||||
assert(unsafeDate.numElements == dateArray.length)
|
||||
dateArray.zipWithIndex.map { case (e, i) =>
|
||||
assert(unsafeDate.get(i, DateType) == e)
|
||||
}
|
||||
|
||||
val unsafeTimestamp = ExpressionEncoder[Array[Long]].resolveAndBind().
|
||||
toRow(timestampArray).getArray(0)
|
||||
assert(unsafeTimestamp.isInstanceOf[UnsafeArrayData])
|
||||
assert(unsafeTimestamp.numElements == timestampArray.length)
|
||||
timestampArray.zipWithIndex.map { case (e, i) =>
|
||||
assert(unsafeTimestamp.get(i, TimestampType) == e)
|
||||
}
|
||||
|
||||
Seq(decimalArray4_1, decimalArray20_20).map { decimalArray =>
|
||||
val decimal = decimalArray(0)
|
||||
val schema = new StructType().add(
|
||||
"array", ArrayType(DecimalType(decimal.precision, decimal.scale)))
|
||||
val encoder = RowEncoder(schema).resolveAndBind()
|
||||
val externalRow = Row(decimalArray)
|
||||
val ir = encoder.toRow(externalRow)
|
||||
|
||||
val unsafeDecimal = ir.getArray(0)
|
||||
assert(unsafeDecimal.isInstanceOf[UnsafeArrayData])
|
||||
assert(unsafeDecimal.numElements == decimalArray.length)
|
||||
decimalArray.zipWithIndex.map { case (e, i) =>
|
||||
assert(unsafeDecimal.getDecimal(i, e.precision, e.scale).toBigDecimal == e)
|
||||
}
|
||||
}
|
||||
|
||||
val schema = new StructType().add("array", ArrayType(CalendarIntervalType))
|
||||
val encoder = RowEncoder(schema).resolveAndBind()
|
||||
val externalRow = Row(calenderintervalArray)
|
||||
val ir = encoder.toRow(externalRow)
|
||||
val unsafeCalendar = ir.getArray(0)
|
||||
assert(unsafeCalendar.isInstanceOf[UnsafeArrayData])
|
||||
assert(unsafeCalendar.numElements == calenderintervalArray.length)
|
||||
calenderintervalArray.zipWithIndex.map { case (e, i) =>
|
||||
assert(unsafeCalendar.getInterval(i) == e)
|
||||
}
|
||||
|
||||
val unsafeMultiDimInt = ExpressionEncoder[Array[Array[Int]]].resolveAndBind().
|
||||
toRow(intMultiDimArray).getArray(0)
|
||||
assert(unsafeMultiDimInt.isInstanceOf[UnsafeArrayData])
|
||||
assert(unsafeMultiDimInt.numElements == intMultiDimArray.length)
|
||||
intMultiDimArray.zipWithIndex.map { case (a, j) =>
|
||||
val u = unsafeMultiDimInt.getArray(j)
|
||||
assert(u.isInstanceOf[UnsafeArrayData])
|
||||
assert(u.numElements == a.length)
|
||||
a.zipWithIndex.map { case (e, i) =>
|
||||
assert(u.getInt(i) == e)
|
||||
}
|
||||
}
|
||||
|
||||
val unsafeMultiDimDouble = ExpressionEncoder[Array[Array[Double]]].resolveAndBind().
|
||||
toRow(doubleMultiDimArray).getArray(0)
|
||||
assert(unsafeDouble.isInstanceOf[UnsafeArrayData])
|
||||
assert(unsafeMultiDimDouble.numElements == doubleMultiDimArray.length)
|
||||
doubleMultiDimArray.zipWithIndex.map { case (a, j) =>
|
||||
val u = unsafeMultiDimDouble.getArray(j)
|
||||
assert(u.isInstanceOf[UnsafeArrayData])
|
||||
assert(u.numElements == a.length)
|
||||
a.zipWithIndex.map { case (e, i) =>
|
||||
assert(u.getDouble(i) == e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("from primitive double array") {
|
||||
val array = Array(1.1, 2.2, 3.3)
|
||||
val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
|
||||
assert(unsafe.numElements == 3)
|
||||
assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 8 * 3)
|
||||
assert(unsafe.getDouble(0) == 1.1)
|
||||
assert(unsafe.getDouble(1) == 2.2)
|
||||
assert(unsafe.getDouble(2) == 3.3)
|
||||
test("from primitive array") {
|
||||
val unsafeInt = UnsafeArrayData.fromPrimitiveArray(intArray)
|
||||
assert(unsafeInt.numElements == 3)
|
||||
assert(unsafeInt.getSizeInBytes ==
|
||||
((8 + scala.math.ceil(3/64.toDouble) * 8 + 4 * 3 + 7).toInt / 8) * 8)
|
||||
intArray.zipWithIndex.map { case (e, i) =>
|
||||
assert(unsafeInt.getInt(i) == e)
|
||||
}
|
||||
|
||||
val unsafeDouble = UnsafeArrayData.fromPrimitiveArray(doubleArray)
|
||||
assert(unsafeDouble.numElements == 3)
|
||||
assert(unsafeDouble.getSizeInBytes ==
|
||||
((8 + scala.math.ceil(3/64.toDouble) * 8 + 8 * 3 + 7).toInt / 8) * 8)
|
||||
doubleArray.zipWithIndex.map { case (e, i) =>
|
||||
assert(unsafeDouble.getDouble(i) == e)
|
||||
}
|
||||
}
|
||||
|
||||
test("to primitive array") {
|
||||
val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind()
|
||||
assert(intEncoder.toRow(intArray).getArray(0).toIntArray.sameElements(intArray))
|
||||
|
||||
val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
|
||||
assert(doubleEncoder.toRow(doubleArray).getArray(0).toDoubleArray.sameElements(doubleArray))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -601,7 +601,7 @@ private[columnar] case class ARRAY(dataType: ArrayType)
|
|||
|
||||
override def actualSize(row: InternalRow, ordinal: Int): Int = {
|
||||
val unsafeArray = getField(row, ordinal)
|
||||
4 + unsafeArray.getSizeInBytes
|
||||
8 + unsafeArray.getSizeInBytes
|
||||
}
|
||||
|
||||
override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = {
|
||||
|
@ -640,7 +640,7 @@ private[columnar] case class MAP(dataType: MapType)
|
|||
|
||||
override def actualSize(row: InternalRow, ordinal: Int): Int = {
|
||||
val unsafeMap = getField(row, ordinal)
|
||||
4 + unsafeMap.getSizeInBytes
|
||||
8 + unsafeMap.getSizeInBytes
|
||||
}
|
||||
|
||||
override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = {
|
||||
|
|
|
@ -0,0 +1,232 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.benchmark
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter}
|
||||
import org.apache.spark.util.Benchmark
|
||||
|
||||
/**
|
||||
* Benchmark [[UnsafeArrayDataBenchmark]] for UnsafeArrayData
|
||||
* To run this:
|
||||
* 1. replace ignore(...) with test(...)
|
||||
* 2. build/sbt "sql/test-only *benchmark.UnsafeArrayDataBenchmark"
|
||||
*
|
||||
* Benchmarks in this file are skipped in normal builds.
|
||||
*/
|
||||
class UnsafeArrayDataBenchmark extends BenchmarkBase {
|
||||
|
||||
def calculateHeaderPortionInBytes(count: Int) : Int = {
|
||||
/* 4 + 4 * count // Use this expression for SPARK-15962 */
|
||||
UnsafeArrayData.calculateHeaderPortionInBytes(count)
|
||||
}
|
||||
|
||||
def readUnsafeArray(iters: Int): Unit = {
|
||||
val count = 1024 * 1024 * 16
|
||||
val rand = new Random(42)
|
||||
|
||||
val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt }
|
||||
val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind()
|
||||
val intUnsafeArray = intEncoder.toRow(intPrimitiveArray).getArray(0)
|
||||
val readIntArray = { i: Int =>
|
||||
var n = 0
|
||||
while (n < iters) {
|
||||
val len = intUnsafeArray.numElements
|
||||
var sum = 0
|
||||
var i = 0
|
||||
while (i < len) {
|
||||
sum += intUnsafeArray.getInt(i)
|
||||
i += 1
|
||||
}
|
||||
n += 1
|
||||
}
|
||||
}
|
||||
|
||||
val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble }
|
||||
val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
|
||||
val doubleUnsafeArray = doubleEncoder.toRow(doublePrimitiveArray).getArray(0)
|
||||
val readDoubleArray = { i: Int =>
|
||||
var n = 0
|
||||
while (n < iters) {
|
||||
val len = doubleUnsafeArray.numElements
|
||||
var sum = 0.0
|
||||
var i = 0
|
||||
while (i < len) {
|
||||
sum += doubleUnsafeArray.getDouble(i)
|
||||
i += 1
|
||||
}
|
||||
n += 1
|
||||
}
|
||||
}
|
||||
|
||||
val benchmark = new Benchmark("Read UnsafeArrayData", count * iters)
|
||||
benchmark.addCase("Int")(readIntArray)
|
||||
benchmark.addCase("Double")(readDoubleArray)
|
||||
benchmark.run
|
||||
/*
|
||||
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
|
||||
Intel Xeon E3-12xx v2 (Ivy Bridge)
|
||||
Read UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||
------------------------------------------------------------------------------------------------
|
||||
Int 252 / 260 666.1 1.5 1.0X
|
||||
Double 281 / 292 597.7 1.7 0.9X
|
||||
*/
|
||||
}
|
||||
|
||||
def writeUnsafeArray(iters: Int): Unit = {
|
||||
val count = 1024 * 1024 * 2
|
||||
val rand = new Random(42)
|
||||
|
||||
var intTotalLength: Int = 0
|
||||
val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt }
|
||||
val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind()
|
||||
val writeIntArray = { i: Int =>
|
||||
var len = 0
|
||||
var n = 0
|
||||
while (n < iters) {
|
||||
len += intEncoder.toRow(intPrimitiveArray).getArray(0).numElements()
|
||||
n += 1
|
||||
}
|
||||
intTotalLength = len
|
||||
}
|
||||
|
||||
var doubleTotalLength: Int = 0
|
||||
val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble }
|
||||
val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
|
||||
val writeDoubleArray = { i: Int =>
|
||||
var len = 0
|
||||
var n = 0
|
||||
while (n < iters) {
|
||||
len += doubleEncoder.toRow(doublePrimitiveArray).getArray(0).numElements()
|
||||
n += 1
|
||||
}
|
||||
doubleTotalLength = len
|
||||
}
|
||||
|
||||
val benchmark = new Benchmark("Write UnsafeArrayData", count * iters)
|
||||
benchmark.addCase("Int")(writeIntArray)
|
||||
benchmark.addCase("Double")(writeDoubleArray)
|
||||
benchmark.run
|
||||
/*
|
||||
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
|
||||
Intel Xeon E3-12xx v2 (Ivy Bridge)
|
||||
Write UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||
------------------------------------------------------------------------------------------------
|
||||
Int 196 / 249 107.0 9.3 1.0X
|
||||
Double 227 / 367 92.3 10.8 0.9X
|
||||
*/
|
||||
}
|
||||
|
||||
def getPrimitiveArray(iters: Int): Unit = {
|
||||
val count = 1024 * 1024 * 12
|
||||
val rand = new Random(42)
|
||||
|
||||
var intTotalLength: Int = 0
|
||||
val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt }
|
||||
val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind()
|
||||
val intUnsafeArray = intEncoder.toRow(intPrimitiveArray).getArray(0)
|
||||
val readIntArray = { i: Int =>
|
||||
var len = 0
|
||||
var n = 0
|
||||
while (n < iters) {
|
||||
len += intUnsafeArray.toIntArray.length
|
||||
n += 1
|
||||
}
|
||||
intTotalLength = len
|
||||
}
|
||||
|
||||
var doubleTotalLength: Int = 0
|
||||
val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble }
|
||||
val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
|
||||
val doubleUnsafeArray = doubleEncoder.toRow(doublePrimitiveArray).getArray(0)
|
||||
val readDoubleArray = { i: Int =>
|
||||
var len = 0
|
||||
var n = 0
|
||||
while (n < iters) {
|
||||
len += doubleUnsafeArray.toDoubleArray.length
|
||||
n += 1
|
||||
}
|
||||
doubleTotalLength = len
|
||||
}
|
||||
|
||||
val benchmark = new Benchmark("Get primitive array from UnsafeArrayData", count * iters)
|
||||
benchmark.addCase("Int")(readIntArray)
|
||||
benchmark.addCase("Double")(readDoubleArray)
|
||||
benchmark.run
|
||||
/*
|
||||
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
|
||||
Intel Xeon E3-12xx v2 (Ivy Bridge)
|
||||
Get primitive array from UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||
------------------------------------------------------------------------------------------------
|
||||
Int 151 / 198 415.8 2.4 1.0X
|
||||
Double 214 / 394 293.6 3.4 0.7X
|
||||
*/
|
||||
}
|
||||
|
||||
def putPrimitiveArray(iters: Int): Unit = {
|
||||
val count = 1024 * 1024 * 12
|
||||
val rand = new Random(42)
|
||||
|
||||
var intTotalLen: Int = 0
|
||||
val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt }
|
||||
val createIntArray = { i: Int =>
|
||||
var len = 0
|
||||
var n = 0
|
||||
while (n < iters) {
|
||||
len += UnsafeArrayData.fromPrimitiveArray(intPrimitiveArray).numElements
|
||||
n += 1
|
||||
}
|
||||
intTotalLen = len
|
||||
}
|
||||
|
||||
var doubleTotalLen: Int = 0
|
||||
val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble }
|
||||
val createDoubleArray = { i: Int =>
|
||||
var len = 0
|
||||
var n = 0
|
||||
while (n < iters) {
|
||||
len += UnsafeArrayData.fromPrimitiveArray(doublePrimitiveArray).numElements
|
||||
n += 1
|
||||
}
|
||||
doubleTotalLen = len
|
||||
}
|
||||
|
||||
val benchmark = new Benchmark("Create UnsafeArrayData from primitive array", count * iters)
|
||||
benchmark.addCase("Int")(createIntArray)
|
||||
benchmark.addCase("Double")(createDoubleArray)
|
||||
benchmark.run
|
||||
/*
|
||||
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
|
||||
Intel Xeon E3-12xx v2 (Ivy Bridge)
|
||||
Create UnsafeArrayData from primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
|
||||
------------------------------------------------------------------------------------------------
|
||||
Int 206 / 211 306.0 3.3 1.0X
|
||||
Double 232 / 406 271.6 3.7 0.9X
|
||||
*/
|
||||
}
|
||||
|
||||
ignore("Benchmark UnsafeArrayData") {
|
||||
readUnsafeArray(10)
|
||||
writeUnsafeArray(10)
|
||||
getPrimitiveArray(5)
|
||||
putPrimitiveArray(5)
|
||||
}
|
||||
}
|
|
@ -73,8 +73,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
|
|||
checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4)
|
||||
checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
|
||||
checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5)
|
||||
checkActualSize(ARRAY_TYPE, Array[Any](1), 16)
|
||||
checkActualSize(MAP_TYPE, Map(1 -> "a"), 29)
|
||||
checkActualSize(ARRAY_TYPE, Array[Any](1), 8 + 8 + 8 + 8)
|
||||
checkActualSize(MAP_TYPE, Map(1 -> "a"), 8 + (8 + 8 + 8 + 8) + (8 + 8 + 8 + 8))
|
||||
checkActualSize(STRUCT_TYPE, Row("hello"), 28)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue