[SPARK-30047][SQL] Support interval types in UnsafeRow
### What changes were proposed in this pull request? Optimize aggregates on interval values from sort-based to hash-based, and we can use the `org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch` for better performance. ### Why are the changes needed? improve aggerates ### Does this PR introduce any user-facing change? no ### How was this patch tested? add ut and existing ones Closes #26680 from yaooqinn/SPARK-30047. Authored-by: Kent Yao <yaooqinn@hotmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
04a5b8f5f8
commit
4e073f3c50
|
@ -103,7 +103,8 @@ public final class UnsafeRow extends InternalRow implements Externalizable, Kryo
|
|||
}
|
||||
|
||||
public static boolean isMutable(DataType dt) {
|
||||
return mutableFieldTypes.contains(dt) || dt instanceof DecimalType;
|
||||
return mutableFieldTypes.contains(dt) || dt instanceof DecimalType ||
|
||||
dt instanceof CalendarIntervalType;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -297,6 +298,26 @@ public final class UnsafeRow extends InternalRow implements Externalizable, Kryo
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setInterval(int ordinal, CalendarInterval value) {
|
||||
assertIndexIsValid(ordinal);
|
||||
long cursor = getLong(ordinal) >>> 32;
|
||||
assert cursor > 0 : "invalid cursor " + cursor;
|
||||
if (value == null) {
|
||||
setNullAt(ordinal);
|
||||
// zero-out the bytes
|
||||
Platform.putLong(baseObject, baseOffset + cursor, 0L);
|
||||
Platform.putLong(baseObject, baseOffset + cursor + 8, 0L);
|
||||
// keep the offset for future update
|
||||
Platform.putLong(baseObject, getFieldOffset(ordinal), (cursor << 32) | 16L);
|
||||
} else {
|
||||
Platform.putInt(baseObject, baseOffset + cursor, value.months);
|
||||
Platform.putInt(baseObject, baseOffset + cursor + 4, value.days);
|
||||
Platform.putLong(baseObject, baseOffset + cursor + 8, value.microseconds);
|
||||
setLong(ordinal, (cursor << 32) | 16L);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object get(int ordinal, DataType dataType) {
|
||||
return SpecializedGettersReader.read(this, ordinal, dataType, true, true);
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
|
|||
import org.apache.spark.sql.types.Decimal;
|
||||
import org.apache.spark.unsafe.Platform;
|
||||
import org.apache.spark.unsafe.array.ByteArrayMethods;
|
||||
import org.apache.spark.unsafe.bitset.BitSetMethods;
|
||||
import org.apache.spark.unsafe.types.CalendarInterval;
|
||||
import org.apache.spark.unsafe.types.UTF8String;
|
||||
|
||||
|
@ -134,13 +135,16 @@ public abstract class UnsafeWriter {
|
|||
// grow the global buffer before writing data.
|
||||
grow(16);
|
||||
|
||||
// Write the months, days and microseconds fields of Interval to the variable length portion.
|
||||
Platform.putInt(getBuffer(), cursor(), input.months);
|
||||
Platform.putInt(getBuffer(), cursor() + 4, input.days);
|
||||
Platform.putLong(getBuffer(), cursor() + 8, input.microseconds);
|
||||
|
||||
if (input == null) {
|
||||
BitSetMethods.set(getBuffer(), startingOffset, ordinal);
|
||||
} else {
|
||||
// Write the months, days and microseconds fields of interval to the variable length portion.
|
||||
Platform.putInt(getBuffer(), cursor(), input.months);
|
||||
Platform.putInt(getBuffer(), cursor() + 4, input.days);
|
||||
Platform.putLong(getBuffer(), cursor() + 8, input.microseconds);
|
||||
}
|
||||
// we need to reserve the space so that we can update it later.
|
||||
setOffsetAndSize(ordinal, 16);
|
||||
|
||||
// move the cursor forward.
|
||||
increaseCursor(16);
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||
|
||||
/**
|
||||
* An abstract class for row used internally in Spark SQL, which only contains the columns as
|
||||
|
@ -58,6 +58,8 @@ abstract class InternalRow extends SpecializedGetters with Serializable {
|
|||
*/
|
||||
def setDecimal(i: Int, value: Decimal, precision: Int): Unit = update(i, value)
|
||||
|
||||
def setInterval(i: Int, value: CalendarInterval): Unit = update(i, value)
|
||||
|
||||
/**
|
||||
* Make a copy of the current [[InternalRow]] object.
|
||||
*/
|
||||
|
|
|
@ -605,6 +605,7 @@ class CodegenContext extends Logging {
|
|||
s"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2)"
|
||||
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
|
||||
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
|
||||
case CalendarIntervalType => s"$c1.equals($c2)"
|
||||
case array: ArrayType => genComp(array, c1, c2) + " == 0"
|
||||
case struct: StructType => genComp(struct, c1, c2) + " == 0"
|
||||
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
|
||||
|
@ -1579,6 +1580,7 @@ object CodeGenerator extends Logging {
|
|||
val jt = javaType(dataType)
|
||||
dataType match {
|
||||
case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
|
||||
case CalendarIntervalType => s"$row.setInterval($ordinal, $value)"
|
||||
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
|
||||
case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
|
||||
// The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy
|
||||
|
@ -1602,8 +1604,10 @@ object CodeGenerator extends Logging {
|
|||
nullable: Boolean,
|
||||
isVectorized: Boolean = false): String = {
|
||||
if (nullable) {
|
||||
// Can't call setNullAt on DecimalType, because we need to keep the offset
|
||||
if (!isVectorized && dataType.isInstanceOf[DecimalType]) {
|
||||
// Can't call setNullAt on DecimalType/CalendarIntervalType, because we need to keep the
|
||||
// offset
|
||||
if (!isVectorized && (dataType.isInstanceOf[DecimalType] ||
|
||||
dataType.isInstanceOf[CalendarIntervalType])) {
|
||||
s"""
|
||||
|if (!${ev.isNull}) {
|
||||
| ${setColumn(row, dataType, ordinal, ev.value)};
|
||||
|
@ -1634,6 +1638,7 @@ object CodeGenerator extends Logging {
|
|||
case _ if isPrimitiveType(jt) =>
|
||||
s"$vector.put${primitiveTypeName(jt)}($rowId, $value);"
|
||||
case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});"
|
||||
case CalendarIntervalType => s"$vector.putInterval($rowId, $value);"
|
||||
case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());"
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
|
||||
|
|
|
@ -111,6 +111,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
|||
case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
|
||||
// Can't call setNullAt() for DecimalType with precision larger than 18.
|
||||
s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});"
|
||||
case CalendarIntervalType => s"$rowWriter.write($index, (CalendarInterval) null);"
|
||||
case _ => s"$rowWriter.setNullAt($index);"
|
||||
}
|
||||
|
||||
|
|
|
@ -124,6 +124,36 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB
|
|||
(Timestamp.valueOf("2015-06-22 08:10:25"))
|
||||
}
|
||||
|
||||
testBothCodegenAndInterpreted(
|
||||
"basic conversion with primitive, string and interval types") {
|
||||
val factory = UnsafeProjection
|
||||
val fieldTypes: Array[DataType] = Array(LongType, StringType, CalendarIntervalType)
|
||||
val converter = factory.create(fieldTypes)
|
||||
|
||||
val row = new SpecificInternalRow(fieldTypes)
|
||||
row.setLong(0, 0)
|
||||
row.update(1, UTF8String.fromString("Hello"))
|
||||
val interval1 = new CalendarInterval(3, 1, 1000L)
|
||||
row.update(2, interval1)
|
||||
|
||||
val unsafeRow: UnsafeRow = converter.apply(row)
|
||||
assert(unsafeRow.getSizeInBytes ===
|
||||
8 + 8 * 3 + roundedSize("Hello".getBytes(StandardCharsets.UTF_8).length) + 16)
|
||||
|
||||
assert(unsafeRow.getLong(0) === 0)
|
||||
assert(unsafeRow.getString(1) === "Hello")
|
||||
assert(unsafeRow.getInterval(2) === interval1)
|
||||
|
||||
val interval2 = new CalendarInterval(1, 2, 3L)
|
||||
unsafeRow.setInterval(2, interval2)
|
||||
assert(unsafeRow.getInterval(2) === interval2)
|
||||
|
||||
val offset = unsafeRow.getLong(2) >>> 32
|
||||
unsafeRow.setInterval(2, null)
|
||||
assert(unsafeRow.getInterval(2) === null)
|
||||
assert(unsafeRow.getLong(2) >>> 32 === offset)
|
||||
}
|
||||
|
||||
testBothCodegenAndInterpreted("null handling") {
|
||||
val factory = UnsafeProjection
|
||||
val fieldTypes: Array[DataType] = Array(
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen
|
|||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.types.Decimal
|
||||
import org.apache.spark.unsafe.types.CalendarInterval
|
||||
|
||||
class UnsafeRowWriterSuite extends SparkFunSuite {
|
||||
|
||||
|
@ -49,4 +50,15 @@ class UnsafeRowWriterSuite extends SparkFunSuite {
|
|||
// The two rows should be the equal
|
||||
assert(res1 == res2)
|
||||
}
|
||||
|
||||
test("write and get calendar intervals through UnsafeRowWriter") {
|
||||
val rowWriter = new UnsafeRowWriter(2)
|
||||
rowWriter.resetRowWriter()
|
||||
rowWriter.write(0, null.asInstanceOf[CalendarInterval])
|
||||
assert(rowWriter.getRow.isNullAt(0))
|
||||
assert(rowWriter.getRow.getInterval(0) === null)
|
||||
val interval = new CalendarInterval(0, 1, 0)
|
||||
rowWriter.write(1, interval)
|
||||
assert(rowWriter.getRow.getInterval(1) === interval)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -212,6 +212,8 @@ public final class MutableColumnarRow extends InternalRow {
|
|||
DecimalType t = (DecimalType) dt;
|
||||
Decimal d = Decimal.apply((BigDecimal) value, t.precision(), t.scale());
|
||||
setDecimal(ordinal, d, t.precision());
|
||||
} else if (dt instanceof CalendarIntervalType) {
|
||||
setInterval(ordinal, (CalendarInterval) value);
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Datatype not supported " + dt);
|
||||
}
|
||||
|
@ -270,4 +272,10 @@ public final class MutableColumnarRow extends InternalRow {
|
|||
columns[ordinal].putNotNull(rowId);
|
||||
columns[ordinal].putDecimal(rowId, value, precision);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setInterval(int ordinal, CalendarInterval value) {
|
||||
columns[ordinal].putNotNull(rowId);
|
||||
columns[ordinal].putInterval(rowId, value);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.spark.sql.vectorized.ColumnVector;
|
|||
import org.apache.spark.sql.vectorized.ColumnarArray;
|
||||
import org.apache.spark.sql.vectorized.ColumnarMap;
|
||||
import org.apache.spark.unsafe.array.ByteArrayMethods;
|
||||
import org.apache.spark.unsafe.types.CalendarInterval;
|
||||
import org.apache.spark.unsafe.types.UTF8String;
|
||||
|
||||
/**
|
||||
|
@ -372,6 +373,12 @@ public abstract class WritableColumnVector extends ColumnVector {
|
|||
}
|
||||
}
|
||||
|
||||
public void putInterval(int rowId, CalendarInterval value) {
|
||||
getChild(0).putInt(rowId, value.months);
|
||||
getChild(1).putInt(rowId, value.days);
|
||||
getChild(2).putLong(rowId, value.microseconds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public UTF8String getUTF8String(int rowId) {
|
||||
if (isNullAt(rowId)) return null;
|
||||
|
|
|
@ -38,7 +38,7 @@ import org.apache.spark.sql.execution._
|
|||
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
|
||||
import org.apache.spark.sql.execution.vectorized.MutableColumnarRow
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
|
||||
import org.apache.spark.sql.types.{CalendarIntervalType, DecimalType, StringType, StructType}
|
||||
import org.apache.spark.unsafe.KVIterator
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
@ -643,7 +643,8 @@ case class HashAggregateExec(
|
|||
private def checkIfFastHashMapSupported(ctx: CodegenContext): Boolean = {
|
||||
val isSupported =
|
||||
(groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) ||
|
||||
f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) &&
|
||||
f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType] ||
|
||||
f.dataType.isInstanceOf[CalendarIntervalType]) &&
|
||||
bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge)
|
||||
|
||||
// For vectorized hash map, We do not support byte array based decimal type for aggregate values
|
||||
|
@ -655,7 +656,7 @@ case class HashAggregateExec(
|
|||
val isNotByteArrayDecimalType = bufferSchema.map(_.dataType).filter(_.isInstanceOf[DecimalType])
|
||||
.forall(!DecimalType.isByteArrayDecimalType(_))
|
||||
|
||||
isSupported && isNotByteArrayDecimalType
|
||||
isSupported && isNotByteArrayDecimalType
|
||||
}
|
||||
|
||||
private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = {
|
||||
|
|
|
@ -127,7 +127,8 @@ class RowBasedHashMapGenerator(
|
|||
case t: DecimalType =>
|
||||
s"agg_rowWriter.write(${ordinal}, ${key.name}, ${t.precision}, ${t.scale})"
|
||||
case t: DataType =>
|
||||
if (!t.isInstanceOf[StringType] && !CodeGenerator.isPrimitiveType(t)) {
|
||||
if (!t.isInstanceOf[StringType] && !t.isInstanceOf[CalendarIntervalType] &&
|
||||
!CodeGenerator.isPrimitiveType(t)) {
|
||||
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $t")
|
||||
}
|
||||
s"agg_rowWriter.write(${ordinal}, ${key.name})"
|
||||
|
|
|
@ -29,8 +29,8 @@ import org.apache.spark.sql.functions._
|
|||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.test.SharedSparkSession
|
||||
import org.apache.spark.sql.test.SQLTestData.DecimalData
|
||||
import org.apache.spark.sql.types.{ArrayType, DecimalType, FloatType, IntegerType}
|
||||
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.CalendarInterval
|
||||
|
||||
case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double)
|
||||
|
||||
|
@ -951,4 +951,17 @@ class DataFrameAggregateSuite extends QueryTest with SharedSparkSession {
|
|||
assert(error.message.contains("function count_if requires boolean type"))
|
||||
}
|
||||
}
|
||||
|
||||
test("calendar interval agg support hash aggregate") {
|
||||
val df1 = Seq((1, "1 day"), (2, "2 day"), (3, "3 day"), (3, null)).toDF("a", "b")
|
||||
val df2 = df1.select(avg('b cast CalendarIntervalType))
|
||||
checkAnswer(df2, Row(new CalendarInterval(0, 2, 0)) :: Nil)
|
||||
assert(df2.queryExecution.executedPlan.find(_.isInstanceOf[HashAggregateExec]).isDefined)
|
||||
val df3 = df1.groupBy('a).agg(avg('b cast CalendarIntervalType))
|
||||
checkAnswer(df3,
|
||||
Row(1, new CalendarInterval(0, 1, 0)) ::
|
||||
Row(2, new CalendarInterval(0, 2, 0)) ::
|
||||
Row(3, new CalendarInterval(0, 3, 0)) :: Nil)
|
||||
assert(df3.queryExecution.executedPlan.find(_.isInstanceOf[HashAggregateExec]).isDefined)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -647,29 +647,13 @@ class ColumnarBatchSuite extends SparkFunSuite {
|
|||
assert(days.dataType() == IntegerType)
|
||||
assert(microseconds.dataType() == LongType)
|
||||
|
||||
months.putInt(0, 1)
|
||||
days.putInt(0, 10)
|
||||
microseconds.putLong(0, 100)
|
||||
reference += new CalendarInterval(1, 10, 100)
|
||||
|
||||
months.putInt(1, 0)
|
||||
days.putInt(1, 0)
|
||||
microseconds.putLong(1, 2000)
|
||||
reference += new CalendarInterval(0, 0, 2000)
|
||||
|
||||
column.putNull(2)
|
||||
assert(column.getInterval(2) == null)
|
||||
reference += null
|
||||
|
||||
months.putInt(3, 20)
|
||||
days.putInt(3, 0)
|
||||
microseconds.putLong(3, 0)
|
||||
reference += new CalendarInterval(20, 0, 0)
|
||||
|
||||
months.putInt(4, 0)
|
||||
days.putInt(4, 200)
|
||||
microseconds.putLong(4, 0)
|
||||
reference += new CalendarInterval(0, 200, 0)
|
||||
Seq(new CalendarInterval(1, 10, 100),
|
||||
new CalendarInterval(0, 0, 2000),
|
||||
new CalendarInterval(20, 0, 0),
|
||||
new CalendarInterval(0, 200, 0)).zipWithIndex.foreach { case (v, i) =>
|
||||
column.putInterval(i, v)
|
||||
reference += v
|
||||
}
|
||||
|
||||
reference.zipWithIndex.foreach { case (v, i) =>
|
||||
val errMsg = "VectorType=" + column.getClass.getSimpleName
|
||||
|
|
Loading…
Reference in a new issue