[SPARK-15962][SQL] Introduce implementation with a dense format for UnsafeArrayData

## What changes were proposed in this pull request?

This PR introduces more compact representation for ```UnsafeArrayData```.

```UnsafeArrayData``` needs to accept ```null``` value in each entry of an array. In the current version, it has three parts
```
[numElements] [offsets] [values]
```
`Offsets` has the number of `numElements`, and represents `null` if its value is negative. It may increase memory footprint, and introduces an indirection for accessing each of `values`.

This PR uses bitvectors to represent nullability for each element like `UnsafeRow`, and eliminates an indirection for accessing each element. The new ```UnsafeArrayData``` has four parts.
```
[numElements][null bits][values or offset&length][variable length portion]
```
In the `null bits` region, we store 1 bit per element, represents whether an element is null. Its total size is ceil(numElements / 8) bytes, and it is aligned to 8-byte boundaries.
In the `values or offset&length` region, we store the content of elements. For fields that hold fixed-length primitive types, such as long, double, or int, we store the value directly in the field. For fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the base address of the array) that points to the beginning of the variable-length field and length (they are combined into a long). Each is word-aligned. For `variable length portion`, each is aligned to 8-byte boundaries.

The new format can reduce memory footprint and improve performance of accessing each element. An example of memory foot comparison:
1024x1024 elements integer array
Size of ```baseObject``` for ```UnsafeArrayData```: 8 + 1024x1024 + 1024x1024 = 2M bytes
Size of ```baseObject``` for ```UnsafeArrayData```: 8 + 1024x1024/8 + 1024x1024 = 1.25M bytes

In summary, we got 1.0-2.6x performance improvements over the code before applying this PR.
Here are performance results of [benchmark programs](04d2e4b6db/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala):

**Read UnsafeArrayData**: 1.7x and 1.6x performance improvements over the code before applying this PR
````
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)

Without SPARK-15962
Read UnsafeArrayData:                    Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Int                                            430 /  436        390.0           2.6       1.0X
Double                                         456 /  485        367.8           2.7       0.9X

With SPARK-15962
Read UnsafeArrayData:                    Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Int                                            252 /  260        666.1           1.5       1.0X
Double                                         281 /  292        597.7           1.7       0.9X
````
**Write UnsafeArrayData**: 1.0x and 1.1x performance improvements over the code before applying this PR
````
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.0.4-301.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)

Without SPARK-15962
Write UnsafeArrayData:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Int                                            203 /  273        103.4           9.7       1.0X
Double                                         239 /  356         87.9          11.4       0.8X

With SPARK-15962
Write UnsafeArrayData:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Int                                            196 /  249        107.0           9.3       1.0X
Double                                         227 /  367         92.3          10.8       0.9X
````

**Get primitive array from UnsafeArrayData**: 2.6x and 1.6x performance improvements over the code before applying this PR
````
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.0.4-301.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)

Without SPARK-15962
Get primitive array from UnsafeArrayData: Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Int                                            207 /  217        304.2           3.3       1.0X
Double                                         257 /  363        245.2           4.1       0.8X

With SPARK-15962
Get primitive array from UnsafeArrayData: Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Int                                            151 /  198        415.8           2.4       1.0X
Double                                         214 /  394        293.6           3.4       0.7X
````

**Create UnsafeArrayData from primitive array**: 1.7x and 2.1x performance improvements over the code before applying this PR
````
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.0.4-301.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)

Without SPARK-15962
Create UnsafeArrayData from primitive array: Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Int                                            340 /  385        185.1           5.4       1.0X
Double                                         479 /  705        131.3           7.6       0.7X

With SPARK-15962
Create UnsafeArrayData from primitive array: Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Int                                            206 /  211        306.0           3.3       1.0X
Double                                         232 /  406        271.6           3.7       0.9X
````

1.7x and 1.4x performance improvements in [```UDTSerializationBenchmark```](https://github.com/apache/spark/blob/master/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala)  over the code before applying this PR
````
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)

Without SPARK-15962
VectorUDT de/serialization:              Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
serialize                                      442 /  533          0.0      441927.1       1.0X
deserialize                                    217 /  274          0.0      217087.6       2.0X

With SPARK-15962
VectorUDT de/serialization:              Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
serialize                                      265 /  318          0.0      265138.5       1.0X
deserialize                                    155 /  197          0.0      154611.4       1.7X
````

## How was this patch tested?

Added unit tests into ```UnsafeArraySuite```

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #13680 from kiszk/SPARK-15962.
This commit is contained in:
Kazuaki Ishizaki 2016-09-27 14:18:32 +08:00 committed by Wenchen Fan
parent 6ee28423ad
commit 85b0a15754
11 changed files with 755 additions and 236 deletions

View file

@ -29,6 +29,8 @@ public final class Platform {
private static final Unsafe _UNSAFE;
public static final int BOOLEAN_ARRAY_OFFSET;
public static final int BYTE_ARRAY_OFFSET;
public static final int SHORT_ARRAY_OFFSET;
@ -235,6 +237,7 @@ public final class Platform {
_UNSAFE = unsafe;
if (_UNSAFE != null) {
BOOLEAN_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(boolean[].class);
BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class);
SHORT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(short[].class);
INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class);
@ -242,6 +245,7 @@ public final class Platform {
FLOAT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(float[].class);
DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class);
} else {
BOOLEAN_ARRAY_OFFSET = 0;
BYTE_ARRAY_OFFSET = 0;
SHORT_ARRAY_OFFSET = 0;
INT_ARRAY_OFFSET = 0;

View file

@ -57,13 +57,12 @@ object UDTSerializationBenchmark {
}
/*
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
serialize 380 / 392 0.0 379730.0 1.0X
deserialize 138 / 142 0.0 137816.6 2.8X
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)
VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
serialize 265 / 318 0.0 265138.5 1.0X
deserialize 155 / 197 0.0 154611.4 1.7X
*/
benchmark.run()
}

View file

@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@ -32,23 +33,31 @@ import org.apache.spark.unsafe.types.UTF8String;
/**
* An Unsafe implementation of Array which is backed by raw memory instead of Java objects.
*
* Each tuple has three parts: [numElements] [offsets] [values]
* Each array has four parts:
* [numElements][null bits][values or offset&length][variable length portion]
*
* The `numElements` is 4 bytes storing the number of elements of this array.
* The `numElements` is 8 bytes storing the number of elements of this array.
*
* In the `offsets` region, we store 4 bytes per element, represents the relative offset (w.r.t. the
* base address of the array) of this element in `values` region. We can get the length of this
* element by subtracting next offset.
* Note that offset can by negative which means this element is null.
* In the `null bits` region, we store 1 bit per element, represents whether an element is null
* Its total size is ceil(numElements / 8) bytes, and it is aligned to 8-byte boundaries.
*
* In the `values` region, we store the content of elements. As we can get length info, so elements
* can be variable-length.
* In the `values or offset&length` region, we store the content of elements. For fields that hold
* fixed-length primitive types, such as long, double, or int, we store the value directly
* in the field. The whole fixed-length portion (even for byte) is aligned to 8-byte boundaries.
* For fields with non-primitive or variable-length values, we store a relative offset
* (w.r.t. the base address of the array) that points to the beginning of the variable-length field
* and length (they are combined into a long). For variable length portion, each is aligned
* to 8-byte boundaries.
*
* Instances of `UnsafeArrayData` act as pointers to row data stored in this format.
*/
// todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData.
public final class UnsafeArrayData extends ArrayData {
public static int calculateHeaderPortionInBytes(int numFields) {
return 8 + ((numFields + 63)/ 64) * 8;
}
private Object baseObject;
private long baseOffset;
@ -56,25 +65,20 @@ public final class UnsafeArrayData extends ArrayData {
private int numElements;
// The size of this array's backing data, in bytes.
// The 4-bytes header of `numElements` is also included.
// The 8-bytes header of `numElements` is also included.
private int sizeInBytes;
/** The position to start storing array elements, */
private long elementOffset;
private long getElementOffset(int ordinal, int elementSize) {
return elementOffset + ordinal * elementSize;
}
public Object getBaseObject() { return baseObject; }
public long getBaseOffset() { return baseOffset; }
public int getSizeInBytes() { return sizeInBytes; }
private int getElementOffset(int ordinal) {
return Platform.getInt(baseObject, baseOffset + 4 + ordinal * 4L);
}
private int getElementSize(int offset, int ordinal) {
if (ordinal == numElements - 1) {
return sizeInBytes - offset;
} else {
return Math.abs(getElementOffset(ordinal + 1)) - offset;
}
}
private void assertIndexIsValid(int ordinal) {
assert ordinal >= 0 : "ordinal (" + ordinal + ") should >= 0";
assert ordinal < numElements : "ordinal (" + ordinal + ") should < " + numElements;
@ -102,20 +106,22 @@ public final class UnsafeArrayData extends ArrayData {
* @param sizeInBytes the size of this array's backing data, in bytes
*/
public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) {
// Read the number of elements from the first 4 bytes.
final int numElements = Platform.getInt(baseObject, baseOffset);
// Read the number of elements from the first 8 bytes.
final long numElements = Platform.getLong(baseObject, baseOffset);
assert numElements >= 0 : "numElements (" + numElements + ") should >= 0";
assert numElements <= Integer.MAX_VALUE : "numElements (" + numElements + ") should <= Integer.MAX_VALUE";
this.numElements = numElements;
this.numElements = (int)numElements;
this.baseObject = baseObject;
this.baseOffset = baseOffset;
this.sizeInBytes = sizeInBytes;
this.elementOffset = baseOffset + calculateHeaderPortionInBytes(this.numElements);
}
@Override
public boolean isNullAt(int ordinal) {
assertIndexIsValid(ordinal);
return getElementOffset(ordinal) < 0;
return BitSetMethods.isSet(baseObject, baseOffset + 8, ordinal);
}
@Override
@ -165,68 +171,50 @@ public final class UnsafeArrayData extends ArrayData {
@Override
public boolean getBoolean(int ordinal) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return false;
return Platform.getBoolean(baseObject, baseOffset + offset);
return Platform.getBoolean(baseObject, getElementOffset(ordinal, 1));
}
@Override
public byte getByte(int ordinal) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return 0;
return Platform.getByte(baseObject, baseOffset + offset);
return Platform.getByte(baseObject, getElementOffset(ordinal, 1));
}
@Override
public short getShort(int ordinal) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return 0;
return Platform.getShort(baseObject, baseOffset + offset);
return Platform.getShort(baseObject, getElementOffset(ordinal, 2));
}
@Override
public int getInt(int ordinal) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return 0;
return Platform.getInt(baseObject, baseOffset + offset);
return Platform.getInt(baseObject, getElementOffset(ordinal, 4));
}
@Override
public long getLong(int ordinal) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return 0;
return Platform.getLong(baseObject, baseOffset + offset);
return Platform.getLong(baseObject, getElementOffset(ordinal, 8));
}
@Override
public float getFloat(int ordinal) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return 0;
return Platform.getFloat(baseObject, baseOffset + offset);
return Platform.getFloat(baseObject, getElementOffset(ordinal, 4));
}
@Override
public double getDouble(int ordinal) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return 0;
return Platform.getDouble(baseObject, baseOffset + offset);
return Platform.getDouble(baseObject, getElementOffset(ordinal, 8));
}
@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return null;
if (isNullAt(ordinal)) return null;
if (precision <= Decimal.MAX_LONG_DIGITS()) {
final long value = Platform.getLong(baseObject, baseOffset + offset);
return Decimal.apply(value, precision, scale);
return Decimal.apply(getLong(ordinal), precision, scale);
} else {
final byte[] bytes = getBinary(ordinal);
final BigInteger bigInteger = new BigInteger(bytes);
@ -237,19 +225,19 @@ public final class UnsafeArrayData extends ArrayData {
@Override
public UTF8String getUTF8String(int ordinal) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return null;
final int size = getElementSize(offset, ordinal);
if (isNullAt(ordinal)) return null;
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
final int size = (int) offsetAndSize;
return UTF8String.fromAddress(baseObject, baseOffset + offset, size);
}
@Override
public byte[] getBinary(int ordinal) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return null;
final int size = getElementSize(offset, ordinal);
if (isNullAt(ordinal)) return null;
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
final int size = (int) offsetAndSize;
final byte[] bytes = new byte[size];
Platform.copyMemory(baseObject, baseOffset + offset, bytes, Platform.BYTE_ARRAY_OFFSET, size);
return bytes;
@ -257,9 +245,9 @@ public final class UnsafeArrayData extends ArrayData {
@Override
public CalendarInterval getInterval(int ordinal) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return null;
if (isNullAt(ordinal)) return null;
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
final int months = (int) Platform.getLong(baseObject, baseOffset + offset);
final long microseconds = Platform.getLong(baseObject, baseOffset + offset + 8);
return new CalendarInterval(months, microseconds);
@ -267,10 +255,10 @@ public final class UnsafeArrayData extends ArrayData {
@Override
public UnsafeRow getStruct(int ordinal, int numFields) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return null;
final int size = getElementSize(offset, ordinal);
if (isNullAt(ordinal)) return null;
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
final int size = (int) offsetAndSize;
final UnsafeRow row = new UnsafeRow(numFields);
row.pointTo(baseObject, baseOffset + offset, size);
return row;
@ -278,10 +266,10 @@ public final class UnsafeArrayData extends ArrayData {
@Override
public UnsafeArrayData getArray(int ordinal) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return null;
final int size = getElementSize(offset, ordinal);
if (isNullAt(ordinal)) return null;
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
final int size = (int) offsetAndSize;
final UnsafeArrayData array = new UnsafeArrayData();
array.pointTo(baseObject, baseOffset + offset, size);
return array;
@ -289,10 +277,10 @@ public final class UnsafeArrayData extends ArrayData {
@Override
public UnsafeMapData getMap(int ordinal) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return null;
final int size = getElementSize(offset, ordinal);
if (isNullAt(ordinal)) return null;
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
final int size = (int) offsetAndSize;
final UnsafeMapData map = new UnsafeMapData();
map.pointTo(baseObject, baseOffset + offset, size);
return map;
@ -341,63 +329,108 @@ public final class UnsafeArrayData extends ArrayData {
return arrayCopy;
}
public static UnsafeArrayData fromPrimitiveArray(int[] arr) {
if (arr.length > (Integer.MAX_VALUE - 4) / 8) {
@Override
public boolean[] toBooleanArray() {
boolean[] values = new boolean[numElements];
Platform.copyMemory(
baseObject, elementOffset, values, Platform.BOOLEAN_ARRAY_OFFSET, numElements);
return values;
}
@Override
public byte[] toByteArray() {
byte[] values = new byte[numElements];
Platform.copyMemory(
baseObject, elementOffset, values, Platform.BYTE_ARRAY_OFFSET, numElements);
return values;
}
@Override
public short[] toShortArray() {
short[] values = new short[numElements];
Platform.copyMemory(
baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2);
return values;
}
@Override
public int[] toIntArray() {
int[] values = new int[numElements];
Platform.copyMemory(
baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4);
return values;
}
@Override
public long[] toLongArray() {
long[] values = new long[numElements];
Platform.copyMemory(
baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8);
return values;
}
@Override
public float[] toFloatArray() {
float[] values = new float[numElements];
Platform.copyMemory(
baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4);
return values;
}
@Override
public double[] toDoubleArray() {
double[] values = new double[numElements];
Platform.copyMemory(
baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8);
return values;
}
private static UnsafeArrayData fromPrimitiveArray(
Object arr, int offset, int length, int elementSize) {
final long headerInBytes = calculateHeaderPortionInBytes(length);
final long valueRegionInBytes = elementSize * length;
final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8;
if (totalSizeInLongs > Integer.MAX_VALUE / 8) {
throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " +
"it's too big.");
}
final int offsetRegionSize = 4 * arr.length;
final int valueRegionSize = 4 * arr.length;
final int totalSize = 4 + offsetRegionSize + valueRegionSize;
final byte[] data = new byte[totalSize];
final long[] data = new long[(int)totalSizeInLongs];
Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length);
int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4;
int valueOffset = 4 + offsetRegionSize;
for (int i = 0; i < arr.length; i++) {
Platform.putInt(data, offsetPosition, valueOffset);
offsetPosition += 4;
valueOffset += 4;
}
Platform.copyMemory(arr, Platform.INT_ARRAY_OFFSET, data,
Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize);
Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length);
Platform.copyMemory(arr, offset, data,
Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes);
UnsafeArrayData result = new UnsafeArrayData();
result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize);
result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8);
return result;
}
public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) {
return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1);
}
public static UnsafeArrayData fromPrimitiveArray(byte[] arr) {
return fromPrimitiveArray(arr, Platform.BYTE_ARRAY_OFFSET, arr.length, 1);
}
public static UnsafeArrayData fromPrimitiveArray(short[] arr) {
return fromPrimitiveArray(arr, Platform.SHORT_ARRAY_OFFSET, arr.length, 2);
}
public static UnsafeArrayData fromPrimitiveArray(int[] arr) {
return fromPrimitiveArray(arr, Platform.INT_ARRAY_OFFSET, arr.length, 4);
}
public static UnsafeArrayData fromPrimitiveArray(long[] arr) {
return fromPrimitiveArray(arr, Platform.LONG_ARRAY_OFFSET, arr.length, 8);
}
public static UnsafeArrayData fromPrimitiveArray(float[] arr) {
return fromPrimitiveArray(arr, Platform.FLOAT_ARRAY_OFFSET, arr.length, 4);
}
public static UnsafeArrayData fromPrimitiveArray(double[] arr) {
if (arr.length > (Integer.MAX_VALUE - 4) / 12) {
throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " +
"it's too big.");
}
final int offsetRegionSize = 4 * arr.length;
final int valueRegionSize = 8 * arr.length;
final int totalSize = 4 + offsetRegionSize + valueRegionSize;
final byte[] data = new byte[totalSize];
Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length);
int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4;
int valueOffset = 4 + offsetRegionSize;
for (int i = 0; i < arr.length; i++) {
Platform.putInt(data, offsetPosition, valueOffset);
offsetPosition += 4;
valueOffset += 8;
}
Platform.copyMemory(arr, Platform.DOUBLE_ARRAY_OFFSET, data,
Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize);
UnsafeArrayData result = new UnsafeArrayData();
result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize);
return result;
return fromPrimitiveArray(arr, Platform.DOUBLE_ARRAY_OFFSET, arr.length, 8);
}
// TODO: add more specialized methods.
}

View file

@ -25,7 +25,7 @@ import org.apache.spark.unsafe.Platform;
/**
* An Unsafe implementation of Map which is backed by raw memory instead of Java objects.
*
* Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 4 bytes at head
* Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 8 bytes at head
* to indicate the number of bytes of the unsafe key array.
* [unsafe key array numBytes] [unsafe key array] [unsafe value array]
*/
@ -65,14 +65,15 @@ public final class UnsafeMapData extends MapData {
* @param sizeInBytes the size of this map's backing data, in bytes
*/
public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) {
// Read the numBytes of key array from the first 4 bytes.
final int keyArraySize = Platform.getInt(baseObject, baseOffset);
final int valueArraySize = sizeInBytes - keyArraySize - 4;
// Read the numBytes of key array from the first 8 bytes.
final long keyArraySize = Platform.getLong(baseObject, baseOffset);
assert keyArraySize >= 0 : "keyArraySize (" + keyArraySize + ") should >= 0";
assert keyArraySize <= Integer.MAX_VALUE : "keyArraySize (" + keyArraySize + ") should <= Integer.MAX_VALUE";
final int valueArraySize = sizeInBytes - (int)keyArraySize - 8;
assert valueArraySize >= 0 : "valueArraySize (" + valueArraySize + ") should >= 0";
keys.pointTo(baseObject, baseOffset + 4, keyArraySize);
values.pointTo(baseObject, baseOffset + 4 + keyArraySize, valueArraySize);
keys.pointTo(baseObject, baseOffset + 8, (int)keyArraySize);
values.pointTo(baseObject, baseOffset + 8 + keyArraySize, valueArraySize);
assert keys.numElements() == values.numElements();

View file

@ -19,9 +19,13 @@ package org.apache.spark.sql.catalyst.expressions.codegen;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes;
/**
* A helper class to write data into global row buffer using `UnsafeArrayData` format,
* used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
@ -33,134 +37,213 @@ public class UnsafeArrayWriter {
// The offset of the global buffer where we start to write this array.
private int startingOffset;
public void initialize(BufferHolder holder, int numElements, int fixedElementSize) {
// We need 4 bytes to store numElements and 4 bytes each element to store offset.
final int fixedSize = 4 + 4 * numElements;
// The number of elements in this array
private int numElements;
private int headerInBytes;
private void assertIndexIsValid(int index) {
assert index >= 0 : "index (" + index + ") should >= 0";
assert index < numElements : "index (" + index + ") should < " + numElements;
}
public void initialize(BufferHolder holder, int numElements, int elementSize) {
// We need 8 bytes to store numElements in header
this.numElements = numElements;
this.headerInBytes = calculateHeaderPortionInBytes(numElements);
this.holder = holder;
this.startingOffset = holder.cursor;
holder.grow(fixedSize);
Platform.putInt(holder.buffer, holder.cursor, numElements);
holder.cursor += fixedSize;
// Grows the global buffer ahead for header and fixed size data.
int fixedPartInBytes =
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementSize * numElements);
holder.grow(headerInBytes + fixedPartInBytes);
// Grows the global buffer ahead for fixed size data.
holder.grow(fixedElementSize * numElements);
// Write numElements and clear out null bits to header
Platform.putLong(holder.buffer, startingOffset, numElements);
for (int i = 8; i < headerInBytes; i += 8) {
Platform.putLong(holder.buffer, startingOffset + i, 0L);
}
// fill 0 into reminder part of 8-bytes alignment in unsafe array
for (int i = elementSize * numElements; i < fixedPartInBytes; i++) {
Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, (byte) 0);
}
holder.cursor += (headerInBytes + fixedPartInBytes);
}
private long getElementOffset(int ordinal) {
return startingOffset + 4 + 4 * ordinal;
private void zeroOutPaddingBytes(int numBytes) {
if ((numBytes & 0x07) > 0) {
Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
}
}
public void setNullAt(int ordinal) {
final int relativeOffset = holder.cursor - startingOffset;
// Writes negative offset value to represent null element.
Platform.putInt(holder.buffer, getElementOffset(ordinal), -relativeOffset);
private long getElementOffset(int ordinal, int elementSize) {
return startingOffset + headerInBytes + ordinal * elementSize;
}
public void setOffset(int ordinal) {
final int relativeOffset = holder.cursor - startingOffset;
Platform.putInt(holder.buffer, getElementOffset(ordinal), relativeOffset);
public void setOffsetAndSize(int ordinal, long currentCursor, int size) {
assertIndexIsValid(ordinal);
final long relativeOffset = currentCursor - startingOffset;
final long offsetAndSize = (relativeOffset << 32) | (long)size;
write(ordinal, offsetAndSize);
}
private void setNullBit(int ordinal) {
assertIndexIsValid(ordinal);
BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal);
}
public void setNullBoolean(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), false);
}
public void setNullByte(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0);
}
public void setNullShort(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0);
}
public void setNullInt(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), (int)0);
}
public void setNullLong(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0);
}
public void setNullFloat(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), (float)0);
}
public void setNullDouble(int ordinal) {
setNullBit(ordinal);
// put zero into the corresponding field when set null
Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), (double)0);
}
public void setNull(int ordinal) { setNullLong(ordinal); }
public void write(int ordinal, boolean value) {
Platform.putBoolean(holder.buffer, holder.cursor, value);
setOffset(ordinal);
holder.cursor += 1;
assertIndexIsValid(ordinal);
Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), value);
}
public void write(int ordinal, byte value) {
Platform.putByte(holder.buffer, holder.cursor, value);
setOffset(ordinal);
holder.cursor += 1;
assertIndexIsValid(ordinal);
Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), value);
}
public void write(int ordinal, short value) {
Platform.putShort(holder.buffer, holder.cursor, value);
setOffset(ordinal);
holder.cursor += 2;
assertIndexIsValid(ordinal);
Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), value);
}
public void write(int ordinal, int value) {
Platform.putInt(holder.buffer, holder.cursor, value);
setOffset(ordinal);
holder.cursor += 4;
assertIndexIsValid(ordinal);
Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), value);
}
public void write(int ordinal, long value) {
Platform.putLong(holder.buffer, holder.cursor, value);
setOffset(ordinal);
holder.cursor += 8;
assertIndexIsValid(ordinal);
Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), value);
}
public void write(int ordinal, float value) {
if (Float.isNaN(value)) {
value = Float.NaN;
}
Platform.putFloat(holder.buffer, holder.cursor, value);
setOffset(ordinal);
holder.cursor += 4;
assertIndexIsValid(ordinal);
Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), value);
}
public void write(int ordinal, double value) {
if (Double.isNaN(value)) {
value = Double.NaN;
}
Platform.putDouble(holder.buffer, holder.cursor, value);
setOffset(ordinal);
holder.cursor += 8;
assertIndexIsValid(ordinal);
Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), value);
}
public void write(int ordinal, Decimal input, int precision, int scale) {
// make sure Decimal object has the same scale as DecimalType
assertIndexIsValid(ordinal);
if (input.changePrecision(precision, scale)) {
if (precision <= Decimal.MAX_LONG_DIGITS()) {
Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong());
setOffset(ordinal);
holder.cursor += 8;
write(ordinal, input.toUnscaledLong());
} else {
final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
assert bytes.length <= 16;
holder.grow(bytes.length);
final int numBytes = bytes.length;
assert numBytes <= 16;
int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
holder.grow(roundedSize);
zeroOutPaddingBytes(numBytes);
// Write the bytes to the variable length portion.
Platform.copyMemory(
bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length);
setOffset(ordinal);
holder.cursor += bytes.length;
bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
setOffsetAndSize(ordinal, holder.cursor, numBytes);
// move the cursor forward with 8-bytes boundary
holder.cursor += roundedSize;
}
} else {
setNullAt(ordinal);
setNull(ordinal);
}
}
public void write(int ordinal, UTF8String input) {
final int numBytes = input.numBytes();
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
// grow the global buffer before writing data.
holder.grow(numBytes);
holder.grow(roundedSize);
zeroOutPaddingBytes(numBytes);
// Write the bytes to the variable length portion.
input.writeToMemory(holder.buffer, holder.cursor);
setOffset(ordinal);
setOffsetAndSize(ordinal, holder.cursor, numBytes);
// move the cursor forward.
holder.cursor += numBytes;
holder.cursor += roundedSize;
}
public void write(int ordinal, byte[] input) {
final int numBytes = input.length;
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length);
// grow the global buffer before writing data.
holder.grow(input.length);
holder.grow(roundedSize);
zeroOutPaddingBytes(numBytes);
// Write the bytes to the variable length portion.
Platform.copyMemory(
input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, input.length);
input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
setOffset(ordinal);
setOffsetAndSize(ordinal, holder.cursor, numBytes);
// move the cursor forward.
holder.cursor += input.length;
holder.cursor += roundedSize;
}
public void write(int ordinal, CalendarInterval input) {
@ -171,7 +254,7 @@ public class UnsafeArrayWriter {
Platform.putLong(holder.buffer, holder.cursor, input.months);
Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds);
setOffset(ordinal);
setOffsetAndSize(ordinal, holder.cursor, 16);
// move the cursor forward.
holder.cursor += 16;

View file

@ -124,7 +124,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
final int $tmpCursor = $bufferHolder.cursor;
${writeArrayToBuffer(ctx, input.value, et, bufferHolder)}
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
$rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
"""
case m @ MapType(kt, vt, _) =>
@ -134,7 +133,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
final int $tmpCursor = $bufferHolder.cursor;
${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)}
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
$rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
"""
case t: DecimalType =>
@ -189,29 +187,33 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val jt = ctx.javaType(et)
val fixedElementSize = et match {
val elementOrOffsetSize = et match {
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8
case _ if ctx.isPrimitiveType(jt) => et.defaultSize
case _ => 0
case _ => 8 // we need 8 bytes to store offset and length
}
val tmpCursor = ctx.freshName("tmpCursor")
val writeElement = et match {
case t: StructType =>
s"""
$arrayWriter.setOffset($index);
final int $tmpCursor = $bufferHolder.cursor;
${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)}
$arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""
case a @ ArrayType(et, _) =>
s"""
$arrayWriter.setOffset($index);
final int $tmpCursor = $bufferHolder.cursor;
${writeArrayToBuffer(ctx, element, et, bufferHolder)}
$arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""
case m @ MapType(kt, vt, _) =>
s"""
$arrayWriter.setOffset($index);
final int $tmpCursor = $bufferHolder.cursor;
${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
$arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""
case t: DecimalType =>
@ -222,16 +224,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => s"$arrayWriter.write($index, $element);"
}
val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else ""
s"""
if ($input instanceof UnsafeArrayData) {
${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)}
} else {
final int $numElements = $input.numElements();
$arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize);
$arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize);
for (int $index = 0; $index < $numElements; $index++) {
if ($input.isNullAt($index)) {
$arrayWriter.setNullAt($index);
$arrayWriter.setNull$primitiveTypeName($index);
} else {
final $jt $element = ${ctx.getValue(input, et, index)};
$writeElement
@ -261,16 +264,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
final ArrayData $keys = $input.keyArray();
final ArrayData $values = $input.valueArray();
// preserve 4 bytes to write the key array numBytes later.
$bufferHolder.grow(4);
$bufferHolder.cursor += 4;
// preserve 8 bytes to write the key array numBytes later.
$bufferHolder.grow(8);
$bufferHolder.cursor += 8;
// Remember the current cursor so that we can write numBytes of key array later.
final int $tmpCursor = $bufferHolder.cursor;
${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)}
// Write the numBytes of key array into the first 4 bytes.
Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor);
// Write the numBytes of key array into the first 8 bytes.
Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor);
${writeArrayToBuffer(ctx, values, valueType, bufferHolder)}
}

View file

@ -300,7 +300,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = {
assert(array.numElements == values.length)
assert(array.getSizeInBytes == 4 + (4 + 4) * values.length)
assert(array.getSizeInBytes ==
8 + scala.math.ceil(values.length / 64.toDouble) * 8 + roundedSize(4 * values.length))
values.zipWithIndex.foreach {
case (value, index) => assert(array.getInt(index) == value)
}
@ -313,7 +314,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
testArrayInt(map.keyArray, keys)
testArrayInt(map.valueArray, values)
assert(map.getSizeInBytes == 4 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
}
test("basic conversion with array type") {
@ -339,7 +340,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val nestedArray = unsafeArray2.getArray(0)
testArrayInt(nestedArray, Seq(3, 4))
assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes)
assert(unsafeArray2.getSizeInBytes == 8 + 8 + 8 + nestedArray.getSizeInBytes)
val array1Size = roundedSize(unsafeArray1.getSizeInBytes)
val array2Size = roundedSize(unsafeArray2.getSizeInBytes)
@ -382,10 +383,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val nestedMap = valueArray.getMap(0)
testMapInt(nestedMap, Seq(5, 6), Seq(7, 8))
assert(valueArray.getSizeInBytes == 4 + 4 + nestedMap.getSizeInBytes)
assert(valueArray.getSizeInBytes == 8 + 8 + 8 + roundedSize(nestedMap.getSizeInBytes))
}
assert(unsafeMap2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
assert(unsafeMap2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
val map1Size = roundedSize(unsafeMap1.getSizeInBytes)
val map2Size = roundedSize(unsafeMap2.getSizeInBytes)
@ -425,7 +426,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(innerStruct.getLong(0) == 2L)
}
assert(field2.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes)
assert(field2.getSizeInBytes == 8 + 8 + 8 + innerStruct.getSizeInBytes)
assert(unsafeRow.getSizeInBytes ==
8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
@ -468,10 +469,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(innerStruct.getSizeInBytes == 8 + 8)
assert(innerStruct.getLong(0) == 4L)
assert(valueArray.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes)
assert(valueArray.getSizeInBytes == 8 + 8 + 8 + innerStruct.getSizeInBytes)
}
assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
assert(field2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
assert(unsafeRow.getSizeInBytes ==
8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
@ -497,7 +498,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val innerMap = field1.getMap(0)
testMapInt(innerMap, Seq(1), Seq(2))
assert(field1.getSizeInBytes == 4 + 4 + innerMap.getSizeInBytes)
assert(field1.getSizeInBytes == 8 + 8 + 8 + roundedSize(innerMap.getSizeInBytes))
val field2 = unsafeRow.getMap(1)
assert(field2.numElements == 1)
@ -513,10 +514,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val innerArray = valueArray.getArray(0)
testArrayInt(innerArray, Seq(4))
assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes))
assert(valueArray.getSizeInBytes == 8 + 8 + 8 + innerArray.getSizeInBytes)
}
assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
assert(field2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
assert(unsafeRow.getSizeInBytes ==
8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes))

View file

@ -18,27 +18,190 @@
package org.apache.spark.sql.catalyst.util
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
class UnsafeArraySuite extends SparkFunSuite {
test("from primitive int array") {
val array = Array(1, 10, 100)
val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
assert(unsafe.numElements == 3)
assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 4 * 3)
assert(unsafe.getInt(0) == 1)
assert(unsafe.getInt(1) == 10)
assert(unsafe.getInt(2) == 100)
val booleanArray = Array(false, true)
val shortArray = Array(1.toShort, 10.toShort, 100.toShort)
val intArray = Array(1, 10, 100)
val longArray = Array(1.toLong, 10.toLong, 100.toLong)
val floatArray = Array(1.1.toFloat, 2.2.toFloat, 3.3.toFloat)
val doubleArray = Array(1.1, 2.2, 3.3)
val stringArray = Array("1", "10", "100")
val dateArray = Array(
DateTimeUtils.stringToDate(UTF8String.fromString("1970-1-1")).get,
DateTimeUtils.stringToDate(UTF8String.fromString("2016-7-26")).get)
val timestampArray = Array(
DateTimeUtils.stringToTimestamp(UTF8String.fromString("1970-1-1 00:00:00")).get,
DateTimeUtils.stringToTimestamp(UTF8String.fromString("2016-7-26 00:00:00")).get)
val decimalArray4_1 = Array(
BigDecimal("123.4").setScale(1, BigDecimal.RoundingMode.FLOOR),
BigDecimal("567.8").setScale(1, BigDecimal.RoundingMode.FLOOR))
val decimalArray20_20 = Array(
BigDecimal("1.2345678901234567890123456").setScale(21, BigDecimal.RoundingMode.FLOOR),
BigDecimal("2.3456789012345678901234567").setScale(21, BigDecimal.RoundingMode.FLOOR))
val calenderintervalArray = Array(new CalendarInterval(3, 321), new CalendarInterval(1, 123))
val intMultiDimArray = Array(Array(1), Array(2, 20), Array(3, 30, 300))
val doubleMultiDimArray = Array(
Array(1.1, 11.1), Array(2.2, 22.2, 222.2), Array(3.3, 33.3, 333.3, 3333.3))
test("read array") {
val unsafeBoolean = ExpressionEncoder[Array[Boolean]].resolveAndBind().
toRow(booleanArray).getArray(0)
assert(unsafeBoolean.isInstanceOf[UnsafeArrayData])
assert(unsafeBoolean.numElements == booleanArray.length)
booleanArray.zipWithIndex.map { case (e, i) =>
assert(unsafeBoolean.getBoolean(i) == e)
}
val unsafeShort = ExpressionEncoder[Array[Short]].resolveAndBind().
toRow(shortArray).getArray(0)
assert(unsafeShort.isInstanceOf[UnsafeArrayData])
assert(unsafeShort.numElements == shortArray.length)
shortArray.zipWithIndex.map { case (e, i) =>
assert(unsafeShort.getShort(i) == e)
}
val unsafeInt = ExpressionEncoder[Array[Int]].resolveAndBind().
toRow(intArray).getArray(0)
assert(unsafeInt.isInstanceOf[UnsafeArrayData])
assert(unsafeInt.numElements == intArray.length)
intArray.zipWithIndex.map { case (e, i) =>
assert(unsafeInt.getInt(i) == e)
}
val unsafeLong = ExpressionEncoder[Array[Long]].resolveAndBind().
toRow(longArray).getArray(0)
assert(unsafeLong.isInstanceOf[UnsafeArrayData])
assert(unsafeLong.numElements == longArray.length)
longArray.zipWithIndex.map { case (e, i) =>
assert(unsafeLong.getLong(i) == e)
}
val unsafeFloat = ExpressionEncoder[Array[Float]].resolveAndBind().
toRow(floatArray).getArray(0)
assert(unsafeFloat.isInstanceOf[UnsafeArrayData])
assert(unsafeFloat.numElements == floatArray.length)
floatArray.zipWithIndex.map { case (e, i) =>
assert(unsafeFloat.getFloat(i) == e)
}
val unsafeDouble = ExpressionEncoder[Array[Double]].resolveAndBind().
toRow(doubleArray).getArray(0)
assert(unsafeDouble.isInstanceOf[UnsafeArrayData])
assert(unsafeDouble.numElements == doubleArray.length)
doubleArray.zipWithIndex.map { case (e, i) =>
assert(unsafeDouble.getDouble(i) == e)
}
val unsafeString = ExpressionEncoder[Array[String]].resolveAndBind().
toRow(stringArray).getArray(0)
assert(unsafeString.isInstanceOf[UnsafeArrayData])
assert(unsafeString.numElements == stringArray.length)
stringArray.zipWithIndex.map { case (e, i) =>
assert(unsafeString.getUTF8String(i).toString().equals(e))
}
val unsafeDate = ExpressionEncoder[Array[Int]].resolveAndBind().
toRow(dateArray).getArray(0)
assert(unsafeDate.isInstanceOf[UnsafeArrayData])
assert(unsafeDate.numElements == dateArray.length)
dateArray.zipWithIndex.map { case (e, i) =>
assert(unsafeDate.get(i, DateType) == e)
}
val unsafeTimestamp = ExpressionEncoder[Array[Long]].resolveAndBind().
toRow(timestampArray).getArray(0)
assert(unsafeTimestamp.isInstanceOf[UnsafeArrayData])
assert(unsafeTimestamp.numElements == timestampArray.length)
timestampArray.zipWithIndex.map { case (e, i) =>
assert(unsafeTimestamp.get(i, TimestampType) == e)
}
Seq(decimalArray4_1, decimalArray20_20).map { decimalArray =>
val decimal = decimalArray(0)
val schema = new StructType().add(
"array", ArrayType(DecimalType(decimal.precision, decimal.scale)))
val encoder = RowEncoder(schema).resolveAndBind()
val externalRow = Row(decimalArray)
val ir = encoder.toRow(externalRow)
val unsafeDecimal = ir.getArray(0)
assert(unsafeDecimal.isInstanceOf[UnsafeArrayData])
assert(unsafeDecimal.numElements == decimalArray.length)
decimalArray.zipWithIndex.map { case (e, i) =>
assert(unsafeDecimal.getDecimal(i, e.precision, e.scale).toBigDecimal == e)
}
}
val schema = new StructType().add("array", ArrayType(CalendarIntervalType))
val encoder = RowEncoder(schema).resolveAndBind()
val externalRow = Row(calenderintervalArray)
val ir = encoder.toRow(externalRow)
val unsafeCalendar = ir.getArray(0)
assert(unsafeCalendar.isInstanceOf[UnsafeArrayData])
assert(unsafeCalendar.numElements == calenderintervalArray.length)
calenderintervalArray.zipWithIndex.map { case (e, i) =>
assert(unsafeCalendar.getInterval(i) == e)
}
val unsafeMultiDimInt = ExpressionEncoder[Array[Array[Int]]].resolveAndBind().
toRow(intMultiDimArray).getArray(0)
assert(unsafeMultiDimInt.isInstanceOf[UnsafeArrayData])
assert(unsafeMultiDimInt.numElements == intMultiDimArray.length)
intMultiDimArray.zipWithIndex.map { case (a, j) =>
val u = unsafeMultiDimInt.getArray(j)
assert(u.isInstanceOf[UnsafeArrayData])
assert(u.numElements == a.length)
a.zipWithIndex.map { case (e, i) =>
assert(u.getInt(i) == e)
}
}
val unsafeMultiDimDouble = ExpressionEncoder[Array[Array[Double]]].resolveAndBind().
toRow(doubleMultiDimArray).getArray(0)
assert(unsafeDouble.isInstanceOf[UnsafeArrayData])
assert(unsafeMultiDimDouble.numElements == doubleMultiDimArray.length)
doubleMultiDimArray.zipWithIndex.map { case (a, j) =>
val u = unsafeMultiDimDouble.getArray(j)
assert(u.isInstanceOf[UnsafeArrayData])
assert(u.numElements == a.length)
a.zipWithIndex.map { case (e, i) =>
assert(u.getDouble(i) == e)
}
}
}
test("from primitive double array") {
val array = Array(1.1, 2.2, 3.3)
val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
assert(unsafe.numElements == 3)
assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 8 * 3)
assert(unsafe.getDouble(0) == 1.1)
assert(unsafe.getDouble(1) == 2.2)
assert(unsafe.getDouble(2) == 3.3)
test("from primitive array") {
val unsafeInt = UnsafeArrayData.fromPrimitiveArray(intArray)
assert(unsafeInt.numElements == 3)
assert(unsafeInt.getSizeInBytes ==
((8 + scala.math.ceil(3/64.toDouble) * 8 + 4 * 3 + 7).toInt / 8) * 8)
intArray.zipWithIndex.map { case (e, i) =>
assert(unsafeInt.getInt(i) == e)
}
val unsafeDouble = UnsafeArrayData.fromPrimitiveArray(doubleArray)
assert(unsafeDouble.numElements == 3)
assert(unsafeDouble.getSizeInBytes ==
((8 + scala.math.ceil(3/64.toDouble) * 8 + 8 * 3 + 7).toInt / 8) * 8)
doubleArray.zipWithIndex.map { case (e, i) =>
assert(unsafeDouble.getDouble(i) == e)
}
}
test("to primitive array") {
val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind()
assert(intEncoder.toRow(intArray).getArray(0).toIntArray.sameElements(intArray))
val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
assert(doubleEncoder.toRow(doubleArray).getArray(0).toDoubleArray.sameElements(doubleArray))
}
}

View file

@ -601,7 +601,7 @@ private[columnar] case class ARRAY(dataType: ArrayType)
override def actualSize(row: InternalRow, ordinal: Int): Int = {
val unsafeArray = getField(row, ordinal)
4 + unsafeArray.getSizeInBytes
8 + unsafeArray.getSizeInBytes
}
override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = {
@ -640,7 +640,7 @@ private[columnar] case class MAP(dataType: MapType)
override def actualSize(row: InternalRow, ordinal: Int): Int = {
val unsafeMap = getField(row, ordinal)
4 + unsafeMap.getSizeInBytes
8 + unsafeMap.getSizeInBytes
}
override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = {

View file

@ -0,0 +1,232 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.benchmark
import scala.util.Random
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeArrayWriter}
import org.apache.spark.util.Benchmark
/**
* Benchmark [[UnsafeArrayDataBenchmark]] for UnsafeArrayData
* To run this:
* 1. replace ignore(...) with test(...)
* 2. build/sbt "sql/test-only *benchmark.UnsafeArrayDataBenchmark"
*
* Benchmarks in this file are skipped in normal builds.
*/
class UnsafeArrayDataBenchmark extends BenchmarkBase {
def calculateHeaderPortionInBytes(count: Int) : Int = {
/* 4 + 4 * count // Use this expression for SPARK-15962 */
UnsafeArrayData.calculateHeaderPortionInBytes(count)
}
def readUnsafeArray(iters: Int): Unit = {
val count = 1024 * 1024 * 16
val rand = new Random(42)
val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt }
val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind()
val intUnsafeArray = intEncoder.toRow(intPrimitiveArray).getArray(0)
val readIntArray = { i: Int =>
var n = 0
while (n < iters) {
val len = intUnsafeArray.numElements
var sum = 0
var i = 0
while (i < len) {
sum += intUnsafeArray.getInt(i)
i += 1
}
n += 1
}
}
val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble }
val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
val doubleUnsafeArray = doubleEncoder.toRow(doublePrimitiveArray).getArray(0)
val readDoubleArray = { i: Int =>
var n = 0
while (n < iters) {
val len = doubleUnsafeArray.numElements
var sum = 0.0
var i = 0
while (i < len) {
sum += doubleUnsafeArray.getDouble(i)
i += 1
}
n += 1
}
}
val benchmark = new Benchmark("Read UnsafeArrayData", count * iters)
benchmark.addCase("Int")(readIntArray)
benchmark.addCase("Double")(readDoubleArray)
benchmark.run
/*
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)
Read UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Int 252 / 260 666.1 1.5 1.0X
Double 281 / 292 597.7 1.7 0.9X
*/
}
def writeUnsafeArray(iters: Int): Unit = {
val count = 1024 * 1024 * 2
val rand = new Random(42)
var intTotalLength: Int = 0
val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt }
val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind()
val writeIntArray = { i: Int =>
var len = 0
var n = 0
while (n < iters) {
len += intEncoder.toRow(intPrimitiveArray).getArray(0).numElements()
n += 1
}
intTotalLength = len
}
var doubleTotalLength: Int = 0
val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble }
val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
val writeDoubleArray = { i: Int =>
var len = 0
var n = 0
while (n < iters) {
len += doubleEncoder.toRow(doublePrimitiveArray).getArray(0).numElements()
n += 1
}
doubleTotalLength = len
}
val benchmark = new Benchmark("Write UnsafeArrayData", count * iters)
benchmark.addCase("Int")(writeIntArray)
benchmark.addCase("Double")(writeDoubleArray)
benchmark.run
/*
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)
Write UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Int 196 / 249 107.0 9.3 1.0X
Double 227 / 367 92.3 10.8 0.9X
*/
}
def getPrimitiveArray(iters: Int): Unit = {
val count = 1024 * 1024 * 12
val rand = new Random(42)
var intTotalLength: Int = 0
val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt }
val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind()
val intUnsafeArray = intEncoder.toRow(intPrimitiveArray).getArray(0)
val readIntArray = { i: Int =>
var len = 0
var n = 0
while (n < iters) {
len += intUnsafeArray.toIntArray.length
n += 1
}
intTotalLength = len
}
var doubleTotalLength: Int = 0
val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble }
val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
val doubleUnsafeArray = doubleEncoder.toRow(doublePrimitiveArray).getArray(0)
val readDoubleArray = { i: Int =>
var len = 0
var n = 0
while (n < iters) {
len += doubleUnsafeArray.toDoubleArray.length
n += 1
}
doubleTotalLength = len
}
val benchmark = new Benchmark("Get primitive array from UnsafeArrayData", count * iters)
benchmark.addCase("Int")(readIntArray)
benchmark.addCase("Double")(readDoubleArray)
benchmark.run
/*
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)
Get primitive array from UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Int 151 / 198 415.8 2.4 1.0X
Double 214 / 394 293.6 3.4 0.7X
*/
}
def putPrimitiveArray(iters: Int): Unit = {
val count = 1024 * 1024 * 12
val rand = new Random(42)
var intTotalLen: Int = 0
val intPrimitiveArray = Array.fill[Int](count) { rand.nextInt }
val createIntArray = { i: Int =>
var len = 0
var n = 0
while (n < iters) {
len += UnsafeArrayData.fromPrimitiveArray(intPrimitiveArray).numElements
n += 1
}
intTotalLen = len
}
var doubleTotalLen: Int = 0
val doublePrimitiveArray = Array.fill[Double](count) { rand.nextDouble }
val createDoubleArray = { i: Int =>
var len = 0
var n = 0
while (n < iters) {
len += UnsafeArrayData.fromPrimitiveArray(doublePrimitiveArray).numElements
n += 1
}
doubleTotalLen = len
}
val benchmark = new Benchmark("Create UnsafeArrayData from primitive array", count * iters)
benchmark.addCase("Int")(createIntArray)
benchmark.addCase("Double")(createDoubleArray)
benchmark.run
/*
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)
Create UnsafeArrayData from primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Int 206 / 211 306.0 3.3 1.0X
Double 232 / 406 271.6 3.7 0.9X
*/
}
ignore("Benchmark UnsafeArrayData") {
readUnsafeArray(10)
writeUnsafeArray(10)
getPrimitiveArray(5)
putPrimitiveArray(5)
}
}

View file

@ -73,8 +73,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4)
checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5)
checkActualSize(ARRAY_TYPE, Array[Any](1), 16)
checkActualSize(MAP_TYPE, Map(1 -> "a"), 29)
checkActualSize(ARRAY_TYPE, Array[Any](1), 8 + 8 + 8 + 8)
checkActualSize(MAP_TYPE, Map(1 -> "a"), 8 + (8 + 8 + 8 + 8) + (8 + 8 + 8 + 8))
checkActualSize(STRUCT_TYPE, Row("hello"), 28)
}