[SPARK-30047][SQL] Support interval types in UnsafeRow

### What changes were proposed in this pull request?

Optimize aggregates on interval values from sort-based to hash-based, and we can use the `org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch` for better performance.

### Why are the changes needed?

improve aggerates

### Does this PR introduce any user-facing change?
no

### How was this patch tested?

add ut and existing ones

Closes #26680 from yaooqinn/SPARK-30047.

Authored-by: Kent Yao <yaooqinn@hotmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Kent Yao 2019-12-02 20:47:23 +08:00 committed by Wenchen Fan
parent 04a5b8f5f8
commit 4e073f3c50
13 changed files with 128 additions and 39 deletions

View file

@ -103,7 +103,8 @@ public final class UnsafeRow extends InternalRow implements Externalizable, Kryo
}
public static boolean isMutable(DataType dt) {
return mutableFieldTypes.contains(dt) || dt instanceof DecimalType;
return mutableFieldTypes.contains(dt) || dt instanceof DecimalType ||
dt instanceof CalendarIntervalType;
}
//////////////////////////////////////////////////////////////////////////////
@ -297,6 +298,26 @@ public final class UnsafeRow extends InternalRow implements Externalizable, Kryo
}
}
@Override
public void setInterval(int ordinal, CalendarInterval value) {
assertIndexIsValid(ordinal);
long cursor = getLong(ordinal) >>> 32;
assert cursor > 0 : "invalid cursor " + cursor;
if (value == null) {
setNullAt(ordinal);
// zero-out the bytes
Platform.putLong(baseObject, baseOffset + cursor, 0L);
Platform.putLong(baseObject, baseOffset + cursor + 8, 0L);
// keep the offset for future update
Platform.putLong(baseObject, getFieldOffset(ordinal), (cursor << 32) | 16L);
} else {
Platform.putInt(baseObject, baseOffset + cursor, value.months);
Platform.putInt(baseObject, baseOffset + cursor + 4, value.days);
Platform.putLong(baseObject, baseOffset + cursor + 8, value.microseconds);
setLong(ordinal, (cursor << 32) | 16L);
}
}
@Override
public Object get(int ordinal, DataType dataType) {
return SpecializedGettersReader.read(this, ordinal, dataType, true, true);

View file

@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@ -134,13 +135,16 @@ public abstract class UnsafeWriter {
// grow the global buffer before writing data.
grow(16);
// Write the months, days and microseconds fields of Interval to the variable length portion.
Platform.putInt(getBuffer(), cursor(), input.months);
Platform.putInt(getBuffer(), cursor() + 4, input.days);
Platform.putLong(getBuffer(), cursor() + 8, input.microseconds);
if (input == null) {
BitSetMethods.set(getBuffer(), startingOffset, ordinal);
} else {
// Write the months, days and microseconds fields of interval to the variable length portion.
Platform.putInt(getBuffer(), cursor(), input.months);
Platform.putInt(getBuffer(), cursor() + 4, input.days);
Platform.putLong(getBuffer(), cursor() + 8, input.microseconds);
}
// we need to reserve the space so that we can update it later.
setOffsetAndSize(ordinal, 16);
// move the cursor forward.
increaseCursor(16);
}

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
* An abstract class for row used internally in Spark SQL, which only contains the columns as
@ -58,6 +58,8 @@ abstract class InternalRow extends SpecializedGetters with Serializable {
*/
def setDecimal(i: Int, value: Decimal, precision: Int): Unit = update(i, value)
def setInterval(i: Int, value: CalendarInterval): Unit = update(i, value)
/**
* Make a copy of the current [[InternalRow]] object.
*/

View file

@ -605,6 +605,7 @@ class CodegenContext extends Logging {
s"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2)"
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
case CalendarIntervalType => s"$c1.equals($c2)"
case array: ArrayType => genComp(array, c1, c2) + " == 0"
case struct: StructType => genComp(struct, c1, c2) + " == 0"
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
@ -1579,6 +1580,7 @@ object CodeGenerator extends Logging {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
case CalendarIntervalType => s"$row.setInterval($ordinal, $value)"
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
// The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy
@ -1602,8 +1604,10 @@ object CodeGenerator extends Logging {
nullable: Boolean,
isVectorized: Boolean = false): String = {
if (nullable) {
// Can't call setNullAt on DecimalType, because we need to keep the offset
if (!isVectorized && dataType.isInstanceOf[DecimalType]) {
// Can't call setNullAt on DecimalType/CalendarIntervalType, because we need to keep the
// offset
if (!isVectorized && (dataType.isInstanceOf[DecimalType] ||
dataType.isInstanceOf[CalendarIntervalType])) {
s"""
|if (!${ev.isNull}) {
| ${setColumn(row, dataType, ordinal, ev.value)};
@ -1634,6 +1638,7 @@ object CodeGenerator extends Logging {
case _ if isPrimitiveType(jt) =>
s"$vector.put${primitiveTypeName(jt)}($rowId, $value);"
case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});"
case CalendarIntervalType => s"$vector.putInterval($rowId, $value);"
case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());"
case _ =>
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")

View file

@ -111,6 +111,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
// Can't call setNullAt() for DecimalType with precision larger than 18.
s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});"
case CalendarIntervalType => s"$rowWriter.write($index, (CalendarInterval) null);"
case _ => s"$rowWriter.setNullAt($index);"
}

View file

@ -124,6 +124,36 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
(Timestamp.valueOf("2015-06-22 08:10:25"))
}
testBothCodegenAndInterpreted(
"basic conversion with primitive, string and interval types") {
val factory = UnsafeProjection
val fieldTypes: Array[DataType] = Array(LongType, StringType, CalendarIntervalType)
val converter = factory.create(fieldTypes)
val row = new SpecificInternalRow(fieldTypes)
row.setLong(0, 0)
row.update(1, UTF8String.fromString("Hello"))
val interval1 = new CalendarInterval(3, 1, 1000L)
row.update(2, interval1)
val unsafeRow: UnsafeRow = converter.apply(row)
assert(unsafeRow.getSizeInBytes ===
8 + 8 * 3 + roundedSize("Hello".getBytes(StandardCharsets.UTF_8).length) + 16)
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.getString(1) === "Hello")
assert(unsafeRow.getInterval(2) === interval1)
val interval2 = new CalendarInterval(1, 2, 3L)
unsafeRow.setInterval(2, interval2)
assert(unsafeRow.getInterval(2) === interval2)
val offset = unsafeRow.getLong(2) >>> 32
unsafeRow.setInterval(2, null)
assert(unsafeRow.getInterval(2) === null)
assert(unsafeRow.getLong(2) >>> 32 === offset)
}
testBothCodegenAndInterpreted("null handling") {
val factory = UnsafeProjection
val fieldTypes: Array[DataType] = Array(

View file

@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.CalendarInterval
class UnsafeRowWriterSuite extends SparkFunSuite {
@ -49,4 +50,15 @@ class UnsafeRowWriterSuite extends SparkFunSuite {
// The two rows should be the equal
assert(res1 == res2)
}
test("write and get calendar intervals through UnsafeRowWriter") {
val rowWriter = new UnsafeRowWriter(2)
rowWriter.resetRowWriter()
rowWriter.write(0, null.asInstanceOf[CalendarInterval])
assert(rowWriter.getRow.isNullAt(0))
assert(rowWriter.getRow.getInterval(0) === null)
val interval = new CalendarInterval(0, 1, 0)
rowWriter.write(1, interval)
assert(rowWriter.getRow.getInterval(1) === interval)
}
}

View file

@ -212,6 +212,8 @@ public final class MutableColumnarRow extends InternalRow {
DecimalType t = (DecimalType) dt;
Decimal d = Decimal.apply((BigDecimal) value, t.precision(), t.scale());
setDecimal(ordinal, d, t.precision());
} else if (dt instanceof CalendarIntervalType) {
setInterval(ordinal, (CalendarInterval) value);
} else {
throw new UnsupportedOperationException("Datatype not supported " + dt);
}
@ -270,4 +272,10 @@ public final class MutableColumnarRow extends InternalRow {
columns[ordinal].putNotNull(rowId);
columns[ordinal].putDecimal(rowId, value, precision);
}
@Override
public void setInterval(int ordinal, CalendarInterval value) {
columns[ordinal].putNotNull(rowId);
columns[ordinal].putInterval(rowId, value);
}
}

View file

@ -27,6 +27,7 @@ import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarArray;
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
/**
@ -372,6 +373,12 @@ public abstract class WritableColumnVector extends ColumnVector {
}
}
public void putInterval(int rowId, CalendarInterval value) {
getChild(0).putInt(rowId, value.months);
getChild(1).putInt(rowId, value.days);
getChild(2).putLong(rowId, value.microseconds);
}
@Override
public UTF8String getUTF8String(int rowId) {
if (isNullAt(rowId)) return null;

View file

@ -38,7 +38,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.vectorized.MutableColumnarRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
import org.apache.spark.sql.types.{CalendarIntervalType, DecimalType, StringType, StructType}
import org.apache.spark.unsafe.KVIterator
import org.apache.spark.util.Utils
@ -643,7 +643,8 @@ case class HashAggregateExec(
private def checkIfFastHashMapSupported(ctx: CodegenContext): Boolean = {
val isSupported =
(groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) ||
f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) &&
f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType] ||
f.dataType.isInstanceOf[CalendarIntervalType]) &&
bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge)
// For vectorized hash map, We do not support byte array based decimal type for aggregate values
@ -655,7 +656,7 @@ case class HashAggregateExec(
val isNotByteArrayDecimalType = bufferSchema.map(_.dataType).filter(_.isInstanceOf[DecimalType])
.forall(!DecimalType.isByteArrayDecimalType(_))
isSupported && isNotByteArrayDecimalType
isSupported && isNotByteArrayDecimalType
}
private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = {

View file

@ -127,7 +127,8 @@ class RowBasedHashMapGenerator(
case t: DecimalType =>
s"agg_rowWriter.write(${ordinal}, ${key.name}, ${t.precision}, ${t.scale})"
case t: DataType =>
if (!t.isInstanceOf[StringType] && !CodeGenerator.isPrimitiveType(t)) {
if (!t.isInstanceOf[StringType] && !t.isInstanceOf[CalendarIntervalType] &&
!CodeGenerator.isPrimitiveType(t)) {
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $t")
}
s"agg_rowWriter.write(${ordinal}, ${key.name})"

View file

@ -29,8 +29,8 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SQLTestData.DecimalData
import org.apache.spark.sql.types.{ArrayType, DecimalType, FloatType, IntegerType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double)
@ -951,4 +951,17 @@ class DataFrameAggregateSuite extends QueryTest with SharedSparkSession {
assert(error.message.contains("function count_if requires boolean type"))
}
}
test("calendar interval agg support hash aggregate") {
val df1 = Seq((1, "1 day"), (2, "2 day"), (3, "3 day"), (3, null)).toDF("a", "b")
val df2 = df1.select(avg('b cast CalendarIntervalType))
checkAnswer(df2, Row(new CalendarInterval(0, 2, 0)) :: Nil)
assert(df2.queryExecution.executedPlan.find(_.isInstanceOf[HashAggregateExec]).isDefined)
val df3 = df1.groupBy('a).agg(avg('b cast CalendarIntervalType))
checkAnswer(df3,
Row(1, new CalendarInterval(0, 1, 0)) ::
Row(2, new CalendarInterval(0, 2, 0)) ::
Row(3, new CalendarInterval(0, 3, 0)) :: Nil)
assert(df3.queryExecution.executedPlan.find(_.isInstanceOf[HashAggregateExec]).isDefined)
}
}

View file

@ -647,29 +647,13 @@ class ColumnarBatchSuite extends SparkFunSuite {
assert(days.dataType() == IntegerType)
assert(microseconds.dataType() == LongType)
months.putInt(0, 1)
days.putInt(0, 10)
microseconds.putLong(0, 100)
reference += new CalendarInterval(1, 10, 100)
months.putInt(1, 0)
days.putInt(1, 0)
microseconds.putLong(1, 2000)
reference += new CalendarInterval(0, 0, 2000)
column.putNull(2)
assert(column.getInterval(2) == null)
reference += null
months.putInt(3, 20)
days.putInt(3, 0)
microseconds.putLong(3, 0)
reference += new CalendarInterval(20, 0, 0)
months.putInt(4, 0)
days.putInt(4, 200)
microseconds.putLong(4, 0)
reference += new CalendarInterval(0, 200, 0)
Seq(new CalendarInterval(1, 10, 100),
new CalendarInterval(0, 0, 2000),
new CalendarInterval(20, 0, 0),
new CalendarInterval(0, 200, 0)).zipWithIndex.foreach { case (v, i) =>
column.putInterval(i, v)
reference += v
}
reference.zipWithIndex.foreach { case (v, i) =>
val errMsg = "VectorType=" + column.getClass.getSimpleName