diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 6a16d34083..fdd9125613 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -23,7 +23,6 @@ import java.math.BigInteger; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala index f6fa021ade..52069598ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala @@ -48,6 +48,11 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte } object ArrayBasedMapData { + def apply(map: Map[Any, Any]): ArrayBasedMapData = { + val array = map.toArray + ArrayBasedMapData(array.map(_._1), array.map(_._2)) + } + def apply(keys: Array[Any], values: Array[Any]): ArrayBasedMapData = { new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 2b1d700987..f04099f54c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.columnar import java.nio.{ByteBuffer, ByteOrder} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor import org.apache.spark.sql.types._ @@ -61,6 +62,10 @@ private[sql] abstract class BasicColumnAccessor[JvmType]( protected def underlyingBuffer = buffer } +private[sql] class NullColumnAccess(buffer: ByteBuffer) + extends BasicColumnAccessor[Any](buffer, NULL) + with NullableColumnAccessor + private[sql] abstract class NativeColumnAccessor[T <: AtomicType]( override protected val buffer: ByteBuffer, override protected val columnType: NativeColumnType[T]) @@ -96,11 +101,23 @@ private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[Array[Byte]](buffer, BINARY) with NullableColumnAccessor -private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int) - extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale)) +private[sql] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) + extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType)) -private[sql] class GenericColumnAccessor(buffer: ByteBuffer, dataType: DataType) - extends BasicColumnAccessor[Array[Byte]](buffer, GENERIC(dataType)) +private[sql] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) + extends BasicColumnAccessor[Decimal](buffer, LARGE_DECIMAL(dataType)) + with NullableColumnAccessor + +private[sql] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType) + extends BasicColumnAccessor[InternalRow](buffer, STRUCT(dataType)) + with NullableColumnAccessor + +private[sql] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType) + extends BasicColumnAccessor[ArrayData](buffer, ARRAY(dataType)) + with NullableColumnAccessor + +private[sql] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) + extends BasicColumnAccessor[MapData](buffer, MAP(dataType)) with NullableColumnAccessor private[sql] object ColumnAccessor { @@ -108,6 +125,7 @@ private[sql] object ColumnAccessor { val buf = buffer.order(ByteOrder.nativeOrder) dataType match { + case NullType => new NullColumnAccess(buf) case BooleanType => new BooleanColumnAccessor(buf) case ByteType => new ByteColumnAccessor(buf) case ShortType => new ShortColumnAccessor(buf) @@ -117,9 +135,15 @@ private[sql] object ColumnAccessor { case DoubleType => new DoubleColumnAccessor(buf) case StringType => new StringColumnAccessor(buf) case BinaryType => new BinaryColumnAccessor(buf) - case DecimalType.Fixed(precision, scale) if precision < 19 => - new FixedDecimalColumnAccessor(buf, precision, scale) - case other => new GenericColumnAccessor(buf, other) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + new CompactDecimalColumnAccessor(buf, dt) + case dt: DecimalType => new DecimalColumnAccessor(buf, dt) + case struct: StructType => new StructColumnAccessor(buf, struct) + case array: ArrayType => new ArrayColumnAccessor(buf, array) + case map: MapType => new MapColumnAccessor(buf, map) + case udt: UserDefinedType[_] => ColumnAccessor(udt.sqlType, buffer) + case other => + throw new Exception(s"not support type: $other") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 2e60564f7c..7a7345a7e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -77,6 +77,10 @@ private[sql] class BasicColumnBuilder[JvmType]( } } +private[sql] class NullColumnBuilder + extends BasicColumnBuilder[Any](new ObjectColumnStats(NullType), NULL) + with NullableColumnBuilder + private[sql] abstract class ComplexColumnBuilder[JvmType]( columnStats: ColumnStats, columnType: ColumnType[JvmType]) @@ -109,16 +113,20 @@ private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringCol private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) -private[sql] class FixedDecimalColumnBuilder( - precision: Int, - scale: Int) - extends NativeColumnBuilder( - new FixedDecimalColumnStats(precision, scale), - FIXED_DECIMAL(precision, scale)) +private[sql] class CompactDecimalColumnBuilder(dataType: DecimalType) + extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType)) -// TODO (lian) Add support for array, struct and map -private[sql] class GenericColumnBuilder(dataType: DataType) - extends ComplexColumnBuilder(new GenericColumnStats(dataType), GENERIC(dataType)) +private[sql] class DecimalColumnBuilder(dataType: DecimalType) + extends ComplexColumnBuilder(new DecimalColumnStats(dataType), LARGE_DECIMAL(dataType)) + +private[sql] class StructColumnBuilder(dataType: StructType) + extends ComplexColumnBuilder(new ObjectColumnStats(dataType), STRUCT(dataType)) + +private[sql] class ArrayColumnBuilder(dataType: ArrayType) + extends ComplexColumnBuilder(new ObjectColumnStats(dataType), ARRAY(dataType)) + +private[sql] class MapColumnBuilder(dataType: MapType) + extends ComplexColumnBuilder(new ObjectColumnStats(dataType), MAP(dataType)) private[sql] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024 @@ -145,6 +153,7 @@ private[sql] object ColumnBuilder { columnName: String = "", useCompression: Boolean = false): ColumnBuilder = { val builder: ColumnBuilder = dataType match { + case NullType => new NullColumnBuilder case BooleanType => new BooleanColumnBuilder case ByteType => new ByteColumnBuilder case ShortType => new ShortColumnBuilder @@ -154,9 +163,16 @@ private[sql] object ColumnBuilder { case DoubleType => new DoubleColumnBuilder case StringType => new StringColumnBuilder case BinaryType => new BinaryColumnBuilder - case DecimalType.Fixed(precision, scale) if precision < 19 => - new FixedDecimalColumnBuilder(precision, scale) - case other => new GenericColumnBuilder(other) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + new CompactDecimalColumnBuilder(dt) + case dt: DecimalType => new DecimalColumnBuilder(dt) + case struct: StructType => new StructColumnBuilder(struct) + case array: ArrayType => new ArrayColumnBuilder(array) + case map: MapType => new MapColumnBuilder(map) + case udt: UserDefinedType[_] => + return apply(udt.sqlType, initialSize, columnName, useCompression) + case other => + throw new Exception(s"not suppported type: $other") } builder.initialize(initialSize, columnName, useCompression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 3b5052b754..ba61003ba4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -235,7 +235,9 @@ private[sql] class BinaryColumnStats extends ColumnStats { new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) } -private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { +private[sql] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { + def this(dt: DecimalType) = this(dt.precision, dt.scale) + protected var upper: Decimal = null protected var lower: Decimal = null @@ -245,7 +247,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C val value = row.getDecimal(ordinal, precision, scale) if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value - sizeInBytes += FIXED_DECIMAL.defaultSize + // TODO: this is not right for DecimalType with precision > 18 + sizeInBytes += 8 } } @@ -253,8 +256,8 @@ private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends C new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) } -private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { - val columnType = GENERIC(dataType) +private[sql] class ObjectColumnStats(dataType: DataType) extends ColumnStats { + val columnType = ColumnType(dataType) override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 3a0cea8750..3563eacb3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.columnar -import java.nio.ByteBuffer +import java.math.{BigDecimal, BigInteger} +import java.nio.{ByteOrder, ByteBuffer} import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.UTF8String /** @@ -102,6 +103,16 @@ private[sql] sealed abstract class ColumnType[JvmType] { override def toString: String = getClass.getSimpleName.stripSuffix("$") } +private[sql] object NULL extends ColumnType[Any] { + + override def dataType: DataType = NullType + override def defaultSize: Int = 0 + override def append(v: Any, buffer: ByteBuffer): Unit = {} + override def extract(buffer: ByteBuffer): Any = null + override def setField(row: MutableRow, ordinal: Int, value: Any): Unit = row.setNullAt(ordinal) + override def getField(row: InternalRow, ordinal: Int): Any = null +} + private[sql] abstract class NativeColumnType[T <: AtomicType]( val dataType: T, val defaultSize: Int) @@ -339,10 +350,8 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) { override def clone(v: UTF8String): UTF8String = v.clone() } -private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) - extends NativeColumnType( - DecimalType(precision, scale), - FIXED_DECIMAL.defaultSize) { +private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) + extends NativeColumnType(DecimalType(precision, scale), 8) { override def extract(buffer: ByteBuffer): Decimal = { Decimal(buffer.getLong(), precision, scale) @@ -365,32 +374,39 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) } } -private[sql] object FIXED_DECIMAL { - val defaultSize = 8 +private[sql] object COMPACT_DECIMAL { + def apply(dt: DecimalType): COMPACT_DECIMAL = { + COMPACT_DECIMAL(dt.precision, dt.scale) + } } -private[sql] sealed abstract class ByteArrayColumnType(val defaultSize: Int) - extends ColumnType[Array[Byte]] { +private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int) + extends ColumnType[JvmType] { + + def serialize(value: JvmType): Array[Byte] + def deserialize(bytes: Array[Byte]): JvmType override def actualSize(row: InternalRow, ordinal: Int): Int = { - getField(row, ordinal).length + 4 + // TODO: grow the buffer in append(), so serialize() will not be called twice + serialize(getField(row, ordinal)).length + 4 } - override def append(v: Array[Byte], buffer: ByteBuffer): Unit = { - buffer.putInt(v.length).put(v, 0, v.length) + override def append(v: JvmType, buffer: ByteBuffer): Unit = { + val bytes = serialize(v) + buffer.putInt(bytes.length).put(bytes, 0, bytes.length) } - override def extract(buffer: ByteBuffer): Array[Byte] = { + override def extract(buffer: ByteBuffer): JvmType = { val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes, 0, length) - bytes + deserialize(bytes) } } -private[sql] object BINARY extends ByteArrayColumnType(16) { +private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { - def dataType: DataType = BooleanType + def dataType: DataType = BinaryType override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row.update(ordinal, value) @@ -399,24 +415,164 @@ private[sql] object BINARY extends ByteArrayColumnType(16) { override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { row.getBinary(ordinal) } + + def serialize(value: Array[Byte]): Array[Byte] = value + def deserialize(bytes: Array[Byte]): Array[Byte] = bytes } -// Used to process generic objects (all types other than those listed above). Objects should be -// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized -// byte array. -private[sql] case class GENERIC(dataType: DataType) extends ByteArrayColumnType(16) { - override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { - row.update(ordinal, SparkSqlSerializer.deserialize[Any](value)) +private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int) + extends ByteArrayColumnType[Decimal](12) { + + override val dataType: DataType = DecimalType(precision, scale) + + override def getField(row: InternalRow, ordinal: Int): Decimal = { + row.getDecimal(ordinal, precision, scale) } - override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { - SparkSqlSerializer.serialize(row.get(ordinal, dataType)) + override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { + row.setDecimal(ordinal, value, precision) } + + override def serialize(value: Decimal): Array[Byte] = { + value.toJavaBigDecimal.unscaledValue().toByteArray + } + + override def deserialize(bytes: Array[Byte]): Decimal = { + val javaDecimal = new BigDecimal(new BigInteger(bytes), scale) + Decimal.apply(javaDecimal, precision, scale) + } +} + +private[sql] object LARGE_DECIMAL { + def apply(dt: DecimalType): LARGE_DECIMAL = { + LARGE_DECIMAL(dt.precision, dt.scale) + } +} + +private[sql] case class STRUCT(dataType: StructType) + extends ByteArrayColumnType[InternalRow](20) { + + private val projection: UnsafeProjection = + UnsafeProjection.create(dataType) + private val numOfFields: Int = dataType.fields.size + + override def setField(row: MutableRow, ordinal: Int, value: InternalRow): Unit = { + row.update(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): InternalRow = { + row.getStruct(ordinal, numOfFields) + } + + override def serialize(value: InternalRow): Array[Byte] = { + val unsafeRow = if (value.isInstanceOf[UnsafeRow]) { + value.asInstanceOf[UnsafeRow] + } else { + projection(value) + } + unsafeRow.getBytes + } + + override def deserialize(bytes: Array[Byte]): InternalRow = { + val unsafeRow = new UnsafeRow + unsafeRow.pointTo(bytes, numOfFields, bytes.length) + unsafeRow + } + + override def clone(v: InternalRow): InternalRow = v.copy() +} + +private[sql] case class ARRAY(dataType: ArrayType) + extends ByteArrayColumnType[ArrayData](16) { + + private lazy val projection = UnsafeProjection.create(Array[DataType](dataType)) + private val mutableRow = new GenericMutableRow(new Array[Any](1)) + + override def setField(row: MutableRow, ordinal: Int, value: ArrayData): Unit = { + row.update(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): ArrayData = { + row.getArray(ordinal) + } + + override def serialize(value: ArrayData): Array[Byte] = { + val unsafeArray = if (value.isInstanceOf[UnsafeArrayData]) { + value.asInstanceOf[UnsafeArrayData] + } else { + mutableRow(0) = value + projection(mutableRow).getArray(0) + } + val outputBuffer = + ByteBuffer.allocate(4 + unsafeArray.getSizeInBytes).order(ByteOrder.nativeOrder()) + outputBuffer.putInt(unsafeArray.numElements()) + val underlying = outputBuffer.array() + unsafeArray.writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 4) + underlying + } + + override def deserialize(bytes: Array[Byte]): ArrayData = { + val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()) + val numElements = buffer.getInt + val array = new UnsafeArrayData + array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 4, numElements, bytes.length - 4) + array + } + + override def clone(v: ArrayData): ArrayData = v.copy() +} + +private[sql] case class MAP(dataType: MapType) extends ByteArrayColumnType[MapData](32) { + + private lazy val projection: UnsafeProjection = UnsafeProjection.create(Array[DataType](dataType)) + private val mutableRow = new GenericMutableRow(new Array[Any](1)) + + override def setField(row: MutableRow, ordinal: Int, value: MapData): Unit = { + row.update(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): MapData = { + row.getMap(ordinal) + } + + override def serialize(value: MapData): Array[Byte] = { + val unsafeMap = if (value.isInstanceOf[UnsafeMapData]) { + value.asInstanceOf[UnsafeMapData] + } else { + mutableRow(0) = value + projection(mutableRow).getMap(0) + } + + val outputBuffer = + ByteBuffer.allocate(8 + unsafeMap.getSizeInBytes).order(ByteOrder.nativeOrder()) + outputBuffer.putInt(unsafeMap.numElements()) + val keyBytes = unsafeMap.keyArray().getSizeInBytes + outputBuffer.putInt(keyBytes) + val underlying = outputBuffer.array() + unsafeMap.keyArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8) + unsafeMap.valueArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8 + keyBytes) + underlying + } + + override def deserialize(bytes: Array[Byte]): MapData = { + val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()) + val numElements = buffer.getInt + val keyArraySize = buffer.getInt + val keyArray = new UnsafeArrayData + val valueArray = new UnsafeArrayData + keyArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8, numElements, keyArraySize) + valueArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8 + keyArraySize, numElements, + bytes.length - 8 - keyArraySize) + new UnsafeMapData(keyArray, valueArray) + } + + override def clone(v: MapData): MapData = v.copy() } private[sql] object ColumnType { def apply(dataType: DataType): ColumnType[_] = { dataType match { + case NullType => NULL case BooleanType => BOOLEAN case ByteType => BYTE case ShortType => SHORT @@ -426,9 +582,14 @@ private[sql] object ColumnType { case DoubleType => DOUBLE case StringType => STRING case BinaryType => BINARY - case DecimalType.Fixed(precision, scale) if precision < 19 => - FIXED_DECIMAL(precision, scale) - case other => GENERIC(other) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => COMPACT_DECIMAL(dt) + case dt: DecimalType => LARGE_DECIMAL(dt) + case arr: ArrayType => ARRAY(arr) + case map: MapType => MAP(map) + case struct: StructType => STRUCT(struct) + case udt: UserDefinedType[_] => apply(udt.sqlType) + case other => + throw new Exception(s"Unsupported type: $other") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 708fb4cf79..89a664001b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ @@ -76,11 +75,11 @@ class ColumnStatsSuite extends SparkFunSuite { def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( initialStatistics: GenericInternalRow): Unit = { - val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName - val columnType = FIXED_DECIMAL(15, 10) + val columnStatsName = classOf[DecimalColumnStats].getSimpleName + val columnType = COMPACT_DECIMAL(15, 10) test(s"$columnStatsName: empty") { - val columnStats = new FixedDecimalColumnStats(15, 10) + val columnStats = new DecimalColumnStats(15, 10) columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { case (actual, expected) => assert(actual === expected) } @@ -89,7 +88,7 @@ class ColumnStatsSuite extends SparkFunSuite { test(s"$columnStatsName: non-empty") { import org.apache.spark.sql.columnar.ColumnarTestUtils._ - val columnStats = new FixedDecimalColumnStats(15, 10) + val columnStats = new DecimalColumnStats(15, 10) val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index a4cbe3525e..ceb8ad97bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -19,28 +19,25 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import com.esotericsoftware.kryo.io.{Input, Output} -import com.esotericsoftware.kryo.{Kryo, Serializer} - -import org.apache.spark.{Logging, SparkConf, SparkFunSuite} -import org.apache.spark.serializer.KryoRegistrator +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{Logging, SparkFunSuite} class ColumnTypeSuite extends SparkFunSuite with Logging { private val DEFAULT_BUFFER_SIZE = 512 - private val MAP_GENERIC = GENERIC(MapType(IntegerType, StringType)) + private val MAP_TYPE = MAP(MapType(IntegerType, StringType)) + private val ARRAY_TYPE = ARRAY(ArrayType(IntegerType)) + private val STRUCT_TYPE = STRUCT(StructType(StructField("a", StringType) :: Nil)) test("defaultSize") { val checks = Map( - BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, - LONG -> 8, FLOAT -> 4, DOUBLE -> 8, - STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, - MAP_GENERIC -> 16) + NULL-> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8, + FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12, + STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 16, MAP_TYPE -> 32) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -50,18 +47,19 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } test("actualSize") { - def checkActualSize[JvmType]( - columnType: ColumnType[JvmType], - value: JvmType, + def checkActualSize( + columnType: ColumnType[_], + value: Any, expected: Int): Unit = { assertResult(expected, s"Wrong actualSize for $columnType") { val row = new GenericMutableRow(1) - columnType.setField(row, 0, value) + row.update(0, CatalystTypeConverters.convertToCatalyst(value)) columnType.actualSize(row, 0) } } + checkActualSize(NULL, null, 0) checkActualSize(BOOLEAN, true, 1) checkActualSize(BYTE, Byte.MaxValue, 1) checkActualSize(SHORT, Short.MaxValue, 2) @@ -69,176 +67,65 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { checkActualSize(LONG, Long.MaxValue, 8) checkActualSize(FLOAT, Float.MaxValue, 4) checkActualSize(DOUBLE, Double.MaxValue, 8) - checkActualSize(STRING, UTF8String.fromString("hello"), 4 + "hello".getBytes("utf-8").length) + checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) - checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) - - val generic = Map(1 -> "a") - checkActualSize(MAP_GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) + checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8) + checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5) + checkActualSize(ARRAY_TYPE, Array[Any](1), 16) + checkActualSize(MAP_TYPE, Map(1 -> "a"), 25) + checkActualSize(STRUCT_TYPE, Row("hello"), 28) } - testNativeColumnType(BOOLEAN)( - (buffer: ByteBuffer, v: Boolean) => { - buffer.put((if (v) 1 else 0).toByte) - }, - (buffer: ByteBuffer) => { - buffer.get() == 1 - }) + testNativeColumnType(BOOLEAN) + testNativeColumnType(BYTE) + testNativeColumnType(SHORT) + testNativeColumnType(INT) + testNativeColumnType(LONG) + testNativeColumnType(FLOAT) + testNativeColumnType(DOUBLE) + testNativeColumnType(COMPACT_DECIMAL(15, 10)) + testNativeColumnType(STRING) - testNativeColumnType(BYTE)(_.put(_), _.get) + testColumnType(NULL) + testColumnType(BINARY) + testColumnType(LARGE_DECIMAL(20, 10)) + testColumnType(STRUCT_TYPE) + testColumnType(ARRAY_TYPE) + testColumnType(MAP_TYPE) - testNativeColumnType(SHORT)(_.putShort(_), _.getShort) - - testNativeColumnType(INT)(_.putInt(_), _.getInt) - - testNativeColumnType(LONG)(_.putLong(_), _.getLong) - - testNativeColumnType(FLOAT)(_.putFloat(_), _.getFloat) - - testNativeColumnType(DOUBLE)(_.putDouble(_), _.getDouble) - - testNativeColumnType(FIXED_DECIMAL(15, 10))( - (buffer: ByteBuffer, decimal: Decimal) => { - buffer.putLong(decimal.toUnscaledLong) - }, - (buffer: ByteBuffer) => { - Decimal(buffer.getLong(), 15, 10) - }) - - - testNativeColumnType(STRING)( - (buffer: ByteBuffer, string: UTF8String) => { - val bytes = string.getBytes - buffer.putInt(bytes.length) - buffer.put(bytes) - }, - (buffer: ByteBuffer) => { - val length = buffer.getInt() - val bytes = new Array[Byte](length) - buffer.get(bytes) - UTF8String.fromBytes(bytes) - }) - - testColumnType[Array[Byte]]( - BINARY, - (buffer: ByteBuffer, bytes: Array[Byte]) => { - buffer.putInt(bytes.length).put(bytes) - }, - (buffer: ByteBuffer) => { - val length = buffer.getInt() - val bytes = new Array[Byte](length) - buffer.get(bytes, 0, length) - bytes - }) - - test("GENERIC") { - val buffer = ByteBuffer.allocate(512) - val obj = Map(1 -> "spark", 2 -> "sql") - val serializedObj = SparkSqlSerializer.serialize(obj) - - MAP_GENERIC.append(SparkSqlSerializer.serialize(obj), buffer) - buffer.rewind() - - val length = buffer.getInt() - assert(length === serializedObj.length) - - assertResult(obj, "Deserialized object didn't equal to the original object") { - val bytes = new Array[Byte](length) - buffer.get(bytes, 0, length) - SparkSqlSerializer.deserialize(bytes) - } - - buffer.rewind() - buffer.putInt(serializedObj.length).put(serializedObj) - - assertResult(obj, "Deserialized object didn't equal to the original object") { - buffer.rewind() - SparkSqlSerializer.deserialize(MAP_GENERIC.extract(buffer)) - } + def testNativeColumnType[T <: AtomicType](columnType: NativeColumnType[T]): Unit = { + testColumnType[T#InternalType](columnType) } - test("CUSTOM") { - val conf = new SparkConf() - conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator") - val serializer = new SparkSqlSerializer(conf).newInstance() - - val buffer = ByteBuffer.allocate(512) - val obj = CustomClass(Int.MaxValue, Long.MaxValue) - val serializedObj = serializer.serialize(obj).array() - - MAP_GENERIC.append(serializer.serialize(obj).array(), buffer) - buffer.rewind() - - val length = buffer.getInt - assert(length === serializedObj.length) - assert(13 == length) // id (1) + int (4) + long (8) - - val genericSerializedObj = SparkSqlSerializer.serialize(obj) - assert(length != genericSerializedObj.length) - assert(length < genericSerializedObj.length) - - assertResult(obj, "Custom deserialized object didn't equal the original object") { - val bytes = new Array[Byte](length) - buffer.get(bytes, 0, length) - serializer.deserialize(ByteBuffer.wrap(bytes)) - } - - buffer.rewind() - buffer.putInt(serializedObj.length).put(serializedObj) - - assertResult(obj, "Custom deserialized object didn't equal the original object") { - buffer.rewind() - serializer.deserialize(ByteBuffer.wrap(MAP_GENERIC.extract(buffer))) - } - } - - def testNativeColumnType[T <: AtomicType]( - columnType: NativeColumnType[T]) - (putter: (ByteBuffer, T#InternalType) => Unit, - getter: (ByteBuffer) => T#InternalType): Unit = { - - testColumnType[T#InternalType](columnType, putter, getter) - } - - def testColumnType[JvmType]( - columnType: ColumnType[JvmType], - putter: (ByteBuffer, JvmType) => Unit, - getter: (ByteBuffer) => JvmType): Unit = { + def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = { val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE) val seq = (0 until 4).map(_ => makeRandomValue(columnType)) + val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) - test(s"$columnType.extract") { + test(s"$columnType append/extract") { buffer.rewind() - seq.foreach(putter(buffer, _)) + seq.foreach(columnType.append(_, buffer)) buffer.rewind() seq.foreach { expected => logInfo("buffer = " + buffer + ", expected = " + expected) val extracted = columnType.extract(buffer) assert( - expected === extracted, + converter(expected) === converter(extracted), "Extracted value didn't equal to the original one. " + hexDump(expected) + " != " + hexDump(extracted) + ", buffer = " + dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer])) } } - - test(s"$columnType.append") { - buffer.rewind() - seq.foreach(columnType.append(_, buffer)) - - buffer.rewind() - seq.foreach { expected => - assert( - expected === getter(buffer), - "Extracted value didn't equal to the original one") - } - } } private def hexDump(value: Any): String = { - value.toString.map(ch => Integer.toHexString(ch & 0xffff)).mkString(" ") + if (value == null) { + "" + } else { + value.toString.map(ch => Integer.toHexString(ch & 0xffff)).mkString(" ") + } } private def dumpBuffer(buff: ByteBuffer): Any = { @@ -253,33 +140,13 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { test("column type for decimal types with different precision") { (1 to 18).foreach { i => - assertResult(FIXED_DECIMAL(i, 0)) { + assertResult(COMPACT_DECIMAL(i, 0)) { ColumnType(DecimalType(i, 0)) } } - assertResult(GENERIC(DecimalType(19, 0))) { + assertResult(LARGE_DECIMAL(19, 0)) { ColumnType(DecimalType(19, 0)) } } } - -private[columnar] final case class CustomClass(a: Int, b: Long) - -private[columnar] object CustomerSerializer extends Serializer[CustomClass] { - override def write(kryo: Kryo, output: Output, t: CustomClass) { - output.writeInt(t.a) - output.writeLong(t.b) - } - override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = { - val a = input.readInt() - val b = input.readLong() - CustomClass(a, b) - } -} - -private[columnar] final class Registrator extends KryoRegistrator { - override def registerClasses(kryo: Kryo) { - kryo.register(classOf[CustomClass], CustomerSerializer) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index 123a7053c0..964cdb52b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -21,8 +21,8 @@ import scala.collection.immutable.HashSet import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{AtomicType, Decimal} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow} +import org.apache.spark.sql.types.{ArrayBasedMapData, GenericArrayData, AtomicType, Decimal} import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { @@ -40,6 +40,7 @@ object ColumnarTestUtils { } (columnType match { + case NULL => null case BOOLEAN => Random.nextBoolean() case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort @@ -49,10 +50,15 @@ object ColumnarTestUtils { case DOUBLE => Random.nextDouble() case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) case BINARY => randomBytes(Random.nextInt(32)) - case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) - case _ => - // Using a random one-element map instead of an arbitrary object - Map(Random.nextInt() -> Random.nextString(Random.nextInt(32))) + case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) + case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale) + case STRUCT(_) => + new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10)))) + case ARRAY(_) => + new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt())) + case MAP(_) => + ArrayBasedMapData( + Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32))))) }).asInstanceOf[JvmType] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index ea5dd2be33..6265e40a0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -157,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Create a RDD for the schema val rdd = - sparkContext.parallelize((1 to 100), 10).map { i => + sparkContext.parallelize((1 to 10000), 10).map { i => Row( s"str${i}: test cache.", s"binary${i}: test cache.".getBytes("UTF-8"), @@ -172,9 +172,9 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { BigDecimal(Long.MaxValue.toString + ".12345"), new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), new Date(i), - new Timestamp(i), - (1 to i).toSeq, - (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, + new Timestamp(i * 1000000L), + (i to i + 10).toSeq, + (i to i + 10).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, Row((i - 0.25).toFloat, Seq(true, false, null))) } sqlContext.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index a3a23d37d7..78cebbf3cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{ArrayType, StringType} +import org.apache.spark.sql.types._ class TestNullableColumnAccessor[JvmType]( buffer: ByteBuffer, @@ -40,8 +41,10 @@ class NullableColumnAccessorSuite extends SparkFunSuite { import org.apache.spark.sql.columnar.ColumnarTestUtils._ Seq( - BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, - STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) + NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, + STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10), + STRUCT(StructType(StructField("a", StringType) :: Nil)), + ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType))) .foreach { testNullableColumnAccessor(_) } @@ -69,11 +72,13 @@ class NullableColumnAccessorSuite extends SparkFunSuite { val accessor = TestNullableColumnAccessor(builder.build(), columnType) val row = new GenericMutableRow(1) + val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) (0 until 4).foreach { _ => assert(accessor.hasNext) accessor.extractTo(row, 0) - assert(row.get(0, columnType.dataType) === randomRow.get(0, columnType.dataType)) + assert(converter(row.get(0, columnType.dataType)) + === converter(randomRow.get(0, columnType.dataType))) assert(accessor.hasNext) accessor.extractTo(row, 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 9557eead27..fba08e626d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types._ class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) @@ -35,11 +36,13 @@ object TestNullableColumnBuilder { } class NullableColumnBuilderSuite extends SparkFunSuite { - import ColumnarTestUtils._ + import org.apache.spark.sql.columnar.ColumnarTestUtils._ Seq( BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, - STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) + STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10), + STRUCT(StructType(StructField("a", StringType) :: Nil)), + ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType))) .foreach { testNullableColumnBuilder(_) } @@ -74,6 +77,8 @@ class NullableColumnBuilderSuite extends SparkFunSuite { val columnBuilder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) val nullRow = makeNullRow(1) + val dataType = columnType.dataType + val converter = CatalystTypeConverters.createToScalaConverter(dataType) (0 until 4).foreach { _ => columnBuilder.appendFrom(randomRow, 0) @@ -88,14 +93,10 @@ class NullableColumnBuilderSuite extends SparkFunSuite { (1 to 7 by 2).foreach(assertResult(_, "Wrong null position")(buffer.getInt())) // For non-null values + val actual = new GenericMutableRow(new Array[Any](1)) (0 until 4).foreach { _ => - val actual = if (columnType.isInstanceOf[GENERIC]) { - SparkSqlSerializer.deserialize[Any](columnType.extract(buffer).asInstanceOf[Array[Byte]]) - } else { - columnType.extract(buffer) - } - - assert(actual === randomRow.get(0, columnType.dataType), + columnType.extract(buffer, actual, 0) + assert(converter(actual.get(0, dataType)) === converter(randomRow.get(0, dataType)), "Extracted value didn't equal to the original one") }