[SPARK-35139][SQL] Support ANSI intervals as Arrow Column vectors
### What changes were proposed in this pull request? Support YearMonthIntervalType and DayTimeIntervalType to extend ArrowColumnVector ### Why are the changes needed? https://issues.apache.org/jira/browse/SPARK-35139 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? 1. By checking coding style via: $ ./dev/scalastyle $ ./dev/lint-java 2. Run the test "ArrowWriterSuite" Closes #32340 from Peng-Lei/SPARK-35139. Authored-by: PengLei <18066542445@189.cn> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
7f51106c0d
commit
eb08b9010a
|
@ -19,12 +19,16 @@ package org.apache.spark.sql.vectorized;
|
|||
|
||||
import org.apache.arrow.vector.*;
|
||||
import org.apache.arrow.vector.complex.*;
|
||||
import org.apache.arrow.vector.holders.NullableIntervalDayHolder;
|
||||
import org.apache.arrow.vector.holders.NullableVarCharHolder;
|
||||
|
||||
import org.apache.spark.sql.util.ArrowUtils;
|
||||
import org.apache.spark.sql.types.*;
|
||||
import org.apache.spark.unsafe.types.UTF8String;
|
||||
|
||||
import static org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY;
|
||||
import static org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS;
|
||||
|
||||
/**
|
||||
* A column vector backed by Apache Arrow. Currently calendar interval type and map type are not
|
||||
* supported.
|
||||
|
@ -172,6 +176,10 @@ public final class ArrowColumnVector extends ColumnVector {
|
|||
}
|
||||
} else if (vector instanceof NullVector) {
|
||||
accessor = new NullAccessor((NullVector) vector);
|
||||
} else if (vector instanceof IntervalYearVector) {
|
||||
accessor = new IntervalYearAccessor((IntervalYearVector) vector);
|
||||
} else if (vector instanceof IntervalDayVector) {
|
||||
accessor = new IntervalDayAccessor((IntervalDayVector) vector);
|
||||
} else {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
@ -508,4 +516,37 @@ public final class ArrowColumnVector extends ColumnVector {
|
|||
super(vector);
|
||||
}
|
||||
}
|
||||
|
||||
private static class IntervalYearAccessor extends ArrowVectorAccessor {
|
||||
|
||||
private final IntervalYearVector accessor;
|
||||
|
||||
IntervalYearAccessor(IntervalYearVector vector) {
|
||||
super(vector);
|
||||
this.accessor = vector;
|
||||
}
|
||||
|
||||
@Override
|
||||
int getInt(int rowId) {
|
||||
return accessor.get(rowId);
|
||||
}
|
||||
}
|
||||
|
||||
private static class IntervalDayAccessor extends ArrowVectorAccessor {
|
||||
|
||||
private final IntervalDayVector accessor;
|
||||
private final NullableIntervalDayHolder intervalDayHolder = new NullableIntervalDayHolder();
|
||||
|
||||
IntervalDayAccessor(IntervalDayVector vector) {
|
||||
super(vector);
|
||||
this.accessor = vector;
|
||||
}
|
||||
|
||||
@Override
|
||||
long getLong(int rowId) {
|
||||
accessor.get(rowId, intervalDayHolder);
|
||||
return Math.addExact(Math.multiplyExact(intervalDayHolder.days, MICROS_PER_DAY),
|
||||
intervalDayHolder.milliseconds * MICROS_PER_MILLIS);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ import scala.collection.JavaConverters._
|
|||
|
||||
import org.apache.arrow.memory.RootAllocator
|
||||
import org.apache.arrow.vector.complex.MapVector
|
||||
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit}
|
||||
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit}
|
||||
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
|
||||
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
@ -54,6 +54,8 @@ private[sql] object ArrowUtils {
|
|||
new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId)
|
||||
}
|
||||
case NullType => ArrowType.Null.INSTANCE
|
||||
case YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH)
|
||||
case DayTimeIntervalType => new ArrowType.Interval(IntervalUnit.DAY_TIME)
|
||||
case _ =>
|
||||
throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}")
|
||||
}
|
||||
|
@ -74,6 +76,8 @@ private[sql] object ArrowUtils {
|
|||
case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
|
||||
case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType
|
||||
case ArrowType.Null.INSTANCE => NullType
|
||||
case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType
|
||||
case di: ArrowType.Interval if di.getUnit == IntervalUnit.DAY_TIME => DayTimeIntervalType
|
||||
case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt")
|
||||
}
|
||||
|
||||
|
|
|
@ -48,6 +48,8 @@ class ArrowUtilsSuite extends SparkFunSuite {
|
|||
roundtrip(BinaryType)
|
||||
roundtrip(DecimalType.SYSTEM_DEFAULT)
|
||||
roundtrip(DateType)
|
||||
roundtrip(YearMonthIntervalType)
|
||||
roundtrip(DayTimeIntervalType)
|
||||
val tsExMsg = intercept[UnsupportedOperationException] {
|
||||
roundtrip(TimestampType)
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.arrow.vector.complex._
|
|||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants.{MICROS_PER_DAY, MICROS_PER_MILLIS}
|
||||
import org.apache.spark.sql.errors.QueryExecutionErrors
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.util.ArrowUtils
|
||||
|
@ -74,6 +75,8 @@ object ArrowWriter {
|
|||
}
|
||||
new StructWriter(vector, children.toArray)
|
||||
case (NullType, vector: NullVector) => new NullWriter(vector)
|
||||
case (YearMonthIntervalType, vector: IntervalYearVector) => new IntervalYearWriter(vector)
|
||||
case (DayTimeIntervalType, vector: IntervalDayVector) => new IntervalDayWriter(vector)
|
||||
case (dt, _) =>
|
||||
throw QueryExecutionErrors.unsupportedDataTypeError(dt)
|
||||
}
|
||||
|
@ -394,3 +397,28 @@ private[arrow] class NullWriter(val valueVector: NullVector) extends ArrowFieldW
|
|||
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
|
||||
}
|
||||
}
|
||||
|
||||
private[arrow] class IntervalYearWriter(val valueVector: IntervalYearVector)
|
||||
extends ArrowFieldWriter {
|
||||
override def setNull(): Unit = {
|
||||
valueVector.setNull(count)
|
||||
}
|
||||
|
||||
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
|
||||
valueVector.setSafe(count, input.getInt(ordinal));
|
||||
}
|
||||
}
|
||||
|
||||
private[arrow] class IntervalDayWriter(val valueVector: IntervalDayVector)
|
||||
extends ArrowFieldWriter {
|
||||
override def setNull(): Unit = {
|
||||
valueVector.setNull(count)
|
||||
}
|
||||
|
||||
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
|
||||
val totalMicroseconds = input.getLong(ordinal)
|
||||
val days = totalMicroseconds / MICROS_PER_DAY
|
||||
val millis = (totalMicroseconds % MICROS_PER_DAY) / MICROS_PER_MILLIS
|
||||
valueVector.set(count, days.toInt, millis.toInt)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
package org.apache.spark.sql.execution.arrow
|
||||
|
||||
import org.apache.arrow.vector.IntervalDayVector
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.util._
|
||||
|
@ -54,6 +56,8 @@ class ArrowWriterSuite extends SparkFunSuite {
|
|||
case BinaryType => reader.getBinary(rowId)
|
||||
case DateType => reader.getInt(rowId)
|
||||
case TimestampType => reader.getLong(rowId)
|
||||
case YearMonthIntervalType => reader.getInt(rowId)
|
||||
case DayTimeIntervalType => reader.getLong(rowId)
|
||||
}
|
||||
assert(value === datum)
|
||||
}
|
||||
|
@ -73,6 +77,33 @@ class ArrowWriterSuite extends SparkFunSuite {
|
|||
check(DateType, Seq(0, 1, 2, null, 4))
|
||||
check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles")
|
||||
check(NullType, Seq(null, null, null))
|
||||
check(YearMonthIntervalType, Seq(null, 0, 1, -1, Int.MaxValue, Int.MinValue))
|
||||
check(DayTimeIntervalType, Seq(null, 0L, 1000L, -1000L, (Long.MaxValue - 807L),
|
||||
(Long.MinValue + 808L)))
|
||||
}
|
||||
|
||||
test("long overflow for DayTimeIntervalType")
|
||||
{
|
||||
val schema = new StructType().add("value", DayTimeIntervalType, nullable = true)
|
||||
val writer = ArrowWriter.create(schema, null)
|
||||
val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
|
||||
val valueVector = writer.root.getFieldVectors().get(0).asInstanceOf[IntervalDayVector]
|
||||
|
||||
valueVector.set(0, 106751992, 0)
|
||||
valueVector.set(1, 106751991, Int.MaxValue)
|
||||
|
||||
// first long overflow for test Math.multiplyExact()
|
||||
val msg = intercept[java.lang.ArithmeticException] {
|
||||
reader.getLong(0)
|
||||
}.getMessage
|
||||
assert(msg.equals("long overflow"))
|
||||
|
||||
// second long overflow for test Math.addExact()
|
||||
val msg1 = intercept[java.lang.ArithmeticException] {
|
||||
reader.getLong(1)
|
||||
}.getMessage
|
||||
assert(msg1.equals("long overflow"))
|
||||
writer.root.close()
|
||||
}
|
||||
|
||||
test("get multiple") {
|
||||
|
@ -97,6 +128,8 @@ class ArrowWriterSuite extends SparkFunSuite {
|
|||
case DoubleType => reader.getDoubles(0, data.size)
|
||||
case DateType => reader.getInts(0, data.size)
|
||||
case TimestampType => reader.getLongs(0, data.size)
|
||||
case YearMonthIntervalType => reader.getInts(0, data.size)
|
||||
case DayTimeIntervalType => reader.getLongs(0, data.size)
|
||||
}
|
||||
assert(values === data)
|
||||
|
||||
|
@ -111,6 +144,8 @@ class ArrowWriterSuite extends SparkFunSuite {
|
|||
check(DoubleType, (0 until 10).map(_.toDouble))
|
||||
check(DateType, (0 until 10))
|
||||
check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), "America/Los_Angeles")
|
||||
check(YearMonthIntervalType, (0 until 10))
|
||||
check(DayTimeIntervalType, (-10 until 10).map(_ * 1000.toLong))
|
||||
}
|
||||
|
||||
test("array") {
|
||||
|
|
Loading…
Reference in a new issue