[SPARK-9425] [SQL] support DecimalType in UnsafeRow

This PR brings the support of DecimalType in UnsafeRow, for precision <= 18, it's settable, otherwise it's not settable.

Author: Davies Liu <davies@databricks.com>

Closes #7758 from davies/unsafe_decimal and squashes the following commits:

478b1ba [Davies Liu] address comments
536314c [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_decimal
7c2e77a [Davies Liu] fix JoinedRow
76d6fa4 [Davies Liu] fix tests
99d3151 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_decimal
d49c6ae [Davies Liu] support DecimalType in UnsafeRow
This commit is contained in:
Davies Liu 2015-07-30 17:18:32 -07:00 committed by Reynold Xin
parent e7a0976e99
commit 0b1a464b6e
23 changed files with 237 additions and 125 deletions

View file

@ -41,7 +41,7 @@ public interface SpecializedGetters {
double getDouble(int ordinal);
Decimal getDecimal(int ordinal);
Decimal getDecimal(int ordinal, int precision, int scale);
UTF8String getUTF8String(int ordinal);

View file

@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.expressions;
import java.util.Iterator;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.PlatformDependent;
@ -61,26 +63,18 @@ public final class UnsafeFixedWidthAggregationMap {
private final boolean enablePerfMetrics;
/**
* @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema,
* false otherwise.
*/
public static boolean supportsGroupKeySchema(StructType schema) {
for (StructField field: schema.fields()) {
if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
return false;
}
}
return true;
}
/**
* @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
* schema, false otherwise.
*/
public static boolean supportsAggregationBufferSchema(StructType schema) {
for (StructField field: schema.fields()) {
if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
if (field.dataType() instanceof DecimalType) {
DecimalType dt = (DecimalType) field.dataType();
if (dt.precision() > Decimal.MAX_LONG_DIGITS()) {
return false;
}
} else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
return false;
}
}

View file

@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions;
import java.io.IOException;
import java.io.OutputStream;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
@ -65,12 +67,7 @@ public final class UnsafeRow extends MutableRow {
*/
public static final Set<DataType> settableFieldTypes;
/**
* Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException).
*/
public static final Set<DataType> readableFieldTypes;
// TODO: support DecimalType
// DecimalType(precision <= 18) is settable
static {
settableFieldTypes = Collections.unmodifiableSet(
new HashSet<>(
@ -86,16 +83,6 @@ public final class UnsafeRow extends MutableRow {
DateType,
TimestampType
})));
// We support get() on a superset of the types for which we support set():
final Set<DataType> _readableFieldTypes = new HashSet<>(
Arrays.asList(new DataType[]{
StringType,
BinaryType,
CalendarIntervalType
}));
_readableFieldTypes.addAll(settableFieldTypes);
readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
}
//////////////////////////////////////////////////////////////////////////////
@ -232,6 +219,21 @@ public final class UnsafeRow extends MutableRow {
PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
}
@Override
public void setDecimal(int ordinal, Decimal value, int precision) {
assertIndexIsValid(ordinal);
if (value == null) {
setNullAt(ordinal);
} else {
if (precision <= Decimal.MAX_LONG_DIGITS()) {
setLong(ordinal, value.toUnscaledLong());
} else {
// TODO(davies): support update decimal (hold a bounded space even it's null)
throw new UnsupportedOperationException();
}
}
}
@Override
public Object get(int ordinal) {
throw new UnsupportedOperationException();
@ -256,7 +258,8 @@ public final class UnsafeRow extends MutableRow {
} else if (dataType instanceof DoubleType) {
return getDouble(ordinal);
} else if (dataType instanceof DecimalType) {
return getDecimal(ordinal);
DecimalType dt = (DecimalType) dataType;
return getDecimal(ordinal, dt.precision(), dt.scale());
} else if (dataType instanceof DateType) {
return getInt(ordinal);
} else if (dataType instanceof TimestampType) {
@ -322,6 +325,22 @@ public final class UnsafeRow extends MutableRow {
return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal));
}
@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
assertIndexIsValid(ordinal);
if (isNullAt(ordinal)) {
return null;
}
if (precision <= Decimal.MAX_LONG_DIGITS()) {
return Decimal.apply(getLong(ordinal), precision, scale);
} else {
byte[] bytes = getBinary(ordinal);
BigInteger bigInteger = new BigInteger(bytes);
BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale);
}
}
@Override
public UTF8String getUTF8String(int ordinal) {
assertIndexIsValid(ordinal);

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.ByteArray;
@ -30,6 +31,47 @@ import org.apache.spark.unsafe.types.UTF8String;
*/
public class UnsafeRowWriters {
/** Writer for Decimal with precision under 18. */
public static class CompactDecimalWriter {
public static int getSize(Decimal input) {
return 0;
}
public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) {
target.setLong(ordinal, input.toUnscaledLong());
return 0;
}
}
/** Writer for Decimal with precision larger than 18. */
public static class DecimalWriter {
public static int getSize(Decimal input) {
// bounded size
return 16;
}
public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) {
final long offset = target.getBaseOffset() + cursor;
final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
final int numBytes = bytes.length;
assert(numBytes <= 16);
// zero-out the bytes
PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 0L);
PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 0L);
// Write the bytes to the variable length portion.
PlatformDependent.copyMemory(bytes, PlatformDependent.BYTE_ARRAY_OFFSET,
target.getBaseObject(), offset, numBytes);
// Set the fixed length portion.
target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
return 16;
}
}
/** Writer for UTF8String. */
public static class UTF8StringWriter {

View file

@ -68,7 +68,7 @@ object CatalystTypeConverters {
case StringType => StringConverter
case DateType => DateConverter
case TimestampType => TimestampConverter
case dt: DecimalType => BigDecimalConverter
case dt: DecimalType => new DecimalConverter(dt)
case BooleanType => BooleanConverter
case ByteType => ByteConverter
case ShortType => ShortConverter
@ -306,7 +306,8 @@ object CatalystTypeConverters {
DateTimeUtils.toJavaTimestamp(row.getLong(column))
}
private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
private class DecimalConverter(dataType: DecimalType)
extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match {
case d: BigDecimal => Decimal(d)
case d: JavaBigDecimal => Decimal(d)
@ -314,9 +315,11 @@ object CatalystTypeConverters {
}
override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal
override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal =
row.getDecimal(column).toJavaBigDecimal
row.getDecimal(column, dataType.precision, dataType.scale).toJavaBigDecimal
}
private object BigDecimalConverter extends DecimalConverter(DecimalType.SYSTEM_DEFAULT)
private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] {
final override def toScala(catalystValue: Any): Any = catalystValue
final override def toCatalystImpl(scalaValue: T): Any = scalaValue

View file

@ -58,8 +58,8 @@ abstract class InternalRow extends Serializable with SpecializedGetters {
override def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType)
override def getDecimal(ordinal: Int): Decimal =
getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT)
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal =
getAs[Decimal](ordinal, DecimalType(precision, scale))
override def getInterval(ordinal: Int): CalendarInterval =
getAs[CalendarInterval](ordinal, CalendarIntervalType)

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection}
import org.apache.spark.sql.types.{StructType, DataType}
import org.apache.spark.sql.types.{Decimal, StructType, DataType}
import org.apache.spark.unsafe.types.UTF8String
/**
@ -225,6 +225,11 @@ class JoinedRow extends InternalRow {
override def getFloat(i: Int): Float =
if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)
override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = {
if (i < row1.numFields) row1.getDecimal(i, precision, scale)
else row2.getDecimal(i - row1.numFields, precision, scale)
}
override def getStruct(i: Int, numFields: Int): InternalRow = {
if (i < row1.numFields) {
row1.getStruct(i, numFields)

View file

@ -106,6 +106,7 @@ class CodeGenContext {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) => s"$getter.get${primitiveTypeName(jt)}($ordinal)"
case t: DecimalType => s"$getter.getDecimal($ordinal, ${t.precision}, ${t.scale})"
case StringType => s"$getter.getUTF8String($ordinal)"
case BinaryType => s"$getter.getBinary($ordinal)"
case CalendarIntervalType => s"$getter.getInterval($ordinal)"
@ -120,10 +121,10 @@ class CodeGenContext {
*/
def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = {
val jt = javaType(dataType)
if (isPrimitiveType(jt)) {
s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
} else {
s"$row.update($ordinal, $value)"
dataType match {
case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
case _ => s"$row.update($ordinal, $value)"
}
}

View file

@ -35,6 +35,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName
private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName
private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName
private val CompactDecimalWriter = classOf[UnsafeRowWriters.CompactDecimalWriter].getName
private val DecimalWriter = classOf[UnsafeRowWriters.DecimalWriter].getName
/** Returns true iff we support this data type. */
def canSupport(dataType: DataType): Boolean = dataType match {
@ -42,9 +44,64 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _: CalendarIntervalType => true
case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
case NullType => true
case t: DecimalType => true
case _ => false
}
def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match {
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
s" + (${ev.isNull} ? 0 : $DecimalWriter.getSize(${ev.primitive}))"
case StringType =>
s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))"
case BinaryType =>
s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))"
case CalendarIntervalType =>
s" + (${ev.isNull} ? 0 : 16)"
case _: StructType =>
s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))"
case _ => ""
}
def genFieldWriter(
ctx: CodeGenContext,
fieldType: DataType,
ev: GeneratedExpressionCode,
primitive: String,
index: Int,
cursor: String): String = fieldType match {
case _ if ctx.isPrimitiveType(fieldType) =>
s"${ctx.setColumn(primitive, fieldType, index, ev.primitive)}"
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
s"""
// make sure Decimal object has the same scale as DecimalType
if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) {
$CompactDecimalWriter.write($primitive, $index, $cursor, ${ev.primitive});
} else {
$primitive.setNullAt($index);
}
"""
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
s"""
// make sure Decimal object has the same scale as DecimalType
if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) {
$cursor += $DecimalWriter.write($primitive, $index, $cursor, ${ev.primitive});
} else {
$primitive.setNullAt($index);
}
"""
case StringType =>
s"$cursor += $StringWriter.write($primitive, $index, $cursor, ${ev.primitive})"
case BinaryType =>
s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})"
case CalendarIntervalType =>
s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})"
case t: StructType =>
s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})"
case NullType => ""
case _ =>
throw new UnsupportedOperationException(s"Not supported DataType: $fieldType")
}
/**
* Generates the code to create an [[UnsafeRow]] object based on the input expressions.
* @param ctx context for code generation
@ -69,36 +126,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val allExprs = exprs.map(_.code).mkString("\n")
val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
val additionalSize = expressions.zipWithIndex.map { case (e, i) =>
e.dataType match {
case StringType =>
s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))"
case BinaryType =>
s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))"
case CalendarIntervalType =>
s" + (${exprs(i).isNull} ? 0 : 16)"
case _: StructType =>
s" + (${exprs(i).isNull} ? 0 : $StructWriter.getSize(${exprs(i).primitive}))"
case _ => ""
}
val additionalSize = expressions.zipWithIndex.map {
case (e, i) => genAdditionalSize(e.dataType, exprs(i))
}.mkString("")
val writers = expressions.zipWithIndex.map { case (e, i) =>
val update = e.dataType match {
case dt if ctx.isPrimitiveType(dt) =>
s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}"
case StringType =>
s"$cursor += $StringWriter.write($ret, $i, $cursor, ${exprs(i).primitive})"
case BinaryType =>
s"$cursor += $BinaryWriter.write($ret, $i, $cursor, ${exprs(i).primitive})"
case CalendarIntervalType =>
s"$cursor += $IntervalWriter.write($ret, $i, $cursor, ${exprs(i).primitive})"
case t: StructType =>
s"$cursor += $StructWriter.write($ret, $i, $cursor, ${exprs(i).primitive})"
case NullType => ""
case _ =>
throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}")
}
val update = genFieldWriter(ctx, e.dataType, exprs(i), ret, i, cursor)
s"""if (${exprs(i).isNull}) {
$ret.setNullAt($i);
} else {
@ -168,35 +201,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) =>
dt match {
case StringType =>
s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))"
case BinaryType =>
s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))"
case CalendarIntervalType =>
s" + (${ev.isNull} ? 0 : 16)"
case _: StructType =>
s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))"
case _ => ""
}
genAdditionalSize(dt, ev)
}.mkString("")
val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) =>
val update = dt match {
case _ if ctx.isPrimitiveType(dt) =>
s"${ctx.setColumn(primitive, dt, i, exprs(i).primitive)}"
case StringType =>
s"$cursor += $StringWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})"
case BinaryType =>
s"$cursor += $BinaryWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})"
case CalendarIntervalType =>
s"$cursor += $IntervalWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})"
case t: StructType =>
s"$cursor += $StructWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})"
case NullType => ""
case _ =>
throw new UnsupportedOperationException(s"Not supported DataType: $dt")
}
val update = genFieldWriter(ctx, dt, ev, primitive, i, cursor)
s"""
if (${exprs(i).isNull}) {
$primitive.setNullAt($i);

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.{DataType, StructType, AtomicType}
import org.apache.spark.sql.types.{Decimal, DataType, StructType, AtomicType}
import org.apache.spark.unsafe.types.UTF8String
/**
@ -39,6 +39,7 @@ abstract class MutableRow extends InternalRow {
def setShort(i: Int, value: Short): Unit = { update(i, value) }
def setByte(i: Int, value: Byte): Unit = { update(i, value) }
def setFloat(i: Int, value: Float): Unit = { update(i, value) }
def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) }
def setString(i: Int, value: String): Unit = {
update(i, UTF8String.fromString(value))
}

View file

@ -188,6 +188,10 @@ final class Decimal extends Ordered[Decimal] with Serializable {
* @return true if successful, false if overflow would occur
*/
def changePrecision(precision: Int, scale: Int): Boolean = {
// fast path for UnsafeProjection
if (precision == this.precision && scale == this.scale) {
return true
}
// First, update our longVal if we can, or transfer over to using a BigDecimal
if (decimalVal.eq(null)) {
if (scale < _scale) {
@ -224,7 +228,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
decimalVal = newVal
} else {
// We're still using Longs, but we should check whether we match the new precision
val p = POW_10(math.min(_precision, MAX_LONG_DIGITS))
val p = POW_10(math.min(precision, MAX_LONG_DIGITS))
if (longVal <= -p || longVal >= p) {
// Note that we shouldn't have been able to fix this by switching to BigDecimal
return false

View file

@ -43,7 +43,7 @@ class GenericArrayData(array: Array[Any]) extends ArrayData {
override def getDouble(ordinal: Int): Double = getAs(ordinal)
override def getDecimal(ordinal: Int): Decimal = getAs(ordinal)
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)

View file

@ -242,10 +242,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123))
checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123))
checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0))
checkEvaluation(cast(123L, DecimalType(3, 1)), null)
// TODO: Fix the following bug and re-enable it.
// checkEvaluation(cast(123L, DecimalType(2, 0)), null)
checkEvaluation(cast(123L, DecimalType(2, 0)), null)
}
test("cast from boolean") {

View file

@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
import java.sql.{Timestamp, Date}
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.Calendar

View file

@ -55,13 +55,13 @@ class UnsafeFixedWidthAggregationMapSuite
}
test("supported schemas") {
assert(supportsAggregationBufferSchema(
StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil)))
assert(!supportsAggregationBufferSchema(
StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil)))
assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil)))
assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil)))
assert(
!supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
assert(
!supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
}
test("empty map") {

View file

@ -46,7 +46,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(unsafeRow.getLong(1) === 1)
assert(unsafeRow.getInt(2) === 2)
// We can copy UnsafeRows as long as they don't reference ObjectPools
val unsafeRowCopy = unsafeRow.copy()
assert(unsafeRowCopy.getLong(0) === 0)
assert(unsafeRowCopy.getLong(1) === 1)
@ -122,8 +121,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
FloatType,
DoubleType,
StringType,
BinaryType
// DecimalType.Default,
BinaryType,
DecimalType.USER_DEFAULT
// ArrayType(IntegerType)
)
val converter = UnsafeProjection.create(fieldTypes)
@ -150,7 +149,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(createdFromNull.getDouble(7) === 0.0d)
assert(createdFromNull.getUTF8String(8) === null)
assert(createdFromNull.getBinary(9) === null)
// assert(createdFromNull.get(10) === null)
assert(createdFromNull.getDecimal(10, 10, 0) === null)
// assert(createdFromNull.get(11) === null)
// If we have an UnsafeRow with columns that are initially non-null and we null out those
@ -168,7 +167,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
r.setDouble(7, 700)
r.update(8, UTF8String.fromString("hello"))
r.update(9, "world".getBytes)
// r.update(10, Decimal(10))
r.setDecimal(10, Decimal(10), 10)
// r.update(11, Array(11))
r
}
@ -184,7 +183,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7))
assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8))
assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9))
// assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10))
assert(setToNullAfterCreation.getDecimal(10, 10, 0) ===
rowWithNoNullColumns.getDecimal(10, 10, 0))
// assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
for (i <- fieldTypes.indices) {
@ -203,7 +203,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
setToNullAfterCreation.setDouble(7, 700)
// setToNullAfterCreation.update(8, UTF8String.fromString("hello"))
// setToNullAfterCreation.update(9, "world".getBytes)
// setToNullAfterCreation.update(10, Decimal(10))
setToNullAfterCreation.setDecimal(10, Decimal(10), 10)
// setToNullAfterCreation.update(11, Array(11))
assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
@ -216,7 +216,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7))
// assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8))
// assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9))
// assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10))
assert(setToNullAfterCreation.getDecimal(10, 10, 0) ===
rowWithNoNullColumns.getDecimal(10, 10, 0))
// assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
}

View file

@ -114,7 +114,7 @@ private[sql] class FixedDecimalColumnBuilder(
precision: Int,
scale: Int)
extends NativeColumnBuilder(
new FixedDecimalColumnStats,
new FixedDecimalColumnStats(precision, scale),
FIXED_DECIMAL(precision, scale))
// TODO (lian) Add support for array, struct and map

View file

@ -234,14 +234,14 @@ private[sql] class BinaryColumnStats extends ColumnStats {
InternalRow(null, null, nullCount, count, sizeInBytes)
}
private[sql] class FixedDecimalColumnStats extends ColumnStats {
private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
protected var upper: Decimal = null
protected var lower: Decimal = null
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getDecimal(ordinal)
val value = row.getDecimal(ordinal, precision, scale)
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
sizeInBytes += FIXED_DECIMAL.defaultSize

View file

@ -392,7 +392,7 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
}
override def getField(row: InternalRow, ordinal: Int): Decimal = {
row.getDecimal(ordinal)
row.getDecimal(ordinal, precision, scale)
}
override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = {

View file

@ -202,7 +202,7 @@ case class GeneratedAggregate(
val schemaSupportsUnsafe: Boolean = {
UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema)
UnsafeProjection.canSupport(groupKeySchema)
}
child.execute().mapPartitions { iter =>

View file

@ -298,7 +298,7 @@ private[sql] object SparkSqlSerializer2 {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
val value = row.getDecimal(i)
val value = row.getDecimal(i, decimal.precision, decimal.scale)
val javaBigDecimal = value.toJavaBigDecimal
// First, write out the unscaled value.
val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray

View file

@ -293,8 +293,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes))
case BinaryType =>
writer.addBinary(Binary.fromByteArray(record.getBinary(index)))
case DecimalType.Fixed(precision, _) =>
writeDecimal(record.getDecimal(index), precision)
case DecimalType.Fixed(precision, scale) =>
writeDecimal(record.getDecimal(index, precision, scale), precision)
case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
}
}

View file

@ -34,8 +34,7 @@ class ColumnStatsSuite extends SparkFunSuite {
testColumnStats(classOf[DoubleColumnStats], DOUBLE,
InternalRow(Double.MaxValue, Double.MinValue, 0))
testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0))
testColumnStats(classOf[FixedDecimalColumnStats],
FIXED_DECIMAL(15, 10), InternalRow(null, null, 0))
testDecimalColumnStats(InternalRow(null, null, 0))
def testColumnStats[T <: AtomicType, U <: ColumnStats](
columnStatsClass: Class[U],
@ -52,7 +51,7 @@ class ColumnStatsSuite extends SparkFunSuite {
}
test(s"$columnStatsName: non-empty") {
import ColumnarTestUtils._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
val columnStats = columnStatsClass.newInstance()
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
@ -73,4 +72,39 @@ class ColumnStatsSuite extends SparkFunSuite {
}
}
}
def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) {
val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName
val columnType = FIXED_DECIMAL(15, 10)
test(s"$columnStatsName: empty") {
val columnStats = new FixedDecimalColumnStats(15, 10)
columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach {
case (actual, expected) => assert(actual === expected)
}
}
test(s"$columnStatsName: non-empty") {
import org.apache.spark.sql.columnar.ColumnarTestUtils._
val columnStats = new FixedDecimalColumnStats(15, 10)
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(columnStats.gatherStats(_, 0))
val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType])
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
val stats = columnStats.collectedStatistics
assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0))
assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1))
assertResult(10, "Wrong null count")(stats.genericGet(2))
assertResult(20, "Wrong row count")(stats.genericGet(3))
assertResult(stats.genericGet(4), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum
}
}
}
}