[SPARK-30066][SQL] Support columnar execution on interval types
### What changes were proposed in this pull request? Columnar execution support for interval types ### Why are the changes needed? support cache tables with interval columns improve performance too ### Does this PR introduce any user-facing change? Yes cache table with accept interval columns ### How was this patch tested? add ut Closes #26699 from yaooqinn/SPARK-30066. Authored-by: Kent Yao <yaooqinn@hotmail.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
parent
f197204f03
commit
d3ec8b1735
|
@ -133,4 +133,16 @@ public final class CalendarInterval implements Serializable, Comparable<Calendar
|
|||
* @throws ArithmeticException if a numeric overflow occurs
|
||||
*/
|
||||
public Duration extractAsDuration() { return Duration.of(microseconds, ChronoUnit.MICROS); }
|
||||
|
||||
/**
|
||||
* A constant holding the minimum value an {@code CalendarInterval} can have.
|
||||
*/
|
||||
public static CalendarInterval MIN_VALUE =
|
||||
new CalendarInterval(Integer.MIN_VALUE, Integer.MIN_VALUE, Long.MIN_VALUE);
|
||||
|
||||
/**
|
||||
* A constant holding the maximum value an {@code CalendarInterval} can have.
|
||||
*/
|
||||
public static CalendarInterval MAX_VALUE =
|
||||
new CalendarInterval(Integer.MAX_VALUE, Integer.MAX_VALUE, Long.MAX_VALUE);
|
||||
}
|
||||
|
|
|
@ -171,6 +171,8 @@ object InternalRow {
|
|||
case LongType | TimestampType => (input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
|
||||
case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float])
|
||||
case DoubleType => (input, v) => input.setDouble(ordinal, v.asInstanceOf[Double])
|
||||
case CalendarIntervalType =>
|
||||
(input, v) => input.setInterval(ordinal, v.asInstanceOf[CalendarInterval])
|
||||
case DecimalType.Fixed(precision, _) =>
|
||||
(input, v) => input.setDecimal(ordinal, v.asInstanceOf[Decimal], precision)
|
||||
case udt: UserDefinedType[_] => getWriter(ordinal, udt.sqlType)
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData
|
|||
import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor
|
||||
import org.apache.spark.sql.execution.vectorized.WritableColumnVector
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.CalendarInterval
|
||||
|
||||
/**
|
||||
* An `Iterator` like trait used to extract values from columnar byte buffer. When a value is
|
||||
|
@ -104,6 +105,10 @@ private[columnar] class BinaryColumnAccessor(buffer: ByteBuffer)
|
|||
extends BasicColumnAccessor[Array[Byte]](buffer, BINARY)
|
||||
with NullableColumnAccessor
|
||||
|
||||
private[columnar] class IntervalColumnAccessor(buffer: ByteBuffer, dataType: CalendarIntervalType)
|
||||
extends BasicColumnAccessor[CalendarInterval](buffer, CALENDAR_INTERVAL)
|
||||
with NullableColumnAccessor
|
||||
|
||||
private[columnar] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType)
|
||||
extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType))
|
||||
|
||||
|
|
|
@ -125,6 +125,9 @@ class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STR
|
|||
private[columnar]
|
||||
class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY)
|
||||
|
||||
private[columnar]
|
||||
class IntervalColumnBuilder extends ComplexColumnBuilder(new IntervalColumnStats, CALENDAR_INTERVAL)
|
||||
|
||||
private[columnar] class CompactDecimalColumnBuilder(dataType: DecimalType)
|
||||
extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType))
|
||||
|
||||
|
@ -176,6 +179,7 @@ private[columnar] object ColumnBuilder {
|
|||
case DoubleType => new DoubleColumnBuilder
|
||||
case StringType => new StringColumnBuilder
|
||||
case BinaryType => new BinaryColumnBuilder
|
||||
case CalendarIntervalType => new IntervalColumnBuilder
|
||||
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
|
||||
new CompactDecimalColumnBuilder(dt)
|
||||
case dt: DecimalType => new DecimalColumnBuilder(dt)
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.columnar
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||
|
||||
private[columnar] class ColumnStatisticsSchema(a: Attribute) extends Serializable {
|
||||
val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)()
|
||||
|
@ -295,6 +295,26 @@ private[columnar] final class BinaryColumnStats extends ColumnStats {
|
|||
Array[Any](null, null, nullCount, count, sizeInBytes)
|
||||
}
|
||||
|
||||
private[columnar] final class IntervalColumnStats extends ColumnStats {
|
||||
protected var upper: CalendarInterval = CalendarInterval.MIN_VALUE
|
||||
protected var lower: CalendarInterval = CalendarInterval.MAX_VALUE
|
||||
|
||||
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
|
||||
if (!row.isNullAt(ordinal)) {
|
||||
val value = row.getInterval(ordinal)
|
||||
if (value.compareTo(upper) > 0) upper = value
|
||||
if (value.compareTo(lower) < 0) lower = value
|
||||
sizeInBytes += CALENDAR_INTERVAL.actualSize(row, ordinal)
|
||||
count += 1
|
||||
} else {
|
||||
gatherNullStats
|
||||
}
|
||||
}
|
||||
|
||||
override def collectedStatistics: Array[Any] =
|
||||
Array[Any](lower, upper, nullCount, count, sizeInBytes)
|
||||
}
|
||||
|
||||
private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
|
||||
def this(dt: DecimalType) = this(dt.precision, dt.scale)
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.Platform
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||
|
||||
|
||||
/**
|
||||
|
@ -705,6 +705,37 @@ private[columnar] case class MAP(dataType: MapType)
|
|||
override def clone(v: UnsafeMapData): UnsafeMapData = v.copy()
|
||||
}
|
||||
|
||||
private[columnar] object CALENDAR_INTERVAL extends ColumnType[CalendarInterval]
|
||||
with DirectCopyColumnType[CalendarInterval] {
|
||||
|
||||
override def dataType: DataType = CalendarIntervalType
|
||||
|
||||
override def defaultSize: Int = 16
|
||||
|
||||
override def actualSize(row: InternalRow, ordinal: Int): Int = 20
|
||||
|
||||
override def getField(row: InternalRow, ordinal: Int): CalendarInterval = row.getInterval(ordinal)
|
||||
|
||||
override def setField(row: InternalRow, ordinal: Int, value: CalendarInterval): Unit = {
|
||||
row.setInterval(ordinal, value)
|
||||
}
|
||||
|
||||
override def extract(buffer: ByteBuffer): CalendarInterval = {
|
||||
ByteBufferHelper.getInt(buffer)
|
||||
val months = ByteBufferHelper.getInt(buffer)
|
||||
val days = ByteBufferHelper.getInt(buffer)
|
||||
val microseconds = ByteBufferHelper.getLong(buffer)
|
||||
new CalendarInterval(months, days, microseconds)
|
||||
}
|
||||
|
||||
override def append(v: CalendarInterval, buffer: ByteBuffer): Unit = {
|
||||
ByteBufferHelper.putInt(buffer, 16)
|
||||
ByteBufferHelper.putInt(buffer, v.months)
|
||||
ByteBufferHelper.putInt(buffer, v.days)
|
||||
ByteBufferHelper.putLong(buffer, v.microseconds)
|
||||
}
|
||||
}
|
||||
|
||||
private[columnar] object ColumnType {
|
||||
@tailrec
|
||||
def apply(dataType: DataType): ColumnType[_] = {
|
||||
|
@ -719,6 +750,7 @@ private[columnar] object ColumnType {
|
|||
case DoubleType => DOUBLE
|
||||
case StringType => STRING
|
||||
case BinaryType => BINARY
|
||||
case i: CalendarIntervalType => CALENDAR_INTERVAL
|
||||
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => COMPACT_DECIMAL(dt)
|
||||
case dt: DecimalType => LARGE_DECIMAL(dt)
|
||||
case arr: ArrayType => ARRAY(arr)
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, CodeGenerator, UnsafeRowWriter}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.CalendarInterval
|
||||
|
||||
/**
|
||||
* An Iterator to walk through the InternalRows from a CachedBatch
|
||||
|
@ -51,6 +52,10 @@ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends BaseGenericInternalR
|
|||
// the writer will be used directly to avoid creating wrapper objects
|
||||
override def setDecimal(i: Int, v: Decimal, precision: Int): Unit =
|
||||
throw new UnsupportedOperationException
|
||||
|
||||
override def setInterval(i: Int, value: CalendarInterval): Unit =
|
||||
throw new UnsupportedOperationException
|
||||
|
||||
override def update(i: Int, v: Any): Unit = throw new UnsupportedOperationException
|
||||
|
||||
// all other methods inherited from GenericMutableRow are not need
|
||||
|
@ -81,6 +86,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
|
|||
case DoubleType => classOf[DoubleColumnAccessor].getName
|
||||
case StringType => classOf[StringColumnAccessor].getName
|
||||
case BinaryType => classOf[BinaryColumnAccessor].getName
|
||||
case CalendarIntervalType => classOf[IntervalColumnAccessor].getName
|
||||
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
|
||||
classOf[CompactDecimalColumnAccessor].getName
|
||||
case dt: DecimalType => classOf[DecimalColumnAccessor].getName
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
|
|||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Join, JoinStrategyHint, SHUFFLE_HASH}
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants
|
||||
import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
|
||||
import org.apache.spark.sql.execution.columnar._
|
||||
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
|
||||
|
@ -36,6 +37,7 @@ import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
|
|||
import org.apache.spark.sql.types.{StringType, StructField, StructType}
|
||||
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
|
||||
import org.apache.spark.storage.StorageLevel.{MEMORY_AND_DISK_2, MEMORY_ONLY}
|
||||
import org.apache.spark.unsafe.types.CalendarInterval
|
||||
import org.apache.spark.util.{AccumulatorContext, Utils}
|
||||
|
||||
private case class BigData(s: String)
|
||||
|
@ -1094,4 +1096,17 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSparkSessi
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("cache supports for intervals") {
|
||||
withTable("interval_cache") {
|
||||
Seq((1, "1 second"), (2, "2 seconds"), (2, null))
|
||||
.toDF("k", "v").write.saveAsTable("interval_cache")
|
||||
sql("CACHE TABLE t1 AS SELECT k, cast(v as interval) FROM interval_cache")
|
||||
assert(spark.catalog.isCached("t1"))
|
||||
checkAnswer(sql("SELECT * FROM t1 WHERE k = 1"),
|
||||
Row(1, new CalendarInterval(0, 0, DateTimeConstants.MICROS_PER_SECOND)))
|
||||
sql("UNCACHE TABLE t1")
|
||||
assert(!spark.catalog.isCached("t1"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar
|
|||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.CalendarInterval
|
||||
|
||||
class ColumnStatsSuite extends SparkFunSuite {
|
||||
testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0))
|
||||
|
@ -30,6 +31,7 @@ class ColumnStatsSuite extends SparkFunSuite {
|
|||
testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, Double.MinValue, 0))
|
||||
testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0))
|
||||
testDecimalColumnStats(Array(null, null, 0))
|
||||
testIntervalColumnStats(Array(CalendarInterval.MAX_VALUE, CalendarInterval.MIN_VALUE, 0))
|
||||
|
||||
def testColumnStats[T <: AtomicType, U <: ColumnStats](
|
||||
columnStatsClass: Class[U],
|
||||
|
@ -103,4 +105,40 @@ class ColumnStatsSuite extends SparkFunSuite {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
def testIntervalColumnStats[T <: AtomicType, U <: ColumnStats](
|
||||
initialStatistics: Array[Any]): Unit = {
|
||||
|
||||
val columnStatsName = classOf[IntervalColumnStats].getSimpleName
|
||||
val columnType = CALENDAR_INTERVAL
|
||||
|
||||
test(s"$columnStatsName: empty") {
|
||||
val columnStats = new IntervalColumnStats
|
||||
columnStats.collectedStatistics.zip(initialStatistics).foreach {
|
||||
case (actual, expected) => assert(actual === expected)
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$columnStatsName: non-empty") {
|
||||
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
|
||||
|
||||
val columnStats = new IntervalColumnStats
|
||||
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
|
||||
rows.foreach(columnStats.gatherStats(_, 0))
|
||||
|
||||
val values = rows.take(10).map(_.get(0, columnType.dataType))
|
||||
val ordering = CalendarIntervalType.ordering.asInstanceOf[Ordering[Any]]
|
||||
val stats = columnStats.collectedStatistics
|
||||
|
||||
assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
|
||||
assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
|
||||
assertResult(10, "Wrong null count")(stats(2))
|
||||
assertResult(20, "Wrong row count")(stats(3))
|
||||
assertResult(stats(4), "Wrong size in bytes") {
|
||||
rows.map { row =>
|
||||
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
|
||||
}.sum
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters
|
|||
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection}
|
||||
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.CalendarInterval
|
||||
|
||||
class ColumnTypeSuite extends SparkFunSuite with Logging {
|
||||
private val DEFAULT_BUFFER_SIZE = 512
|
||||
|
@ -38,7 +39,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
|
|||
val checks = Map(
|
||||
NULL -> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8,
|
||||
FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 10) -> 12,
|
||||
STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE -> 68)
|
||||
STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE -> 68,
|
||||
CALENDAR_INTERVAL -> 16)
|
||||
|
||||
checks.foreach { case (columnType, expectedSize) =>
|
||||
assertResult(expectedSize, s"Wrong defaultSize for $columnType") {
|
||||
|
@ -76,6 +78,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
|
|||
checkActualSize(ARRAY_TYPE, Array[Any](1), 4 + 8 + 8 + 8)
|
||||
checkActualSize(MAP_TYPE, Map(1 -> "a"), 4 + (8 + 8 + 8 + 8) + (8 + 8 + 8 + 8))
|
||||
checkActualSize(STRUCT_TYPE, Row("hello"), 28)
|
||||
checkActualSize(CALENDAR_INTERVAL, CalendarInterval.MAX_VALUE, 4 + 4 + 4 + 8)
|
||||
}
|
||||
|
||||
testNativeColumnType(BOOLEAN)
|
||||
|
@ -94,6 +97,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
|
|||
testColumnType(STRUCT_TYPE)
|
||||
testColumnType(ARRAY_TYPE)
|
||||
testColumnType(MAP_TYPE)
|
||||
testColumnType(CALENDAR_INTERVAL)
|
||||
|
||||
def testNativeColumnType[T <: AtomicType](columnType: NativeColumnType[T]): Unit = {
|
||||
testColumnType[T#InternalType](columnType)
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
|
||||
import org.apache.spark.sql.types.{AtomicType, Decimal}
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||
|
||||
object ColumnarTestUtils {
|
||||
def makeNullRow(length: Int): GenericInternalRow = {
|
||||
|
@ -51,6 +51,8 @@ object ColumnarTestUtils {
|
|||
case DOUBLE => Random.nextDouble()
|
||||
case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32)))
|
||||
case BINARY => randomBytes(Random.nextInt(32))
|
||||
case CALENDAR_INTERVAL =>
|
||||
new CalendarInterval(Random.nextInt(), Random.nextInt(), Random.nextLong())
|
||||
case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
|
||||
case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale)
|
||||
case STRUCT(_) =>
|
||||
|
|
|
@ -44,7 +44,8 @@ class NullableColumnAccessorSuite extends SparkFunSuite {
|
|||
NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
|
||||
STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
|
||||
STRUCT(StructType(StructField("a", StringType) :: Nil)),
|
||||
ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType)))
|
||||
ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType)),
|
||||
CALENDAR_INTERVAL)
|
||||
.foreach {
|
||||
testNullableColumnAccessor(_)
|
||||
}
|
||||
|
|
|
@ -42,7 +42,8 @@ class NullableColumnBuilderSuite extends SparkFunSuite {
|
|||
BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
|
||||
STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
|
||||
STRUCT(StructType(StructField("a", StringType) :: Nil)),
|
||||
ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType)))
|
||||
ARRAY(ArrayType(IntegerType)), MAP(MapType(IntegerType, StringType)),
|
||||
CALENDAR_INTERVAL)
|
||||
.foreach {
|
||||
testNullableColumnBuilder(_)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue