[SPARK-8692] [SQL] re-order the case statements that handling catalyst data types

use same order: boolean, byte, short, int, date, long, timestamp, float, double, string, binary, decimal.

Then we can easily check whether some data types are missing by just one glance, and make sure we handle data/timestamp just as int/long.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #7073 from cloud-fan/fix-date and squashes the following commits:

463044d [Wenchen Fan] fix style
51cd347 [Wenchen Fan] refactor handling of date and timestmap
This commit is contained in:
Wenchen Fan 2015-06-29 11:41:26 -07:00 committed by Cheng Lian
parent ea88b1a507
commit ed413bcc78
15 changed files with 199 additions and 234 deletions

View file

@ -196,15 +196,15 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
def this(dataTypes: Seq[DataType]) =
this(
dataTypes.map {
case IntegerType => new MutableInt
case ByteType => new MutableByte
case FloatType => new MutableFloat
case ShortType => new MutableShort
case DoubleType => new MutableDouble
case BooleanType => new MutableBoolean
case LongType => new MutableLong
case DateType => new MutableInt // We use INT for DATE internally
case TimestampType => new MutableLong // We use Long for Timestamp internally
case ByteType => new MutableByte
case ShortType => new MutableShort
// We use INT for DATE internally
case IntegerType | DateType => new MutableInt
// We use Long for Timestamp internally
case LongType | TimestampType => new MutableLong
case FloatType => new MutableFloat
case DoubleType => new MutableDouble
case _ => new MutableAny
}.toArray)

View file

@ -128,14 +128,12 @@ private object UnsafeColumnWriter {
case BooleanType => BooleanUnsafeColumnWriter
case ByteType => ByteUnsafeColumnWriter
case ShortType => ShortUnsafeColumnWriter
case IntegerType => IntUnsafeColumnWriter
case LongType => LongUnsafeColumnWriter
case IntegerType | DateType => IntUnsafeColumnWriter
case LongType | TimestampType => LongUnsafeColumnWriter
case FloatType => FloatUnsafeColumnWriter
case DoubleType => DoubleUnsafeColumnWriter
case StringType => StringUnsafeColumnWriter
case BinaryType => BinaryUnsafeColumnWriter
case DateType => IntUnsafeColumnWriter
case TimestampType => LongUnsafeColumnWriter
case t =>
throw new UnsupportedOperationException(s"Do not know how to write columns of type $t")
}

View file

@ -120,15 +120,13 @@ class CodeGenContext {
case BooleanType => JAVA_BOOLEAN
case ByteType => JAVA_BYTE
case ShortType => JAVA_SHORT
case IntegerType => JAVA_INT
case LongType => JAVA_LONG
case IntegerType | DateType => JAVA_INT
case LongType | TimestampType => JAVA_LONG
case FloatType => JAVA_FLOAT
case DoubleType => JAVA_DOUBLE
case dt: DecimalType => decimalType
case BinaryType => "byte[]"
case StringType => stringType
case DateType => JAVA_INT
case TimestampType => JAVA_LONG
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
case _ => "Object"

View file

@ -71,44 +71,44 @@ private[sql] abstract class NativeColumnAccessor[T <: AtomicType](
private[sql] class BooleanColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, BOOLEAN)
private[sql] class IntColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, INT)
private[sql] class ByteColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, BYTE)
private[sql] class ShortColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, SHORT)
private[sql] class IntColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, INT)
private[sql] class LongColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, LONG)
private[sql] class ByteColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, BYTE)
private[sql] class DoubleColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, DOUBLE)
private[sql] class FloatColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, FLOAT)
private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int)
extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale))
private[sql] class DoubleColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, DOUBLE)
private[sql] class StringColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, STRING)
private[sql] class BinaryColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY)
with NullableColumnAccessor
private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int)
extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale))
private[sql] class GenericColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC)
with NullableColumnAccessor
private[sql] class DateColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, DATE)
private[sql] class TimestampColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, TIMESTAMP)
private[sql] class BinaryColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY)
with NullableColumnAccessor
private[sql] class GenericColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC)
with NullableColumnAccessor
private[sql] object ColumnAccessor {
def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = {
val dup = buffer.duplicate().order(ByteOrder.nativeOrder)
@ -118,17 +118,17 @@ private[sql] object ColumnAccessor {
dup.getInt()
dataType match {
case IntegerType => new IntColumnAccessor(dup)
case LongType => new LongColumnAccessor(dup)
case FloatType => new FloatColumnAccessor(dup)
case DoubleType => new DoubleColumnAccessor(dup)
case BooleanType => new BooleanColumnAccessor(dup)
case ByteType => new ByteColumnAccessor(dup)
case ShortType => new ShortColumnAccessor(dup)
case IntegerType => new IntColumnAccessor(dup)
case DateType => new DateColumnAccessor(dup)
case LongType => new LongColumnAccessor(dup)
case TimestampType => new TimestampColumnAccessor(dup)
case FloatType => new FloatColumnAccessor(dup)
case DoubleType => new DoubleColumnAccessor(dup)
case StringType => new StringColumnAccessor(dup)
case BinaryType => new BinaryColumnAccessor(dup)
case DateType => new DateColumnAccessor(dup)
case TimestampType => new TimestampColumnAccessor(dup)
case DecimalType.Fixed(precision, scale) if precision < 19 =>
new FixedDecimalColumnAccessor(dup, precision, scale)
case _ => new GenericColumnAccessor(dup)

View file

@ -94,17 +94,21 @@ private[sql] abstract class NativeColumnBuilder[T <: AtomicType](
private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN)
private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT)
private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE)
private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT)
private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT)
private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG)
private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE)
private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT)
private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE)
private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT)
private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING)
private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY)
private[sql] class FixedDecimalColumnBuilder(
precision: Int,
@ -113,19 +117,15 @@ private[sql] class FixedDecimalColumnBuilder(
new FixedDecimalColumnStats,
FIXED_DECIMAL(precision, scale))
private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING)
// TODO (lian) Add support for array, struct and map
private[sql] class GenericColumnBuilder
extends ComplexColumnBuilder(new GenericColumnStats, GENERIC)
private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE)
private[sql] class TimestampColumnBuilder
extends NativeColumnBuilder(new TimestampColumnStats, TIMESTAMP)
private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY)
// TODO (lian) Add support for array, struct and map
private[sql] class GenericColumnBuilder
extends ComplexColumnBuilder(new GenericColumnStats, GENERIC)
private[sql] object ColumnBuilder {
val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024
@ -151,17 +151,17 @@ private[sql] object ColumnBuilder {
columnName: String = "",
useCompression: Boolean = false): ColumnBuilder = {
val builder: ColumnBuilder = dataType match {
case IntegerType => new IntColumnBuilder
case LongType => new LongColumnBuilder
case FloatType => new FloatColumnBuilder
case DoubleType => new DoubleColumnBuilder
case BooleanType => new BooleanColumnBuilder
case ByteType => new ByteColumnBuilder
case ShortType => new ShortColumnBuilder
case IntegerType => new IntColumnBuilder
case DateType => new DateColumnBuilder
case LongType => new LongColumnBuilder
case TimestampType => new TimestampColumnBuilder
case FloatType => new FloatColumnBuilder
case DoubleType => new DoubleColumnBuilder
case StringType => new StringColumnBuilder
case BinaryType => new BinaryColumnBuilder
case DateType => new DateColumnBuilder
case TimestampType => new TimestampColumnBuilder
case DecimalType.Fixed(precision, scale) if precision < 19 =>
new FixedDecimalColumnBuilder(precision, scale)
case _ => new GenericColumnBuilder

View file

@ -132,6 +132,24 @@ private[sql] class ShortColumnStats extends ColumnStats {
InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class IntColumnStats extends ColumnStats {
protected var upper = Int.MinValue
protected var lower = Int.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getInt(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += INT.defaultSize
}
}
override def collectedStatistics: InternalRow =
InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class LongColumnStats extends ColumnStats {
protected var upper = Long.MinValue
protected var lower = Long.MaxValue
@ -150,24 +168,6 @@ private[sql] class LongColumnStats extends ColumnStats {
InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class DoubleColumnStats extends ColumnStats {
protected var upper = Double.MinValue
protected var lower = Double.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getDouble(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += DOUBLE.defaultSize
}
}
override def collectedStatistics: InternalRow =
InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class FloatColumnStats extends ColumnStats {
protected var upper = Float.MinValue
protected var lower = Float.MaxValue
@ -186,35 +186,17 @@ private[sql] class FloatColumnStats extends ColumnStats {
InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class FixedDecimalColumnStats extends ColumnStats {
protected var upper: Decimal = null
protected var lower: Decimal = null
private[sql] class DoubleColumnStats extends ColumnStats {
protected var upper = Double.MinValue
protected var lower = Double.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row(ordinal).asInstanceOf[Decimal]
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
sizeInBytes += FIXED_DECIMAL.defaultSize
}
}
override def collectedStatistics: InternalRow =
InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class IntColumnStats extends ColumnStats {
protected var upper = Int.MinValue
protected var lower = Int.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getInt(ordinal)
val value = row.getDouble(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += INT.defaultSize
sizeInBytes += DOUBLE.defaultSize
}
}
@ -240,10 +222,6 @@ private[sql] class StringColumnStats extends ColumnStats {
InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class DateColumnStats extends IntColumnStats
private[sql] class TimestampColumnStats extends LongColumnStats
private[sql] class BinaryColumnStats extends ColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
@ -256,6 +234,24 @@ private[sql] class BinaryColumnStats extends ColumnStats {
InternalRow(null, null, nullCount, count, sizeInBytes)
}
private[sql] class FixedDecimalColumnStats extends ColumnStats {
protected var upper: Decimal = null
protected var lower: Decimal = null
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row(ordinal).asInstanceOf[Decimal]
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
sizeInBytes += FIXED_DECIMAL.defaultSize
}
}
override def collectedStatistics: InternalRow =
InternalRow(lower, upper, nullCount, count, sizeInBytes)
}
private[sql] class GenericColumnStats extends ColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
@ -267,3 +263,7 @@ private[sql] class GenericColumnStats extends ColumnStats {
override def collectedStatistics: InternalRow =
InternalRow(null, null, nullCount, count, sizeInBytes)
}
private[sql] class DateColumnStats extends IntColumnStats
private[sql] class TimestampColumnStats extends LongColumnStats

View file

@ -447,17 +447,17 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) {
private[sql] object ColumnType {
def apply(dataType: DataType): ColumnType[_, _] = {
dataType match {
case IntegerType => INT
case LongType => LONG
case FloatType => FLOAT
case DoubleType => DOUBLE
case BooleanType => BOOLEAN
case ByteType => BYTE
case ShortType => SHORT
case IntegerType => INT
case DateType => DATE
case LongType => LONG
case TimestampType => TIMESTAMP
case FloatType => FLOAT
case DoubleType => DOUBLE
case StringType => STRING
case BinaryType => BINARY
case DateType => DATE
case TimestampType => TIMESTAMP
case DecimalType.Fixed(precision, scale) if precision < 19 =>
FIXED_DECIMAL(precision, scale)
case _ => GENERIC

View file

@ -237,7 +237,7 @@ private[sql] object SparkSqlSerializer2 {
out.writeShort(row.getShort(i))
}
case IntegerType =>
case IntegerType | DateType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
@ -245,7 +245,7 @@ private[sql] object SparkSqlSerializer2 {
out.writeInt(row.getInt(i))
}
case LongType =>
case LongType | TimestampType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
@ -269,37 +269,6 @@ private[sql] object SparkSqlSerializer2 {
out.writeDouble(row.getDouble(i))
}
case decimal: DecimalType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
val value = row.apply(i).asInstanceOf[Decimal]
val javaBigDecimal = value.toJavaBigDecimal
// First, write out the unscaled value.
val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray
out.writeInt(bytes.length)
out.write(bytes)
// Then, write out the scale.
out.writeInt(javaBigDecimal.scale())
}
case DateType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
out.writeInt(row.getAs[Int](i))
}
case TimestampType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
out.writeLong(row.getAs[Long](i))
}
case StringType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
@ -319,6 +288,21 @@ private[sql] object SparkSqlSerializer2 {
out.writeInt(bytes.length)
out.write(bytes)
}
case decimal: DecimalType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
val value = row.apply(i).asInstanceOf[Decimal]
val javaBigDecimal = value.toJavaBigDecimal
// First, write out the unscaled value.
val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray
out.writeInt(bytes.length)
out.write(bytes)
// Then, write out the scale.
out.writeInt(javaBigDecimal.scale())
}
}
i += 1
}
@ -364,14 +348,14 @@ private[sql] object SparkSqlSerializer2 {
mutableRow.setShort(i, in.readShort())
}
case IntegerType =>
case IntegerType | DateType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
mutableRow.setInt(i, in.readInt())
}
case LongType =>
case LongType | TimestampType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
@ -392,35 +376,6 @@ private[sql] object SparkSqlSerializer2 {
mutableRow.setDouble(i, in.readDouble())
}
case decimal: DecimalType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
// First, read in the unscaled value.
val length = in.readInt()
val bytes = new Array[Byte](length)
in.readFully(bytes)
val unscaledVal = new BigInteger(bytes)
// Then, read the scale.
val scale = in.readInt()
// Finally, create the Decimal object and set it in the row.
mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale)))
}
case DateType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
mutableRow.update(i, in.readInt())
}
case TimestampType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
mutableRow.update(i, in.readLong())
}
case StringType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
@ -440,6 +395,21 @@ private[sql] object SparkSqlSerializer2 {
in.readFully(bytes)
mutableRow.update(i, bytes)
}
case decimal: DecimalType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
// First, read in the unscaled value.
val length = in.readInt()
val bytes = new Array[Byte](length)
in.readFully(bytes)
val unscaledVal = new BigInteger(bytes)
// Then, read the scale.
val scale = in.readInt()
// Finally, create the Decimal object and set it in the row.
mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale)))
}
}
i += 1
}

View file

@ -198,19 +198,18 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
private[parquet] def writePrimitive(schema: DataType, value: Any): Unit = {
if (value != null) {
schema match {
case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean])
case ByteType => writer.addInteger(value.asInstanceOf[Byte])
case ShortType => writer.addInteger(value.asInstanceOf[Short])
case IntegerType | DateType => writer.addInteger(value.asInstanceOf[Int])
case LongType => writer.addLong(value.asInstanceOf[Long])
case TimestampType => writeTimestamp(value.asInstanceOf[Long])
case FloatType => writer.addFloat(value.asInstanceOf[Float])
case DoubleType => writer.addDouble(value.asInstanceOf[Double])
case StringType => writer.addBinary(
Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes))
case BinaryType => writer.addBinary(
Binary.fromByteArray(value.asInstanceOf[Array[Byte]]))
case IntegerType => writer.addInteger(value.asInstanceOf[Int])
case ShortType => writer.addInteger(value.asInstanceOf[Short])
case LongType => writer.addLong(value.asInstanceOf[Long])
case TimestampType => writeTimestamp(value.asInstanceOf[Long])
case ByteType => writer.addInteger(value.asInstanceOf[Byte])
case DoubleType => writer.addDouble(value.asInstanceOf[Double])
case FloatType => writer.addFloat(value.asInstanceOf[Float])
case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean])
case DateType => writer.addInteger(value.asInstanceOf[Int])
case d: DecimalType =>
if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
sys.error(s"Unsupported datatype $d, cannot write to consumer")
@ -353,19 +352,18 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
record: InternalRow,
index: Int): Unit = {
ctype match {
case BooleanType => writer.addBoolean(record.getBoolean(index))
case ByteType => writer.addInteger(record.getByte(index))
case ShortType => writer.addInteger(record.getShort(index))
case IntegerType | DateType => writer.addInteger(record.getInt(index))
case LongType => writer.addLong(record.getLong(index))
case TimestampType => writeTimestamp(record.getLong(index))
case FloatType => writer.addFloat(record.getFloat(index))
case DoubleType => writer.addDouble(record.getDouble(index))
case StringType => writer.addBinary(
Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes))
case BinaryType => writer.addBinary(
Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]]))
case IntegerType => writer.addInteger(record.getInt(index))
case ShortType => writer.addInteger(record.getShort(index))
case LongType => writer.addLong(record.getLong(index))
case ByteType => writer.addInteger(record.getByte(index))
case DoubleType => writer.addDouble(record.getDouble(index))
case FloatType => writer.addFloat(record.getFloat(index))
case BooleanType => writer.addBoolean(record.getBoolean(index))
case DateType => writer.addInteger(record.getInt(index))
case TimestampType => writeTimestamp(record.getLong(index))
case d: DecimalType =>
if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
sys.error(s"Unsupported datatype $d, cannot write to consumer")

View file

@ -38,8 +38,8 @@ import org.apache.spark.sql.types._
private[parquet] object ParquetTypesConverter extends Logging {
def isPrimitiveType(ctype: DataType): Boolean = ctype match {
case _: NumericType | BooleanType | StringType | BinaryType => true
case _: DataType => false
case _: NumericType | BooleanType | DateType | TimestampType | StringType | BinaryType => true
case _ => false
}
/**

View file

@ -22,19 +22,20 @@ import org.apache.spark.sql.catalyst.expressions.InternalRow
import org.apache.spark.sql.types._
class ColumnStatsSuite extends SparkFunSuite {
testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0))
testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0))
testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0))
testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0))
testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0))
testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0))
testColumnStats(classOf[TimestampColumnStats], TIMESTAMP,
InternalRow(Long.MaxValue, Long.MinValue, 0))
testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0))
testColumnStats(classOf[DoubleColumnStats], DOUBLE,
InternalRow(Double.MaxValue, Double.MinValue, 0))
testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0))
testColumnStats(classOf[FixedDecimalColumnStats],
FIXED_DECIMAL(15, 10), InternalRow(null, null, 0))
testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0))
testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0))
testColumnStats(classOf[TimestampColumnStats], TIMESTAMP,
InternalRow(Long.MaxValue, Long.MinValue, 0))
def testColumnStats[T <: AtomicType, U <: ColumnStats](
columnStatsClass: Class[U],

View file

@ -36,9 +36,9 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
test("defaultSize") {
val checks = Map(
INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4,
FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 8,
BINARY -> 16, GENERIC -> 16)
BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, DATE -> 4,
LONG -> 8, TIMESTAMP -> 8, FLOAT -> 4, DOUBLE -> 8,
STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, GENERIC -> 16)
checks.foreach { case (columnType, expectedSize) =>
assertResult(expectedSize, s"Wrong defaultSize for $columnType") {
@ -60,27 +60,24 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
}
}
checkActualSize(INT, Int.MaxValue, 4)
checkActualSize(SHORT, Short.MaxValue, 2)
checkActualSize(LONG, Long.MaxValue, 8)
checkActualSize(BYTE, Byte.MaxValue, 1)
checkActualSize(DOUBLE, Double.MaxValue, 8)
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
checkActualSize(BOOLEAN, true, 1)
checkActualSize(BYTE, Byte.MaxValue, 1)
checkActualSize(SHORT, Short.MaxValue, 2)
checkActualSize(INT, Int.MaxValue, 4)
checkActualSize(DATE, Int.MaxValue, 4)
checkActualSize(LONG, Long.MaxValue, 8)
checkActualSize(TIMESTAMP, Long.MaxValue, 8)
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(DOUBLE, Double.MaxValue, 8)
checkActualSize(STRING, UTF8String.fromString("hello"), 4 + "hello".getBytes("utf-8").length)
checkActualSize(DATE, 0, 4)
checkActualSize(TIMESTAMP, 0L, 8)
val binary = Array.fill[Byte](4)(0: Byte)
checkActualSize(BINARY, binary, 4 + 4)
checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4)
checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
val generic = Map(1 -> "a")
checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8)
}
testNativeColumnType[BooleanType.type](
BOOLEAN,
testNativeColumnType(BOOLEAN)(
(buffer: ByteBuffer, v: Boolean) => {
buffer.put((if (v) 1 else 0).toByte)
},
@ -88,18 +85,23 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
buffer.get() == 1
})
testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt)
testNativeColumnType(BYTE)(_.put(_), _.get)
testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort)
testNativeColumnType(SHORT)(_.putShort(_), _.getShort)
testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong)
testNativeColumnType(INT)(_.putInt(_), _.getInt)
testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get)
testNativeColumnType(DATE)(_.putInt(_), _.getInt)
testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble)
testNativeColumnType(LONG)(_.putLong(_), _.getLong)
testNativeColumnType[DecimalType](
FIXED_DECIMAL(15, 10),
testNativeColumnType(TIMESTAMP)(_.putLong(_), _.getLong)
testNativeColumnType(FLOAT)(_.putFloat(_), _.getFloat)
testNativeColumnType(DOUBLE)(_.putDouble(_), _.getDouble)
testNativeColumnType(FIXED_DECIMAL(15, 10))(
(buffer: ByteBuffer, decimal: Decimal) => {
buffer.putLong(decimal.toUnscaledLong)
},
@ -107,10 +109,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
Decimal(buffer.getLong(), 15, 10)
})
testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat)
testNativeColumnType[StringType.type](
STRING,
testNativeColumnType(STRING)(
(buffer: ByteBuffer, string: UTF8String) => {
val bytes = string.getBytes
buffer.putInt(bytes.length)
@ -197,8 +197,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
}
def testNativeColumnType[T <: AtomicType](
columnType: NativeColumnType[T],
putter: (ByteBuffer, T#InternalType) => Unit,
columnType: NativeColumnType[T])
(putter: (ByteBuffer, T#InternalType) => Unit,
getter: (ByteBuffer) => T#InternalType): Unit = {
testColumnType[T, T#InternalType](columnType, putter, getter)

View file

@ -39,18 +39,18 @@ object ColumnarTestUtils {
}
(columnType match {
case BOOLEAN => Random.nextBoolean()
case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
case INT => Random.nextInt()
case DATE => Random.nextInt()
case LONG => Random.nextLong()
case TIMESTAMP => Random.nextLong()
case FLOAT => Random.nextFloat()
case DOUBLE => Random.nextDouble()
case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
case BOOLEAN => Random.nextBoolean()
case BINARY => randomBytes(Random.nextInt(32))
case DATE => Random.nextInt()
case TIMESTAMP => Random.nextLong()
case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
case _ =>
// Using a random one-element map instead of an arbitrary object
Map(Random.nextInt() -> Random.nextString(Random.nextInt(32)))

View file

@ -42,9 +42,9 @@ class NullableColumnAccessorSuite extends SparkFunSuite {
import ColumnarTestUtils._
Seq(
INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC,
DATE, TIMESTAMP
).foreach {
BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE,
STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC)
.foreach {
testNullableColumnAccessor(_)
}

View file

@ -38,9 +38,9 @@ class NullableColumnBuilderSuite extends SparkFunSuite {
import ColumnarTestUtils._
Seq(
INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC,
DATE, TIMESTAMP
).foreach {
BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE,
STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC)
.foreach {
testNullableColumnBuilder(_)
}