[SPARK-10917] [SQL] improve performance of complex type in columnar cache
This PR improve the performance of complex types in columnar cache by using UnsafeProjection instead of KryoSerializer. A simple benchmark show that this PR could improve the performance of scanning a cached table with complex columns by 15x (comparing to Spark 1.5). Here is the code used to benchmark: ``` df = sc.range(1<<23).map(lambda i: Row(a=Row(b=i, c=str(i)), d=range(10), e=dict(zip(range(10), [str(i) for i in range(10)])))).toDF() df.write.parquet("table") ``` ``` df = sqlContext.read.parquet("table") df.cache() df.count() t = time.time() print df.select("*")._jdf.queryExecution().toRdd().count() print time.time() - t ``` Author: Davies Liu <davies@databricks.com> Closes #8971 from davies/complex.
This commit is contained in:
parent
dd36ec6bc5
commit
075a0b6582
|
@ -23,7 +23,6 @@ import java.math.BigInteger;
|
|||
import org.apache.spark.sql.types.*;
|
||||
import org.apache.spark.unsafe.Platform;
|
||||
import org.apache.spark.unsafe.array.ByteArrayMethods;
|
||||
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
|
||||
import org.apache.spark.unsafe.types.CalendarInterval;
|
||||
import org.apache.spark.unsafe.types.UTF8String;
|
||||
|
||||
|
|
|
@ -48,6 +48,11 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte
|
|||
}
|
||||
|
||||
object ArrayBasedMapData {
|
||||
def apply(map: Map[Any, Any]): ArrayBasedMapData = {
|
||||
val array = map.toArray
|
||||
ArrayBasedMapData(array.map(_._1), array.map(_._2))
|
||||
}
|
||||
|
||||
def apply(keys: Array[Any], values: Array[Any]): ArrayBasedMapData = {
|
||||
new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.sql.columnar
|
|||
|
||||
import java.nio.{ByteBuffer, ByteOrder}
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.MutableRow
|
||||
import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -61,6 +62,10 @@ private[sql] abstract class BasicColumnAccessor[JvmType](
|
|||
protected def underlyingBuffer = buffer
|
||||
}
|
||||
|
||||
private[sql] class NullColumnAccess(buffer: ByteBuffer)
|
||||
extends BasicColumnAccessor[Any](buffer, NULL)
|
||||
with NullableColumnAccessor
|
||||
|
||||
private[sql] abstract class NativeColumnAccessor[T <: AtomicType](
|
||||
override protected val buffer: ByteBuffer,
|
||||
override protected val columnType: NativeColumnType[T])
|
||||
|
@ -96,11 +101,23 @@ private[sql] class BinaryColumnAccessor(buffer: ByteBuffer)
|
|||
extends BasicColumnAccessor[Array[Byte]](buffer, BINARY)
|
||||
with NullableColumnAccessor
|
||||
|
||||
private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int)
|
||||
extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale))
|
||||
private[sql] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType)
|
||||
extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType))
|
||||
|
||||
private[sql] class GenericColumnAccessor(buffer: ByteBuffer, dataType: DataType)
|
||||
extends BasicColumnAccessor[Array[Byte]](buffer, GENERIC(dataType))
|
||||
private[sql] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType)
|
||||
extends BasicColumnAccessor[Decimal](buffer, LARGE_DECIMAL(dataType))
|
||||
with NullableColumnAccessor
|
||||
|
||||
private[sql] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType)
|
||||
extends BasicColumnAccessor[InternalRow](buffer, STRUCT(dataType))
|
||||
with NullableColumnAccessor
|
||||
|
||||
private[sql] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType)
|
||||
extends BasicColumnAccessor[ArrayData](buffer, ARRAY(dataType))
|
||||
with NullableColumnAccessor
|
||||
|
||||
private[sql] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType)
|
||||
extends BasicColumnAccessor[MapData](buffer, MAP(dataType))
|
||||
with NullableColumnAccessor
|
||||
|
||||
private[sql] object ColumnAccessor {
|
||||
|
@ -108,6 +125,7 @@ private[sql] object ColumnAccessor {
|
|||
val buf = buffer.order(ByteOrder.nativeOrder)
|
||||
|
||||
dataType match {
|
||||
case NullType => new NullColumnAccess(buf)
|
||||
case BooleanType => new BooleanColumnAccessor(buf)
|
||||
case ByteType => new ByteColumnAccessor(buf)
|
||||
case ShortType => new ShortColumnAccessor(buf)
|
||||
|
@ -117,9 +135,15 @@ private[sql] object ColumnAccessor {
|
|||
case DoubleType => new DoubleColumnAccessor(buf)
|
||||
case StringType => new StringColumnAccessor(buf)
|
||||
case BinaryType => new BinaryColumnAccessor(buf)
|
||||
case DecimalType.Fixed(precision, scale) if precision < 19 =>
|
||||
new FixedDecimalColumnAccessor(buf, precision, scale)
|
||||
case other => new GenericColumnAccessor(buf, other)
|
||||
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
|
||||
new CompactDecimalColumnAccessor(buf, dt)
|
||||
case dt: DecimalType => new DecimalColumnAccessor(buf, dt)
|
||||
case struct: StructType => new StructColumnAccessor(buf, struct)
|
||||
case array: ArrayType => new ArrayColumnAccessor(buf, array)
|
||||
case map: MapType => new MapColumnAccessor(buf, map)
|
||||
case udt: UserDefinedType[_] => ColumnAccessor(udt.sqlType, buffer)
|
||||
case other =>
|
||||
throw new Exception(s"not support type: $other")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -77,6 +77,10 @@ private[sql] class BasicColumnBuilder[JvmType](
|
|||
}
|
||||
}
|
||||
|
||||
private[sql] class NullColumnBuilder
|
||||
extends BasicColumnBuilder[Any](new ObjectColumnStats(NullType), NULL)
|
||||
with NullableColumnBuilder
|
||||
|
||||
private[sql] abstract class ComplexColumnBuilder[JvmType](
|
||||
columnStats: ColumnStats,
|
||||
columnType: ColumnType[JvmType])
|
||||
|
@ -109,16 +113,20 @@ private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringCol
|
|||
|
||||
private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY)
|
||||
|
||||
private[sql] class FixedDecimalColumnBuilder(
|
||||
precision: Int,
|
||||
scale: Int)
|
||||
extends NativeColumnBuilder(
|
||||
new FixedDecimalColumnStats(precision, scale),
|
||||
FIXED_DECIMAL(precision, scale))
|
||||
private[sql] class CompactDecimalColumnBuilder(dataType: DecimalType)
|
||||
extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType))
|
||||
|
||||
// TODO (lian) Add support for array, struct and map
|
||||
private[sql] class GenericColumnBuilder(dataType: DataType)
|
||||
extends ComplexColumnBuilder(new GenericColumnStats(dataType), GENERIC(dataType))
|
||||
private[sql] class DecimalColumnBuilder(dataType: DecimalType)
|
||||
extends ComplexColumnBuilder(new DecimalColumnStats(dataType), LARGE_DECIMAL(dataType))
|
||||
|
||||
private[sql] class StructColumnBuilder(dataType: StructType)
|
||||
extends ComplexColumnBuilder(new ObjectColumnStats(dataType), STRUCT(dataType))
|
||||
|
||||
private[sql] class ArrayColumnBuilder(dataType: ArrayType)
|
||||
extends ComplexColumnBuilder(new ObjectColumnStats(dataType), ARRAY(dataType))
|
||||
|
||||
private[sql] class MapColumnBuilder(dataType: MapType)
|
||||
extends ComplexColumnBuilder(new ObjectColumnStats(dataType), MAP(dataType))
|
||||
|
||||
private[sql] object ColumnBuilder {
|
||||
val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024
|
||||
|
@ -145,6 +153,7 @@ private[sql] object ColumnBuilder {
|
|||
columnName: String = "",
|
||||
useCompression: Boolean = false): ColumnBuilder = {
|
||||
val builder: ColumnBuilder = dataType match {
|
||||
case NullType => new NullColumnBuilder
|
||||
case BooleanType => new BooleanColumnBuilder
|
||||
case ByteType => new ByteColumnBuilder
|
||||
case ShortType => new ShortColumnBuilder
|
||||
|
@ -154,9 +163,16 @@ private[sql] object ColumnBuilder {
|
|||
case DoubleType => new DoubleColumnBuilder
|
||||
case StringType => new StringColumnBuilder
|
||||
case BinaryType => new BinaryColumnBuilder
|
||||
case DecimalType.Fixed(precision, scale) if precision < 19 =>
|
||||
new FixedDecimalColumnBuilder(precision, scale)
|
||||
case other => new GenericColumnBuilder(other)
|
||||
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
|
||||
new CompactDecimalColumnBuilder(dt)
|
||||
case dt: DecimalType => new DecimalColumnBuilder(dt)
|
||||
case struct: StructType => new StructColumnBuilder(struct)
|
||||
case array: ArrayType => new ArrayColumnBuilder(array)
|
||||
case map: MapType => new MapColumnBuilder(map)
|
||||
case udt: UserDefinedType[_] =>
|
||||
return apply(udt.sqlType, initialSize, columnName, useCompression)
|
||||
case other =>
|
||||
throw new Exception(s"not suppported type: $other")
|
||||
}
|
||||
|
||||
builder.initialize(initialSize, columnName, useCompression)
|
||||
|
|
|
@ -235,7 +235,9 @@ private[sql] class BinaryColumnStats extends ColumnStats {
|
|||
new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes))
|
||||
}
|
||||
|
||||
private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
|
||||
private[sql] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
|
||||
def this(dt: DecimalType) = this(dt.precision, dt.scale)
|
||||
|
||||
protected var upper: Decimal = null
|
||||
protected var lower: Decimal = null
|
||||
|
||||
|
@ -245,7 +247,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C
|
|||
val value = row.getDecimal(ordinal, precision, scale)
|
||||
if (upper == null || value.compareTo(upper) > 0) upper = value
|
||||
if (lower == null || value.compareTo(lower) < 0) lower = value
|
||||
sizeInBytes += FIXED_DECIMAL.defaultSize
|
||||
// TODO: this is not right for DecimalType with precision > 18
|
||||
sizeInBytes += 8
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -253,8 +256,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C
|
|||
new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
|
||||
}
|
||||
|
||||
private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats {
|
||||
val columnType = GENERIC(dataType)
|
||||
private[sql] class ObjectColumnStats(dataType: DataType) extends ColumnStats {
|
||||
val columnType = ColumnType(dataType)
|
||||
|
||||
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
|
||||
super.gatherStats(row, ordinal)
|
||||
|
|
|
@ -17,14 +17,15 @@
|
|||
|
||||
package org.apache.spark.sql.columnar
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
import java.math.{BigDecimal, BigInteger}
|
||||
import java.nio.{ByteOrder, ByteBuffer}
|
||||
|
||||
import scala.reflect.runtime.universe.TypeTag
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.MutableRow
|
||||
import org.apache.spark.sql.execution.SparkSqlSerializer
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.Platform
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
/**
|
||||
|
@ -102,6 +103,16 @@ private[sql] sealed abstract class ColumnType[JvmType] {
|
|||
override def toString: String = getClass.getSimpleName.stripSuffix("$")
|
||||
}
|
||||
|
||||
private[sql] object NULL extends ColumnType[Any] {
|
||||
|
||||
override def dataType: DataType = NullType
|
||||
override def defaultSize: Int = 0
|
||||
override def append(v: Any, buffer: ByteBuffer): Unit = {}
|
||||
override def extract(buffer: ByteBuffer): Any = null
|
||||
override def setField(row: MutableRow, ordinal: Int, value: Any): Unit = row.setNullAt(ordinal)
|
||||
override def getField(row: InternalRow, ordinal: Int): Any = null
|
||||
}
|
||||
|
||||
private[sql] abstract class NativeColumnType[T <: AtomicType](
|
||||
val dataType: T,
|
||||
val defaultSize: Int)
|
||||
|
@ -339,10 +350,8 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) {
|
|||
override def clone(v: UTF8String): UTF8String = v.clone()
|
||||
}
|
||||
|
||||
private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
|
||||
extends NativeColumnType(
|
||||
DecimalType(precision, scale),
|
||||
FIXED_DECIMAL.defaultSize) {
|
||||
private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int)
|
||||
extends NativeColumnType(DecimalType(precision, scale), 8) {
|
||||
|
||||
override def extract(buffer: ByteBuffer): Decimal = {
|
||||
Decimal(buffer.getLong(), precision, scale)
|
||||
|
@ -365,32 +374,39 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
|
|||
}
|
||||
}
|
||||
|
||||
private[sql] object FIXED_DECIMAL {
|
||||
val defaultSize = 8
|
||||
private[sql] object COMPACT_DECIMAL {
|
||||
def apply(dt: DecimalType): COMPACT_DECIMAL = {
|
||||
COMPACT_DECIMAL(dt.precision, dt.scale)
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] sealed abstract class ByteArrayColumnType(val defaultSize: Int)
|
||||
extends ColumnType[Array[Byte]] {
|
||||
private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int)
|
||||
extends ColumnType[JvmType] {
|
||||
|
||||
def serialize(value: JvmType): Array[Byte]
|
||||
def deserialize(bytes: Array[Byte]): JvmType
|
||||
|
||||
override def actualSize(row: InternalRow, ordinal: Int): Int = {
|
||||
getField(row, ordinal).length + 4
|
||||
// TODO: grow the buffer in append(), so serialize() will not be called twice
|
||||
serialize(getField(row, ordinal)).length + 4
|
||||
}
|
||||
|
||||
override def append(v: Array[Byte], buffer: ByteBuffer): Unit = {
|
||||
buffer.putInt(v.length).put(v, 0, v.length)
|
||||
override def append(v: JvmType, buffer: ByteBuffer): Unit = {
|
||||
val bytes = serialize(v)
|
||||
buffer.putInt(bytes.length).put(bytes, 0, bytes.length)
|
||||
}
|
||||
|
||||
override def extract(buffer: ByteBuffer): Array[Byte] = {
|
||||
override def extract(buffer: ByteBuffer): JvmType = {
|
||||
val length = buffer.getInt()
|
||||
val bytes = new Array[Byte](length)
|
||||
buffer.get(bytes, 0, length)
|
||||
bytes
|
||||
deserialize(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] object BINARY extends ByteArrayColumnType(16) {
|
||||
private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) {
|
||||
|
||||
def dataType: DataType = BooleanType
|
||||
def dataType: DataType = BinaryType
|
||||
|
||||
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
|
||||
row.update(ordinal, value)
|
||||
|
@ -399,24 +415,164 @@ private[sql] object BINARY extends ByteArrayColumnType(16) {
|
|||
override def getField(row: InternalRow, ordinal: Int): Array[Byte] = {
|
||||
row.getBinary(ordinal)
|
||||
}
|
||||
|
||||
def serialize(value: Array[Byte]): Array[Byte] = value
|
||||
def deserialize(bytes: Array[Byte]): Array[Byte] = bytes
|
||||
}
|
||||
|
||||
// Used to process generic objects (all types other than those listed above). Objects should be
|
||||
// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized
|
||||
// byte array.
|
||||
private[sql] case class GENERIC(dataType: DataType) extends ByteArrayColumnType(16) {
|
||||
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
|
||||
row.update(ordinal, SparkSqlSerializer.deserialize[Any](value))
|
||||
private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int)
|
||||
extends ByteArrayColumnType[Decimal](12) {
|
||||
|
||||
override val dataType: DataType = DecimalType(precision, scale)
|
||||
|
||||
override def getField(row: InternalRow, ordinal: Int): Decimal = {
|
||||
row.getDecimal(ordinal, precision, scale)
|
||||
}
|
||||
|
||||
override def getField(row: InternalRow, ordinal: Int): Array[Byte] = {
|
||||
SparkSqlSerializer.serialize(row.get(ordinal, dataType))
|
||||
override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = {
|
||||
row.setDecimal(ordinal, value, precision)
|
||||
}
|
||||
|
||||
override def serialize(value: Decimal): Array[Byte] = {
|
||||
value.toJavaBigDecimal.unscaledValue().toByteArray
|
||||
}
|
||||
|
||||
override def deserialize(bytes: Array[Byte]): Decimal = {
|
||||
val javaDecimal = new BigDecimal(new BigInteger(bytes), scale)
|
||||
Decimal.apply(javaDecimal, precision, scale)
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] object LARGE_DECIMAL {
|
||||
def apply(dt: DecimalType): LARGE_DECIMAL = {
|
||||
LARGE_DECIMAL(dt.precision, dt.scale)
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] case class STRUCT(dataType: StructType)
|
||||
extends ByteArrayColumnType[InternalRow](20) {
|
||||
|
||||
private val projection: UnsafeProjection =
|
||||
UnsafeProjection.create(dataType)
|
||||
private val numOfFields: Int = dataType.fields.size
|
||||
|
||||
override def setField(row: MutableRow, ordinal: Int, value: InternalRow): Unit = {
|
||||
row.update(ordinal, value)
|
||||
}
|
||||
|
||||
override def getField(row: InternalRow, ordinal: Int): InternalRow = {
|
||||
row.getStruct(ordinal, numOfFields)
|
||||
}
|
||||
|
||||
override def serialize(value: InternalRow): Array[Byte] = {
|
||||
val unsafeRow = if (value.isInstanceOf[UnsafeRow]) {
|
||||
value.asInstanceOf[UnsafeRow]
|
||||
} else {
|
||||
projection(value)
|
||||
}
|
||||
unsafeRow.getBytes
|
||||
}
|
||||
|
||||
override def deserialize(bytes: Array[Byte]): InternalRow = {
|
||||
val unsafeRow = new UnsafeRow
|
||||
unsafeRow.pointTo(bytes, numOfFields, bytes.length)
|
||||
unsafeRow
|
||||
}
|
||||
|
||||
override def clone(v: InternalRow): InternalRow = v.copy()
|
||||
}
|
||||
|
||||
private[sql] case class ARRAY(dataType: ArrayType)
|
||||
extends ByteArrayColumnType[ArrayData](16) {
|
||||
|
||||
private lazy val projection = UnsafeProjection.create(Array[DataType](dataType))
|
||||
private val mutableRow = new GenericMutableRow(new Array[Any](1))
|
||||
|
||||
override def setField(row: MutableRow, ordinal: Int, value: ArrayData): Unit = {
|
||||
row.update(ordinal, value)
|
||||
}
|
||||
|
||||
override def getField(row: InternalRow, ordinal: Int): ArrayData = {
|
||||
row.getArray(ordinal)
|
||||
}
|
||||
|
||||
override def serialize(value: ArrayData): Array[Byte] = {
|
||||
val unsafeArray = if (value.isInstanceOf[UnsafeArrayData]) {
|
||||
value.asInstanceOf[UnsafeArrayData]
|
||||
} else {
|
||||
mutableRow(0) = value
|
||||
projection(mutableRow).getArray(0)
|
||||
}
|
||||
val outputBuffer =
|
||||
ByteBuffer.allocate(4 + unsafeArray.getSizeInBytes).order(ByteOrder.nativeOrder())
|
||||
outputBuffer.putInt(unsafeArray.numElements())
|
||||
val underlying = outputBuffer.array()
|
||||
unsafeArray.writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 4)
|
||||
underlying
|
||||
}
|
||||
|
||||
override def deserialize(bytes: Array[Byte]): ArrayData = {
|
||||
val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder())
|
||||
val numElements = buffer.getInt
|
||||
val array = new UnsafeArrayData
|
||||
array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 4, numElements, bytes.length - 4)
|
||||
array
|
||||
}
|
||||
|
||||
override def clone(v: ArrayData): ArrayData = v.copy()
|
||||
}
|
||||
|
||||
private[sql] case class MAP(dataType: MapType) extends ByteArrayColumnType[MapData](32) {
|
||||
|
||||
private lazy val projection: UnsafeProjection = UnsafeProjection.create(Array[DataType](dataType))
|
||||
private val mutableRow = new GenericMutableRow(new Array[Any](1))
|
||||
|
||||
override def setField(row: MutableRow, ordinal: Int, value: MapData): Unit = {
|
||||
row.update(ordinal, value)
|
||||
}
|
||||
|
||||
override def getField(row: InternalRow, ordinal: Int): MapData = {
|
||||
row.getMap(ordinal)
|
||||
}
|
||||
|
||||
override def serialize(value: MapData): Array[Byte] = {
|
||||
val unsafeMap = if (value.isInstanceOf[UnsafeMapData]) {
|
||||
value.asInstanceOf[UnsafeMapData]
|
||||
} else {
|
||||
mutableRow(0) = value
|
||||
projection(mutableRow).getMap(0)
|
||||
}
|
||||
|
||||
val outputBuffer =
|
||||
ByteBuffer.allocate(8 + unsafeMap.getSizeInBytes).order(ByteOrder.nativeOrder())
|
||||
outputBuffer.putInt(unsafeMap.numElements())
|
||||
val keyBytes = unsafeMap.keyArray().getSizeInBytes
|
||||
outputBuffer.putInt(keyBytes)
|
||||
val underlying = outputBuffer.array()
|
||||
unsafeMap.keyArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8)
|
||||
unsafeMap.valueArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8 + keyBytes)
|
||||
underlying
|
||||
}
|
||||
|
||||
override def deserialize(bytes: Array[Byte]): MapData = {
|
||||
val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder())
|
||||
val numElements = buffer.getInt
|
||||
val keyArraySize = buffer.getInt
|
||||
val keyArray = new UnsafeArrayData
|
||||
val valueArray = new UnsafeArrayData
|
||||
keyArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8, numElements, keyArraySize)
|
||||
valueArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8 + keyArraySize, numElements,
|
||||
bytes.length - 8 - keyArraySize)
|
||||
new UnsafeMapData(keyArray, valueArray)
|
||||
}
|
||||
|
||||
override def clone(v: MapData): MapData = v.copy()
|
||||
}
|
||||
|
||||
private[sql] object ColumnType {
|
||||
def apply(dataType: DataType): ColumnType[_] = {
|
||||
dataType match {
|
||||
case NullType => NULL
|
||||
case BooleanType => BOOLEAN
|
||||
case ByteType => BYTE
|
||||
case ShortType => SHORT
|
||||
|
@ -426,9 +582,14 @@ private[sql] object ColumnType {
|
|||
case DoubleType => DOUBLE
|
||||
case StringType => STRING
|
||||
case BinaryType => BINARY
|
||||
case DecimalType.Fixed(precision, scale) if precision < 19 =>
|
||||
FIXED_DECIMAL(precision, scale)
|
||||
case other => GENERIC(other)
|
||||
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => COMPACT_DECIMAL(dt)
|
||||
case dt: DecimalType => LARGE_DECIMAL(dt)
|
||||
case arr: ArrayType => ARRAY(arr)
|
||||
case map: MapType => MAP(map)
|
||||
case struct: StructType => STRUCT(struct)
|
||||
case udt: UserDefinedType[_] => apply(udt.sqlType)
|
||||
case other =>
|
||||
throw new Exception(s"Unsupported type: $other")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
package org.apache.spark.sql.columnar
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -76,11 +75,11 @@ class ColumnStatsSuite extends SparkFunSuite {
|
|||
def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](
|
||||
initialStatistics: GenericInternalRow): Unit = {
|
||||
|
||||
val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName
|
||||
val columnType = FIXED_DECIMAL(15, 10)
|
||||
val columnStatsName = classOf[DecimalColumnStats].getSimpleName
|
||||
val columnType = COMPACT_DECIMAL(15, 10)
|
||||
|
||||
test(s"$columnStatsName: empty") {
|
||||
val columnStats = new FixedDecimalColumnStats(15, 10)
|
||||
val columnStats = new DecimalColumnStats(15, 10)
|
||||
columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
|
||||
case (actual, expected) => assert(actual === expected)
|
||||
}
|
||||
|
@ -89,7 +88,7 @@ class ColumnStatsSuite extends SparkFunSuite {
|
|||
test(s"$columnStatsName: non-empty") {
|
||||
import org.apache.spark.sql.columnar.ColumnarTestUtils._
|
||||
|
||||
val columnStats = new FixedDecimalColumnStats(15, 10)
|
||||
val columnStats = new DecimalColumnStats(15, 10)
|
||||
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
|
||||
rows.foreach(columnStats.gatherStats(_, 0))
|
||||
|
||||
|
|
|
@ -19,28 +19,25 @@ package org.apache.spark.sql.columnar
|
|||
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
import com.esotericsoftware.kryo.io.{Input, Output}
|
||||
import com.esotericsoftware.kryo.{Kryo, Serializer}
|
||||
|
||||
import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
|
||||
import org.apache.spark.serializer.KryoRegistrator
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.catalyst.CatalystTypeConverters
|
||||
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
|
||||
import org.apache.spark.sql.columnar.ColumnarTestUtils._
|
||||
import org.apache.spark.sql.execution.SparkSqlSerializer
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
import org.apache.spark.{Logging, SparkFunSuite}
|
||||
|
||||
|
||||
class ColumnTypeSuite extends SparkFunSuite with Logging {
|
||||
private val DEFAULT_BUFFER_SIZE = 512
|
||||
private val MAP_GENERIC = GENERIC(MapType(IntegerType, StringType))
|
||||
private val MAP_TYPE = MAP(MapType(IntegerType, StringType))
|
||||
private val ARRAY_TYPE = ARRAY(ArrayType(IntegerType))
|
||||
private val STRUCT_TYPE = STRUCT(StructType(StructField("a", StringType) :: Nil))
|
||||
|
||||
test("defaultSize") {
|
||||
val checks = Map(
|
||||
BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4,
|
||||
LONG -> 8, FLOAT -> 4, DOUBLE -> 8,
|
||||
STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8,
|
||||
MAP_GENERIC -> 16)
|
||||
NULL-> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8,
|
||||
FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12,
|
||||
STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 16, MAP_TYPE -> 32)
|
||||
|
||||
checks.foreach { case (columnType, expectedSize) =>
|
||||
assertResult(expectedSize, s"Wrong defaultSize for $columnType") {
|
||||
|
@ -50,18 +47,19 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
|
|||
}
|
||||
|
||||
test("actualSize") {
|
||||
def checkActualSize[JvmType](
|
||||
columnType: ColumnType[JvmType],
|
||||
value: JvmType,
|
||||
def checkActualSize(
|
||||
columnType: ColumnType[_],
|
||||
value: Any,
|
||||
expected: Int): Unit = {
|
||||
|
||||
assertResult(expected, s"Wrong actualSize for $columnType") {
|
||||
val row = new GenericMutableRow(1)
|
||||
columnType.setField(row, 0, value)
|
||||
row.update(0, CatalystTypeConverters.convertToCatalyst(value))
|
||||
columnType.actualSize(row, 0)
|
||||
}
|
||||
}
|
||||
|
||||
checkActualSize(NULL, null, 0)
|
||||
checkActualSize(BOOLEAN, true, 1)
|
||||
checkActualSize(BYTE, Byte.MaxValue, 1)
|
||||
checkActualSize(SHORT, Short.MaxValue, 2)
|
||||
|
@ -69,177 +67,66 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
|
|||
checkActualSize(LONG, Long.MaxValue, 8)
|
||||
checkActualSize(FLOAT, Float.MaxValue, 4)
|
||||
checkActualSize(DOUBLE, Double.MaxValue, 8)
|
||||
checkActualSize(STRING, UTF8String.fromString("hello"), 4 + "hello".getBytes("utf-8").length)
|
||||
checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length)
|
||||
checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4)
|
||||
checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
|
||||
|
||||
val generic = Map(1 -> "a")
|
||||
checkActualSize(MAP_GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8)
|
||||
checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
|
||||
checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5)
|
||||
checkActualSize(ARRAY_TYPE, Array[Any](1), 16)
|
||||
checkActualSize(MAP_TYPE, Map(1 -> "a"), 25)
|
||||
checkActualSize(STRUCT_TYPE, Row("hello"), 28)
|
||||
}
|
||||
|
||||
testNativeColumnType(BOOLEAN)(
|
||||
(buffer: ByteBuffer, v: Boolean) => {
|
||||
buffer.put((if (v) 1 else 0).toByte)
|
||||
},
|
||||
(buffer: ByteBuffer) => {
|
||||
buffer.get() == 1
|
||||
})
|
||||
testNativeColumnType(BOOLEAN)
|
||||
testNativeColumnType(BYTE)
|
||||
testNativeColumnType(SHORT)
|
||||
testNativeColumnType(INT)
|
||||
testNativeColumnType(LONG)
|
||||
testNativeColumnType(FLOAT)
|
||||
testNativeColumnType(DOUBLE)
|
||||
testNativeColumnType(COMPACT_DECIMAL(15, 10))
|
||||
testNativeColumnType(STRING)
|
||||
|
||||
testNativeColumnType(BYTE)(_.put(_), _.get)
|
||||
testColumnType(NULL)
|
||||
testColumnType(BINARY)
|
||||
testColumnType(LARGE_DECIMAL(20, 10))
|
||||
testColumnType(STRUCT_TYPE)
|
||||
testColumnType(ARRAY_TYPE)
|
||||
testColumnType(MAP_TYPE)
|
||||
|
||||
testNativeColumnType(SHORT)(_.putShort(_), _.getShort)
|
||||
|
||||
testNativeColumnType(INT)(_.putInt(_), _.getInt)
|
||||
|
||||
testNativeColumnType(LONG)(_.putLong(_), _.getLong)
|
||||
|
||||
testNativeColumnType(FLOAT)(_.putFloat(_), _.getFloat)
|
||||
|
||||
testNativeColumnType(DOUBLE)(_.putDouble(_), _.getDouble)
|
||||
|
||||
testNativeColumnType(FIXED_DECIMAL(15, 10))(
|
||||
(buffer: ByteBuffer, decimal: Decimal) => {
|
||||
buffer.putLong(decimal.toUnscaledLong)
|
||||
},
|
||||
(buffer: ByteBuffer) => {
|
||||
Decimal(buffer.getLong(), 15, 10)
|
||||
})
|
||||
|
||||
|
||||
testNativeColumnType(STRING)(
|
||||
(buffer: ByteBuffer, string: UTF8String) => {
|
||||
val bytes = string.getBytes
|
||||
buffer.putInt(bytes.length)
|
||||
buffer.put(bytes)
|
||||
},
|
||||
(buffer: ByteBuffer) => {
|
||||
val length = buffer.getInt()
|
||||
val bytes = new Array[Byte](length)
|
||||
buffer.get(bytes)
|
||||
UTF8String.fromBytes(bytes)
|
||||
})
|
||||
|
||||
testColumnType[Array[Byte]](
|
||||
BINARY,
|
||||
(buffer: ByteBuffer, bytes: Array[Byte]) => {
|
||||
buffer.putInt(bytes.length).put(bytes)
|
||||
},
|
||||
(buffer: ByteBuffer) => {
|
||||
val length = buffer.getInt()
|
||||
val bytes = new Array[Byte](length)
|
||||
buffer.get(bytes, 0, length)
|
||||
bytes
|
||||
})
|
||||
|
||||
test("GENERIC") {
|
||||
val buffer = ByteBuffer.allocate(512)
|
||||
val obj = Map(1 -> "spark", 2 -> "sql")
|
||||
val serializedObj = SparkSqlSerializer.serialize(obj)
|
||||
|
||||
MAP_GENERIC.append(SparkSqlSerializer.serialize(obj), buffer)
|
||||
buffer.rewind()
|
||||
|
||||
val length = buffer.getInt()
|
||||
assert(length === serializedObj.length)
|
||||
|
||||
assertResult(obj, "Deserialized object didn't equal to the original object") {
|
||||
val bytes = new Array[Byte](length)
|
||||
buffer.get(bytes, 0, length)
|
||||
SparkSqlSerializer.deserialize(bytes)
|
||||
def testNativeColumnType[T <: AtomicType](columnType: NativeColumnType[T]): Unit = {
|
||||
testColumnType[T#InternalType](columnType)
|
||||
}
|
||||
|
||||
buffer.rewind()
|
||||
buffer.putInt(serializedObj.length).put(serializedObj)
|
||||
|
||||
assertResult(obj, "Deserialized object didn't equal to the original object") {
|
||||
buffer.rewind()
|
||||
SparkSqlSerializer.deserialize(MAP_GENERIC.extract(buffer))
|
||||
}
|
||||
}
|
||||
|
||||
test("CUSTOM") {
|
||||
val conf = new SparkConf()
|
||||
conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator")
|
||||
val serializer = new SparkSqlSerializer(conf).newInstance()
|
||||
|
||||
val buffer = ByteBuffer.allocate(512)
|
||||
val obj = CustomClass(Int.MaxValue, Long.MaxValue)
|
||||
val serializedObj = serializer.serialize(obj).array()
|
||||
|
||||
MAP_GENERIC.append(serializer.serialize(obj).array(), buffer)
|
||||
buffer.rewind()
|
||||
|
||||
val length = buffer.getInt
|
||||
assert(length === serializedObj.length)
|
||||
assert(13 == length) // id (1) + int (4) + long (8)
|
||||
|
||||
val genericSerializedObj = SparkSqlSerializer.serialize(obj)
|
||||
assert(length != genericSerializedObj.length)
|
||||
assert(length < genericSerializedObj.length)
|
||||
|
||||
assertResult(obj, "Custom deserialized object didn't equal the original object") {
|
||||
val bytes = new Array[Byte](length)
|
||||
buffer.get(bytes, 0, length)
|
||||
serializer.deserialize(ByteBuffer.wrap(bytes))
|
||||
}
|
||||
|
||||
buffer.rewind()
|
||||
buffer.putInt(serializedObj.length).put(serializedObj)
|
||||
|
||||
assertResult(obj, "Custom deserialized object didn't equal the original object") {
|
||||
buffer.rewind()
|
||||
serializer.deserialize(ByteBuffer.wrap(MAP_GENERIC.extract(buffer)))
|
||||
}
|
||||
}
|
||||
|
||||
def testNativeColumnType[T <: AtomicType](
|
||||
columnType: NativeColumnType[T])
|
||||
(putter: (ByteBuffer, T#InternalType) => Unit,
|
||||
getter: (ByteBuffer) => T#InternalType): Unit = {
|
||||
|
||||
testColumnType[T#InternalType](columnType, putter, getter)
|
||||
}
|
||||
|
||||
def testColumnType[JvmType](
|
||||
columnType: ColumnType[JvmType],
|
||||
putter: (ByteBuffer, JvmType) => Unit,
|
||||
getter: (ByteBuffer) => JvmType): Unit = {
|
||||
def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = {
|
||||
|
||||
val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE)
|
||||
val seq = (0 until 4).map(_ => makeRandomValue(columnType))
|
||||
val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType)
|
||||
|
||||
test(s"$columnType.extract") {
|
||||
test(s"$columnType append/extract") {
|
||||
buffer.rewind()
|
||||
seq.foreach(putter(buffer, _))
|
||||
seq.foreach(columnType.append(_, buffer))
|
||||
|
||||
buffer.rewind()
|
||||
seq.foreach { expected =>
|
||||
logInfo("buffer = " + buffer + ", expected = " + expected)
|
||||
val extracted = columnType.extract(buffer)
|
||||
assert(
|
||||
expected === extracted,
|
||||
converter(expected) === converter(extracted),
|
||||
"Extracted value didn't equal to the original one. " +
|
||||
hexDump(expected) + " != " + hexDump(extracted) +
|
||||
", buffer = " + dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer]))
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$columnType.append") {
|
||||
buffer.rewind()
|
||||
seq.foreach(columnType.append(_, buffer))
|
||||
|
||||
buffer.rewind()
|
||||
seq.foreach { expected =>
|
||||
assert(
|
||||
expected === getter(buffer),
|
||||
"Extracted value didn't equal to the original one")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def hexDump(value: Any): String = {
|
||||
if (value == null) {
|
||||
""
|
||||
} else {
|
||||
value.toString.map(ch => Integer.toHexString(ch & 0xffff)).mkString(" ")
|
||||
}
|
||||
}
|
||||
|
||||
private def dumpBuffer(buff: ByteBuffer): Any = {
|
||||
val sb = new StringBuilder()
|
||||
|
@ -253,33 +140,13 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
|
|||
|
||||
test("column type for decimal types with different precision") {
|
||||
(1 to 18).foreach { i =>
|
||||
assertResult(FIXED_DECIMAL(i, 0)) {
|
||||
assertResult(COMPACT_DECIMAL(i, 0)) {
|
||||
ColumnType(DecimalType(i, 0))
|
||||
}
|
||||
}
|
||||
|
||||
assertResult(GENERIC(DecimalType(19, 0))) {
|
||||
assertResult(LARGE_DECIMAL(19, 0)) {
|
||||
ColumnType(DecimalType(19, 0))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[columnar] final case class CustomClass(a: Int, b: Long)
|
||||
|
||||
private[columnar] object CustomerSerializer extends Serializer[CustomClass] {
|
||||
override def write(kryo: Kryo, output: Output, t: CustomClass) {
|
||||
output.writeInt(t.a)
|
||||
output.writeLong(t.b)
|
||||
}
|
||||
override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = {
|
||||
val a = input.readInt()
|
||||
val b = input.readLong()
|
||||
CustomClass(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
private[columnar] final class Registrator extends KryoRegistrator {
|
||||
override def registerClasses(kryo: Kryo) {
|
||||
kryo.register(classOf[CustomClass], CustomerSerializer)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,8 +21,8 @@ import scala.collection.immutable.HashSet
|
|||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
|
||||
import org.apache.spark.sql.types.{AtomicType, Decimal}
|
||||
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow}
|
||||
import org.apache.spark.sql.types.{ArrayBasedMapData, GenericArrayData, AtomicType, Decimal}
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
object ColumnarTestUtils {
|
||||
|
@ -40,6 +40,7 @@ object ColumnarTestUtils {
|
|||
}
|
||||
|
||||
(columnType match {
|
||||
case NULL => null
|
||||
case BOOLEAN => Random.nextBoolean()
|
||||
case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
|
||||
case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
|
||||
|
@ -49,10 +50,15 @@ object ColumnarTestUtils {
|
|||
case DOUBLE => Random.nextDouble()
|
||||
case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
|
||||
case BINARY => randomBytes(Random.nextInt(32))
|
||||
case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
|
||||
case _ =>
|
||||
// Using a random one-element map instead of an arbitrary object
|
||||
Map(Random.nextInt() -> Random.nextString(Random.nextInt(32)))
|
||||
case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
|
||||
case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale)
|
||||
case STRUCT(_) =>
|
||||
new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10))))
|
||||
case ARRAY(_) =>
|
||||
new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt()))
|
||||
case MAP(_) =>
|
||||
ArrayBasedMapData(
|
||||
Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32)))))
|
||||
}).asInstanceOf[JvmType]
|
||||
}
|
||||
|
||||
|
|
|
@ -157,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
|
|||
|
||||
// Create a RDD for the schema
|
||||
val rdd =
|
||||
sparkContext.parallelize((1 to 100), 10).map { i =>
|
||||
sparkContext.parallelize((1 to 10000), 10).map { i =>
|
||||
Row(
|
||||
s"str${i}: test cache.",
|
||||
s"binary${i}: test cache.".getBytes("UTF-8"),
|
||||
|
@ -172,9 +172,9 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
|
|||
BigDecimal(Long.MaxValue.toString + ".12345"),
|
||||
new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"),
|
||||
new Date(i),
|
||||
new Timestamp(i),
|
||||
(1 to i).toSeq,
|
||||
(0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
|
||||
new Timestamp(i * 1000000L),
|
||||
(i to i + 10).toSeq,
|
||||
(i to i + 10).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
|
||||
Row((i - 0.25).toFloat, Seq(true, false, null)))
|
||||
}
|
||||
sqlContext.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
|
||||
|
|
|
@ -20,8 +20,9 @@ package org.apache.spark.sql.columnar
|
|||
import java.nio.ByteBuffer
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.CatalystTypeConverters
|
||||
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
|
||||
import org.apache.spark.sql.types.{ArrayType, StringType}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
class TestNullableColumnAccessor[JvmType](
|
||||
buffer: ByteBuffer,
|
||||
|
@ -40,8 +41,10 @@ class NullableColumnAccessorSuite extends SparkFunSuite {
|
|||
import org.apache.spark.sql.columnar.ColumnarTestUtils._
|
||||
|
||||
Seq(
|
||||
BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
|
||||
STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType)))
|
||||
NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
|
||||
STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
|
||||
STRUCT(StructType(StructField("a", StringType) :: Nil)),
|
||||
ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType)))
|
||||
.foreach {
|
||||
testNullableColumnAccessor(_)
|
||||
}
|
||||
|
@ -69,11 +72,13 @@ class NullableColumnAccessorSuite extends SparkFunSuite {
|
|||
|
||||
val accessor = TestNullableColumnAccessor(builder.build(), columnType)
|
||||
val row = new GenericMutableRow(1)
|
||||
val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType)
|
||||
|
||||
(0 until 4).foreach { _ =>
|
||||
assert(accessor.hasNext)
|
||||
accessor.extractTo(row, 0)
|
||||
assert(row.get(0, columnType.dataType) === randomRow.get(0, columnType.dataType))
|
||||
assert(converter(row.get(0, columnType.dataType))
|
||||
=== converter(randomRow.get(0, columnType.dataType)))
|
||||
|
||||
assert(accessor.hasNext)
|
||||
accessor.extractTo(row, 0)
|
||||
|
|
|
@ -18,7 +18,8 @@
|
|||
package org.apache.spark.sql.columnar
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.execution.SparkSqlSerializer
|
||||
import org.apache.spark.sql.catalyst.CatalystTypeConverters
|
||||
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType])
|
||||
|
@ -35,11 +36,13 @@ object TestNullableColumnBuilder {
|
|||
}
|
||||
|
||||
class NullableColumnBuilderSuite extends SparkFunSuite {
|
||||
import ColumnarTestUtils._
|
||||
import org.apache.spark.sql.columnar.ColumnarTestUtils._
|
||||
|
||||
Seq(
|
||||
BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
|
||||
STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType)))
|
||||
STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
|
||||
STRUCT(StructType(StructField("a", StringType) :: Nil)),
|
||||
ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType)))
|
||||
.foreach {
|
||||
testNullableColumnBuilder(_)
|
||||
}
|
||||
|
@ -74,6 +77,8 @@ class NullableColumnBuilderSuite extends SparkFunSuite {
|
|||
val columnBuilder = TestNullableColumnBuilder(columnType)
|
||||
val randomRow = makeRandomRow(columnType)
|
||||
val nullRow = makeNullRow(1)
|
||||
val dataType = columnType.dataType
|
||||
val converter = CatalystTypeConverters.createToScalaConverter(dataType)
|
||||
|
||||
(0 until 4).foreach { _ =>
|
||||
columnBuilder.appendFrom(randomRow, 0)
|
||||
|
@ -88,14 +93,10 @@ class NullableColumnBuilderSuite extends SparkFunSuite {
|
|||
(1 to 7 by 2).foreach(assertResult(_, "Wrong null position")(buffer.getInt()))
|
||||
|
||||
// For non-null values
|
||||
val actual = new GenericMutableRow(new Array[Any](1))
|
||||
(0 until 4).foreach { _ =>
|
||||
val actual = if (columnType.isInstanceOf[GENERIC]) {
|
||||
SparkSqlSerializer.deserialize[Any](columnType.extract(buffer).asInstanceOf[Array[Byte]])
|
||||
} else {
|
||||
columnType.extract(buffer)
|
||||
}
|
||||
|
||||
assert(actual === randomRow.get(0, columnType.dataType),
|
||||
columnType.extract(buffer, actual, 0)
|
||||
assert(converter(actual.get(0, dataType)) === converter(randomRow.get(0, dataType)),
|
||||
"Extracted value didn't equal to the original one")
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue