[SPARK-10917] [SQL] improve performance of complex type in columnar cache

This PR improve the performance of complex types in columnar cache by using UnsafeProjection instead of KryoSerializer.

A simple benchmark show that this PR could improve the performance of scanning a cached table with complex columns by 15x (comparing to Spark 1.5).

Here is the code used to benchmark:

```
df = sc.range(1<<23).map(lambda i: Row(a=Row(b=i, c=str(i)), d=range(10), e=dict(zip(range(10), [str(i) for i in range(10)])))).toDF()
df.write.parquet("table")
```
```
df = sqlContext.read.parquet("table")
df.cache()
df.count()
t = time.time()
print df.select("*")._jdf.queryExecution().toRdd().count()
print time.time() - t
```

Author: Davies Liu <davies@databricks.com>

Closes #8971 from davies/complex.
This commit is contained in:
Davies Liu 2015-10-07 15:58:07 -07:00 committed by Davies Liu
parent dd36ec6bc5
commit 075a0b6582
12 changed files with 350 additions and 264 deletions

View file

@ -23,7 +23,6 @@ import java.math.BigInteger;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;

View file

@ -48,6 +48,11 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte
}
object ArrayBasedMapData {
def apply(map: Map[Any, Any]): ArrayBasedMapData = {
val array = map.toArray
ArrayBasedMapData(array.map(_._1), array.map(_._2))
}
def apply(keys: Array[Any], values: Array[Any]): ArrayBasedMapData = {
new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values))
}

View file

@ -19,6 +19,7 @@ package org.apache.spark.sql.columnar
import java.nio.{ByteBuffer, ByteOrder}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor
import org.apache.spark.sql.types._
@ -61,6 +62,10 @@ private[sql] abstract class BasicColumnAccessor[JvmType](
protected def underlyingBuffer = buffer
}
private[sql] class NullColumnAccess(buffer: ByteBuffer)
extends BasicColumnAccessor[Any](buffer, NULL)
with NullableColumnAccessor
private[sql] abstract class NativeColumnAccessor[T <: AtomicType](
override protected val buffer: ByteBuffer,
override protected val columnType: NativeColumnType[T])
@ -96,11 +101,23 @@ private[sql] class BinaryColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[Array[Byte]](buffer, BINARY)
with NullableColumnAccessor
private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int)
extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale))
private[sql] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType)
extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType))
private[sql] class GenericColumnAccessor(buffer: ByteBuffer, dataType: DataType)
extends BasicColumnAccessor[Array[Byte]](buffer, GENERIC(dataType))
private[sql] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType)
extends BasicColumnAccessor[Decimal](buffer, LARGE_DECIMAL(dataType))
with NullableColumnAccessor
private[sql] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType)
extends BasicColumnAccessor[InternalRow](buffer, STRUCT(dataType))
with NullableColumnAccessor
private[sql] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType)
extends BasicColumnAccessor[ArrayData](buffer, ARRAY(dataType))
with NullableColumnAccessor
private[sql] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType)
extends BasicColumnAccessor[MapData](buffer, MAP(dataType))
with NullableColumnAccessor
private[sql] object ColumnAccessor {
@ -108,6 +125,7 @@ private[sql] object ColumnAccessor {
val buf = buffer.order(ByteOrder.nativeOrder)
dataType match {
case NullType => new NullColumnAccess(buf)
case BooleanType => new BooleanColumnAccessor(buf)
case ByteType => new ByteColumnAccessor(buf)
case ShortType => new ShortColumnAccessor(buf)
@ -117,9 +135,15 @@ private[sql] object ColumnAccessor {
case DoubleType => new DoubleColumnAccessor(buf)
case StringType => new StringColumnAccessor(buf)
case BinaryType => new BinaryColumnAccessor(buf)
case DecimalType.Fixed(precision, scale) if precision < 19 =>
new FixedDecimalColumnAccessor(buf, precision, scale)
case other => new GenericColumnAccessor(buf, other)
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
new CompactDecimalColumnAccessor(buf, dt)
case dt: DecimalType => new DecimalColumnAccessor(buf, dt)
case struct: StructType => new StructColumnAccessor(buf, struct)
case array: ArrayType => new ArrayColumnAccessor(buf, array)
case map: MapType => new MapColumnAccessor(buf, map)
case udt: UserDefinedType[_] => ColumnAccessor(udt.sqlType, buffer)
case other =>
throw new Exception(s"not support type: $other")
}
}
}

View file

@ -77,6 +77,10 @@ private[sql] class BasicColumnBuilder[JvmType](
}
}
private[sql] class NullColumnBuilder
extends BasicColumnBuilder[Any](new ObjectColumnStats(NullType), NULL)
with NullableColumnBuilder
private[sql] abstract class ComplexColumnBuilder[JvmType](
columnStats: ColumnStats,
columnType: ColumnType[JvmType])
@ -109,16 +113,20 @@ private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringCol
private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY)
private[sql] class FixedDecimalColumnBuilder(
precision: Int,
scale: Int)
extends NativeColumnBuilder(
new FixedDecimalColumnStats(precision, scale),
FIXED_DECIMAL(precision, scale))
private[sql] class CompactDecimalColumnBuilder(dataType: DecimalType)
extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType))
// TODO (lian) Add support for array, struct and map
private[sql] class GenericColumnBuilder(dataType: DataType)
extends ComplexColumnBuilder(new GenericColumnStats(dataType), GENERIC(dataType))
private[sql] class DecimalColumnBuilder(dataType: DecimalType)
extends ComplexColumnBuilder(new DecimalColumnStats(dataType), LARGE_DECIMAL(dataType))
private[sql] class StructColumnBuilder(dataType: StructType)
extends ComplexColumnBuilder(new ObjectColumnStats(dataType), STRUCT(dataType))
private[sql] class ArrayColumnBuilder(dataType: ArrayType)
extends ComplexColumnBuilder(new ObjectColumnStats(dataType), ARRAY(dataType))
private[sql] class MapColumnBuilder(dataType: MapType)
extends ComplexColumnBuilder(new ObjectColumnStats(dataType), MAP(dataType))
private[sql] object ColumnBuilder {
val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024
@ -145,6 +153,7 @@ private[sql] object ColumnBuilder {
columnName: String = "",
useCompression: Boolean = false): ColumnBuilder = {
val builder: ColumnBuilder = dataType match {
case NullType => new NullColumnBuilder
case BooleanType => new BooleanColumnBuilder
case ByteType => new ByteColumnBuilder
case ShortType => new ShortColumnBuilder
@ -154,9 +163,16 @@ private[sql] object ColumnBuilder {
case DoubleType => new DoubleColumnBuilder
case StringType => new StringColumnBuilder
case BinaryType => new BinaryColumnBuilder
case DecimalType.Fixed(precision, scale) if precision < 19 =>
new FixedDecimalColumnBuilder(precision, scale)
case other => new GenericColumnBuilder(other)
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
new CompactDecimalColumnBuilder(dt)
case dt: DecimalType => new DecimalColumnBuilder(dt)
case struct: StructType => new StructColumnBuilder(struct)
case array: ArrayType => new ArrayColumnBuilder(array)
case map: MapType => new MapColumnBuilder(map)
case udt: UserDefinedType[_] =>
return apply(udt.sqlType, initialSize, columnName, useCompression)
case other =>
throw new Exception(s"not suppported type: $other")
}
builder.initialize(initialSize, columnName, useCompression)

View file

@ -235,7 +235,9 @@ private[sql] class BinaryColumnStats extends ColumnStats {
new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes))
}
private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
private[sql] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
def this(dt: DecimalType) = this(dt.precision, dt.scale)
protected var upper: Decimal = null
protected var lower: Decimal = null
@ -245,7 +247,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C
val value = row.getDecimal(ordinal, precision, scale)
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
sizeInBytes += FIXED_DECIMAL.defaultSize
// TODO: this is not right for DecimalType with precision > 18
sizeInBytes += 8
}
}
@ -253,8 +256,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C
new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
}
private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats {
val columnType = GENERIC(dataType)
private[sql] class ObjectColumnStats(dataType: DataType) extends ColumnStats {
val columnType = ColumnType(dataType)
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)

View file

@ -17,14 +17,15 @@
package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
import java.math.{BigDecimal, BigInteger}
import java.nio.{ByteOrder, ByteBuffer}
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.UTF8String
/**
@ -102,6 +103,16 @@ private[sql] sealed abstract class ColumnType[JvmType] {
override def toString: String = getClass.getSimpleName.stripSuffix("$")
}
private[sql] object NULL extends ColumnType[Any] {
override def dataType: DataType = NullType
override def defaultSize: Int = 0
override def append(v: Any, buffer: ByteBuffer): Unit = {}
override def extract(buffer: ByteBuffer): Any = null
override def setField(row: MutableRow, ordinal: Int, value: Any): Unit = row.setNullAt(ordinal)
override def getField(row: InternalRow, ordinal: Int): Any = null
}
private[sql] abstract class NativeColumnType[T <: AtomicType](
val dataType: T,
val defaultSize: Int)
@ -339,10 +350,8 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) {
override def clone(v: UTF8String): UTF8String = v.clone()
}
private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
extends NativeColumnType(
DecimalType(precision, scale),
FIXED_DECIMAL.defaultSize) {
private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int)
extends NativeColumnType(DecimalType(precision, scale), 8) {
override def extract(buffer: ByteBuffer): Decimal = {
Decimal(buffer.getLong(), precision, scale)
@ -365,32 +374,39 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
}
}
private[sql] object FIXED_DECIMAL {
val defaultSize = 8
private[sql] object COMPACT_DECIMAL {
def apply(dt: DecimalType): COMPACT_DECIMAL = {
COMPACT_DECIMAL(dt.precision, dt.scale)
}
}
private[sql] sealed abstract class ByteArrayColumnType(val defaultSize: Int)
extends ColumnType[Array[Byte]] {
private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int)
extends ColumnType[JvmType] {
def serialize(value: JvmType): Array[Byte]
def deserialize(bytes: Array[Byte]): JvmType
override def actualSize(row: InternalRow, ordinal: Int): Int = {
getField(row, ordinal).length + 4
// TODO: grow the buffer in append(), so serialize() will not be called twice
serialize(getField(row, ordinal)).length + 4
}
override def append(v: Array[Byte], buffer: ByteBuffer): Unit = {
buffer.putInt(v.length).put(v, 0, v.length)
override def append(v: JvmType, buffer: ByteBuffer): Unit = {
val bytes = serialize(v)
buffer.putInt(bytes.length).put(bytes, 0, bytes.length)
}
override def extract(buffer: ByteBuffer): Array[Byte] = {
override def extract(buffer: ByteBuffer): JvmType = {
val length = buffer.getInt()
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
bytes
deserialize(bytes)
}
}
private[sql] object BINARY extends ByteArrayColumnType(16) {
private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) {
def dataType: DataType = BooleanType
def dataType: DataType = BinaryType
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
row.update(ordinal, value)
@ -399,24 +415,164 @@ private[sql] object BINARY extends ByteArrayColumnType(16) {
override def getField(row: InternalRow, ordinal: Int): Array[Byte] = {
row.getBinary(ordinal)
}
def serialize(value: Array[Byte]): Array[Byte] = value
def deserialize(bytes: Array[Byte]): Array[Byte] = bytes
}
// Used to process generic objects (all types other than those listed above). Objects should be
// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized
// byte array.
private[sql] case class GENERIC(dataType: DataType) extends ByteArrayColumnType(16) {
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
row.update(ordinal, SparkSqlSerializer.deserialize[Any](value))
private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int)
extends ByteArrayColumnType[Decimal](12) {
override val dataType: DataType = DecimalType(precision, scale)
override def getField(row: InternalRow, ordinal: Int): Decimal = {
row.getDecimal(ordinal, precision, scale)
}
override def getField(row: InternalRow, ordinal: Int): Array[Byte] = {
SparkSqlSerializer.serialize(row.get(ordinal, dataType))
override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = {
row.setDecimal(ordinal, value, precision)
}
override def serialize(value: Decimal): Array[Byte] = {
value.toJavaBigDecimal.unscaledValue().toByteArray
}
override def deserialize(bytes: Array[Byte]): Decimal = {
val javaDecimal = new BigDecimal(new BigInteger(bytes), scale)
Decimal.apply(javaDecimal, precision, scale)
}
}
private[sql] object LARGE_DECIMAL {
def apply(dt: DecimalType): LARGE_DECIMAL = {
LARGE_DECIMAL(dt.precision, dt.scale)
}
}
private[sql] case class STRUCT(dataType: StructType)
extends ByteArrayColumnType[InternalRow](20) {
private val projection: UnsafeProjection =
UnsafeProjection.create(dataType)
private val numOfFields: Int = dataType.fields.size
override def setField(row: MutableRow, ordinal: Int, value: InternalRow): Unit = {
row.update(ordinal, value)
}
override def getField(row: InternalRow, ordinal: Int): InternalRow = {
row.getStruct(ordinal, numOfFields)
}
override def serialize(value: InternalRow): Array[Byte] = {
val unsafeRow = if (value.isInstanceOf[UnsafeRow]) {
value.asInstanceOf[UnsafeRow]
} else {
projection(value)
}
unsafeRow.getBytes
}
override def deserialize(bytes: Array[Byte]): InternalRow = {
val unsafeRow = new UnsafeRow
unsafeRow.pointTo(bytes, numOfFields, bytes.length)
unsafeRow
}
override def clone(v: InternalRow): InternalRow = v.copy()
}
private[sql] case class ARRAY(dataType: ArrayType)
extends ByteArrayColumnType[ArrayData](16) {
private lazy val projection = UnsafeProjection.create(Array[DataType](dataType))
private val mutableRow = new GenericMutableRow(new Array[Any](1))
override def setField(row: MutableRow, ordinal: Int, value: ArrayData): Unit = {
row.update(ordinal, value)
}
override def getField(row: InternalRow, ordinal: Int): ArrayData = {
row.getArray(ordinal)
}
override def serialize(value: ArrayData): Array[Byte] = {
val unsafeArray = if (value.isInstanceOf[UnsafeArrayData]) {
value.asInstanceOf[UnsafeArrayData]
} else {
mutableRow(0) = value
projection(mutableRow).getArray(0)
}
val outputBuffer =
ByteBuffer.allocate(4 + unsafeArray.getSizeInBytes).order(ByteOrder.nativeOrder())
outputBuffer.putInt(unsafeArray.numElements())
val underlying = outputBuffer.array()
unsafeArray.writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 4)
underlying
}
override def deserialize(bytes: Array[Byte]): ArrayData = {
val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder())
val numElements = buffer.getInt
val array = new UnsafeArrayData
array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 4, numElements, bytes.length - 4)
array
}
override def clone(v: ArrayData): ArrayData = v.copy()
}
private[sql] case class MAP(dataType: MapType) extends ByteArrayColumnType[MapData](32) {
private lazy val projection: UnsafeProjection = UnsafeProjection.create(Array[DataType](dataType))
private val mutableRow = new GenericMutableRow(new Array[Any](1))
override def setField(row: MutableRow, ordinal: Int, value: MapData): Unit = {
row.update(ordinal, value)
}
override def getField(row: InternalRow, ordinal: Int): MapData = {
row.getMap(ordinal)
}
override def serialize(value: MapData): Array[Byte] = {
val unsafeMap = if (value.isInstanceOf[UnsafeMapData]) {
value.asInstanceOf[UnsafeMapData]
} else {
mutableRow(0) = value
projection(mutableRow).getMap(0)
}
val outputBuffer =
ByteBuffer.allocate(8 + unsafeMap.getSizeInBytes).order(ByteOrder.nativeOrder())
outputBuffer.putInt(unsafeMap.numElements())
val keyBytes = unsafeMap.keyArray().getSizeInBytes
outputBuffer.putInt(keyBytes)
val underlying = outputBuffer.array()
unsafeMap.keyArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8)
unsafeMap.valueArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8 + keyBytes)
underlying
}
override def deserialize(bytes: Array[Byte]): MapData = {
val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder())
val numElements = buffer.getInt
val keyArraySize = buffer.getInt
val keyArray = new UnsafeArrayData
val valueArray = new UnsafeArrayData
keyArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8, numElements, keyArraySize)
valueArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8 + keyArraySize, numElements,
bytes.length - 8 - keyArraySize)
new UnsafeMapData(keyArray, valueArray)
}
override def clone(v: MapData): MapData = v.copy()
}
private[sql] object ColumnType {
def apply(dataType: DataType): ColumnType[_] = {
dataType match {
case NullType => NULL
case BooleanType => BOOLEAN
case ByteType => BYTE
case ShortType => SHORT
@ -426,9 +582,14 @@ private[sql] object ColumnType {
case DoubleType => DOUBLE
case StringType => STRING
case BinaryType => BINARY
case DecimalType.Fixed(precision, scale) if precision < 19 =>
FIXED_DECIMAL(precision, scale)
case other => GENERIC(other)
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => COMPACT_DECIMAL(dt)
case dt: DecimalType => LARGE_DECIMAL(dt)
case arr: ArrayType => ARRAY(arr)
case map: MapType => MAP(map)
case struct: StructType => STRUCT(struct)
case udt: UserDefinedType[_] => apply(udt.sqlType)
case other =>
throw new Exception(s"Unsupported type: $other")
}
}
}

View file

@ -18,7 +18,6 @@
package org.apache.spark.sql.columnar
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.types._
@ -76,11 +75,11 @@ class ColumnStatsSuite extends SparkFunSuite {
def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](
initialStatistics: GenericInternalRow): Unit = {
val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName
val columnType = FIXED_DECIMAL(15, 10)
val columnStatsName = classOf[DecimalColumnStats].getSimpleName
val columnType = COMPACT_DECIMAL(15, 10)
test(s"$columnStatsName: empty") {
val columnStats = new FixedDecimalColumnStats(15, 10)
val columnStats = new DecimalColumnStats(15, 10)
columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
case (actual, expected) => assert(actual === expected)
}
@ -89,7 +88,7 @@ class ColumnStatsSuite extends SparkFunSuite {
test(s"$columnStatsName: non-empty") {
import org.apache.spark.sql.columnar.ColumnarTestUtils._
val columnStats = new FixedDecimalColumnStats(15, 10)
val columnStats = new DecimalColumnStats(15, 10)
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(columnStats.gatherStats(_, 0))

View file

@ -19,28 +19,25 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Kryo, Serializer}
import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.serializer.KryoRegistrator
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.{Logging, SparkFunSuite}
class ColumnTypeSuite extends SparkFunSuite with Logging {
private val DEFAULT_BUFFER_SIZE = 512
private val MAP_GENERIC = GENERIC(MapType(IntegerType, StringType))
private val MAP_TYPE = MAP(MapType(IntegerType, StringType))
private val ARRAY_TYPE = ARRAY(ArrayType(IntegerType))
private val STRUCT_TYPE = STRUCT(StructType(StructField("a", StringType) :: Nil))
test("defaultSize") {
val checks = Map(
BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4,
LONG -> 8, FLOAT -> 4, DOUBLE -> 8,
STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8,
MAP_GENERIC -> 16)
NULL-> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8,
FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12,
STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 16, MAP_TYPE -> 32)
checks.foreach { case (columnType, expectedSize) =>
assertResult(expectedSize, s"Wrong defaultSize for $columnType") {
@ -50,18 +47,19 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
}
test("actualSize") {
def checkActualSize[JvmType](
columnType: ColumnType[JvmType],
value: JvmType,
def checkActualSize(
columnType: ColumnType[_],
value: Any,
expected: Int): Unit = {
assertResult(expected, s"Wrong actualSize for $columnType") {
val row = new GenericMutableRow(1)
columnType.setField(row, 0, value)
row.update(0, CatalystTypeConverters.convertToCatalyst(value))
columnType.actualSize(row, 0)
}
}
checkActualSize(NULL, null, 0)
checkActualSize(BOOLEAN, true, 1)
checkActualSize(BYTE, Byte.MaxValue, 1)
checkActualSize(SHORT, Short.MaxValue, 2)
@ -69,177 +67,66 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
checkActualSize(LONG, Long.MaxValue, 8)
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(DOUBLE, Double.MaxValue, 8)
checkActualSize(STRING, UTF8String.fromString("hello"), 4 + "hello".getBytes("utf-8").length)
checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length)
checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4)
checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
val generic = Map(1 -> "a")
checkActualSize(MAP_GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8)
checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5)
checkActualSize(ARRAY_TYPE, Array[Any](1), 16)
checkActualSize(MAP_TYPE, Map(1 -> "a"), 25)
checkActualSize(STRUCT_TYPE, Row("hello"), 28)
}
testNativeColumnType(BOOLEAN)(
(buffer: ByteBuffer, v: Boolean) => {
buffer.put((if (v) 1 else 0).toByte)
},
(buffer: ByteBuffer) => {
buffer.get() == 1
})
testNativeColumnType(BOOLEAN)
testNativeColumnType(BYTE)
testNativeColumnType(SHORT)
testNativeColumnType(INT)
testNativeColumnType(LONG)
testNativeColumnType(FLOAT)
testNativeColumnType(DOUBLE)
testNativeColumnType(COMPACT_DECIMAL(15, 10))
testNativeColumnType(STRING)
testNativeColumnType(BYTE)(_.put(_), _.get)
testColumnType(NULL)
testColumnType(BINARY)
testColumnType(LARGE_DECIMAL(20, 10))
testColumnType(STRUCT_TYPE)
testColumnType(ARRAY_TYPE)
testColumnType(MAP_TYPE)
testNativeColumnType(SHORT)(_.putShort(_), _.getShort)
testNativeColumnType(INT)(_.putInt(_), _.getInt)
testNativeColumnType(LONG)(_.putLong(_), _.getLong)
testNativeColumnType(FLOAT)(_.putFloat(_), _.getFloat)
testNativeColumnType(DOUBLE)(_.putDouble(_), _.getDouble)
testNativeColumnType(FIXED_DECIMAL(15, 10))(
(buffer: ByteBuffer, decimal: Decimal) => {
buffer.putLong(decimal.toUnscaledLong)
},
(buffer: ByteBuffer) => {
Decimal(buffer.getLong(), 15, 10)
})
testNativeColumnType(STRING)(
(buffer: ByteBuffer, string: UTF8String) => {
val bytes = string.getBytes
buffer.putInt(bytes.length)
buffer.put(bytes)
},
(buffer: ByteBuffer) => {
val length = buffer.getInt()
val bytes = new Array[Byte](length)
buffer.get(bytes)
UTF8String.fromBytes(bytes)
})
testColumnType[Array[Byte]](
BINARY,
(buffer: ByteBuffer, bytes: Array[Byte]) => {
buffer.putInt(bytes.length).put(bytes)
},
(buffer: ByteBuffer) => {
val length = buffer.getInt()
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
bytes
})
test("GENERIC") {
val buffer = ByteBuffer.allocate(512)
val obj = Map(1 -> "spark", 2 -> "sql")
val serializedObj = SparkSqlSerializer.serialize(obj)
MAP_GENERIC.append(SparkSqlSerializer.serialize(obj), buffer)
buffer.rewind()
val length = buffer.getInt()
assert(length === serializedObj.length)
assertResult(obj, "Deserialized object didn't equal to the original object") {
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
SparkSqlSerializer.deserialize(bytes)
def testNativeColumnType[T <: AtomicType](columnType: NativeColumnType[T]): Unit = {
testColumnType[T#InternalType](columnType)
}
buffer.rewind()
buffer.putInt(serializedObj.length).put(serializedObj)
assertResult(obj, "Deserialized object didn't equal to the original object") {
buffer.rewind()
SparkSqlSerializer.deserialize(MAP_GENERIC.extract(buffer))
}
}
test("CUSTOM") {
val conf = new SparkConf()
conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator")
val serializer = new SparkSqlSerializer(conf).newInstance()
val buffer = ByteBuffer.allocate(512)
val obj = CustomClass(Int.MaxValue, Long.MaxValue)
val serializedObj = serializer.serialize(obj).array()
MAP_GENERIC.append(serializer.serialize(obj).array(), buffer)
buffer.rewind()
val length = buffer.getInt
assert(length === serializedObj.length)
assert(13 == length) // id (1) + int (4) + long (8)
val genericSerializedObj = SparkSqlSerializer.serialize(obj)
assert(length != genericSerializedObj.length)
assert(length < genericSerializedObj.length)
assertResult(obj, "Custom deserialized object didn't equal the original object") {
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
serializer.deserialize(ByteBuffer.wrap(bytes))
}
buffer.rewind()
buffer.putInt(serializedObj.length).put(serializedObj)
assertResult(obj, "Custom deserialized object didn't equal the original object") {
buffer.rewind()
serializer.deserialize(ByteBuffer.wrap(MAP_GENERIC.extract(buffer)))
}
}
def testNativeColumnType[T <: AtomicType](
columnType: NativeColumnType[T])
(putter: (ByteBuffer, T#InternalType) => Unit,
getter: (ByteBuffer) => T#InternalType): Unit = {
testColumnType[T#InternalType](columnType, putter, getter)
}
def testColumnType[JvmType](
columnType: ColumnType[JvmType],
putter: (ByteBuffer, JvmType) => Unit,
getter: (ByteBuffer) => JvmType): Unit = {
def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = {
val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE)
val seq = (0 until 4).map(_ => makeRandomValue(columnType))
val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType)
test(s"$columnType.extract") {
test(s"$columnType append/extract") {
buffer.rewind()
seq.foreach(putter(buffer, _))
seq.foreach(columnType.append(_, buffer))
buffer.rewind()
seq.foreach { expected =>
logInfo("buffer = " + buffer + ", expected = " + expected)
val extracted = columnType.extract(buffer)
assert(
expected === extracted,
converter(expected) === converter(extracted),
"Extracted value didn't equal to the original one. " +
hexDump(expected) + " != " + hexDump(extracted) +
", buffer = " + dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer]))
}
}
test(s"$columnType.append") {
buffer.rewind()
seq.foreach(columnType.append(_, buffer))
buffer.rewind()
seq.foreach { expected =>
assert(
expected === getter(buffer),
"Extracted value didn't equal to the original one")
}
}
}
private def hexDump(value: Any): String = {
if (value == null) {
""
} else {
value.toString.map(ch => Integer.toHexString(ch & 0xffff)).mkString(" ")
}
}
private def dumpBuffer(buff: ByteBuffer): Any = {
val sb = new StringBuilder()
@ -253,33 +140,13 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
test("column type for decimal types with different precision") {
(1 to 18).foreach { i =>
assertResult(FIXED_DECIMAL(i, 0)) {
assertResult(COMPACT_DECIMAL(i, 0)) {
ColumnType(DecimalType(i, 0))
}
}
assertResult(GENERIC(DecimalType(19, 0))) {
assertResult(LARGE_DECIMAL(19, 0)) {
ColumnType(DecimalType(19, 0))
}
}
}
private[columnar] final case class CustomClass(a: Int, b: Long)
private[columnar] object CustomerSerializer extends Serializer[CustomClass] {
override def write(kryo: Kryo, output: Output, t: CustomClass) {
output.writeInt(t.a)
output.writeLong(t.b)
}
override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = {
val a = input.readInt()
val b = input.readLong()
CustomClass(a, b)
}
}
private[columnar] final class Registrator extends KryoRegistrator {
override def registerClasses(kryo: Kryo) {
kryo.register(classOf[CustomClass], CustomerSerializer)
}
}

View file

@ -21,8 +21,8 @@ import scala.collection.immutable.HashSet
import scala.util.Random
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types.{AtomicType, Decimal}
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow}
import org.apache.spark.sql.types.{ArrayBasedMapData, GenericArrayData, AtomicType, Decimal}
import org.apache.spark.unsafe.types.UTF8String
object ColumnarTestUtils {
@ -40,6 +40,7 @@ object ColumnarTestUtils {
}
(columnType match {
case NULL => null
case BOOLEAN => Random.nextBoolean()
case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
@ -49,10 +50,15 @@ object ColumnarTestUtils {
case DOUBLE => Random.nextDouble()
case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
case BINARY => randomBytes(Random.nextInt(32))
case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
case _ =>
// Using a random one-element map instead of an arbitrary object
Map(Random.nextInt() -> Random.nextString(Random.nextInt(32)))
case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale)
case STRUCT(_) =>
new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10))))
case ARRAY(_) =>
new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt()))
case MAP(_) =>
ArrayBasedMapData(
Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32)))))
}).asInstanceOf[JvmType]
}

View file

@ -157,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
// Create a RDD for the schema
val rdd =
sparkContext.parallelize((1 to 100), 10).map { i =>
sparkContext.parallelize((1 to 10000), 10).map { i =>
Row(
s"str${i}: test cache.",
s"binary${i}: test cache.".getBytes("UTF-8"),
@ -172,9 +172,9 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
BigDecimal(Long.MaxValue.toString + ".12345"),
new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"),
new Date(i),
new Timestamp(i),
(1 to i).toSeq,
(0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
new Timestamp(i * 1000000L),
(i to i + 10).toSeq,
(i to i + 10).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
Row((i - 0.25).toFloat, Seq(true, false, null)))
}
sqlContext.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")

View file

@ -20,8 +20,9 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types.{ArrayType, StringType}
import org.apache.spark.sql.types._
class TestNullableColumnAccessor[JvmType](
buffer: ByteBuffer,
@ -40,8 +41,10 @@ class NullableColumnAccessorSuite extends SparkFunSuite {
import org.apache.spark.sql.columnar.ColumnarTestUtils._
Seq(
BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType)))
NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
STRUCT(StructType(StructField("a", StringType) :: Nil)),
ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType)))
.foreach {
testNullableColumnAccessor(_)
}
@ -69,11 +72,13 @@ class NullableColumnAccessorSuite extends SparkFunSuite {
val accessor = TestNullableColumnAccessor(builder.build(), columnType)
val row = new GenericMutableRow(1)
val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType)
(0 until 4).foreach { _ =>
assert(accessor.hasNext)
accessor.extractTo(row, 0)
assert(row.get(0, columnType.dataType) === randomRow.get(0, columnType.dataType))
assert(converter(row.get(0, columnType.dataType))
=== converter(randomRow.get(0, columnType.dataType)))
assert(accessor.hasNext)
accessor.extractTo(row, 0)

View file

@ -18,7 +18,8 @@
package org.apache.spark.sql.columnar
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types._
class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType])
@ -35,11 +36,13 @@ object TestNullableColumnBuilder {
}
class NullableColumnBuilderSuite extends SparkFunSuite {
import ColumnarTestUtils._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
Seq(
BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType)))
STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
STRUCT(StructType(StructField("a", StringType) :: Nil)),
ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType)))
.foreach {
testNullableColumnBuilder(_)
}
@ -74,6 +77,8 @@ class NullableColumnBuilderSuite extends SparkFunSuite {
val columnBuilder = TestNullableColumnBuilder(columnType)
val randomRow = makeRandomRow(columnType)
val nullRow = makeNullRow(1)
val dataType = columnType.dataType
val converter = CatalystTypeConverters.createToScalaConverter(dataType)
(0 until 4).foreach { _ =>
columnBuilder.appendFrom(randomRow, 0)
@ -88,14 +93,10 @@ class NullableColumnBuilderSuite extends SparkFunSuite {
(1 to 7 by 2).foreach(assertResult(_, "Wrong null position")(buffer.getInt()))
// For non-null values
val actual = new GenericMutableRow(new Array[Any](1))
(0 until 4).foreach { _ =>
val actual = if (columnType.isInstanceOf[GENERIC]) {
SparkSqlSerializer.deserialize[Any](columnType.extract(buffer).asInstanceOf[Array[Byte]])
} else {
columnType.extract(buffer)
}
assert(actual === randomRow.get(0, columnType.dataType),
columnType.extract(buffer, actual, 0)
assert(converter(actual.get(0, dataType)) === converter(randomRow.get(0, dataType)),
"Extracted value didn't equal to the original one")
}