[SPARK-34879][SQL] HiveInspector supports DayTimeIntervalType and YearMonthIntervalType

### What changes were proposed in this pull request?
Make HiveInspector support DayTimeIntervalType and YearMonthIntervalType.
Then we can use these two types in HiveUDF and HiveScriptTransformation

### Why are the changes needed?
Support more data type when use hive serde

### Does this PR introduce _any_ user-facing change?
User can use  `DayTimeIntervalType` and `YearMonthIntervalType` in HiveUDF and  HiveScriptTransformation

### How was this patch tested?
Added UT

Closes #31979 from AngersZhuuuu/SPARK-34879.

Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
Angerszhuuuu 2021-03-29 08:38:20 +03:00 committed by Max Gekk
parent 2356cdd420
commit 015c59843c
2 changed files with 154 additions and 1 deletions

View file

@ -18,11 +18,12 @@
package org.apache.spark.sql.hive
import java.lang.reflect.{ParameterizedType, Type, WildcardType}
import java.time.Duration
import scala.collection.JavaConverters._
import org.apache.hadoop.{io => hadoopIo}
import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar}
import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveIntervalDayTime, HiveIntervalYearMonth, HiveVarchar}
import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _}
import org.apache.hadoop.hive.serde2.objectinspector.primitive._
@ -346,6 +347,17 @@ private[hive] trait HiveInspectors {
withNullSafe(o => getTimestampWritable(o))
case _: TimestampObjectInspector =>
withNullSafe(o => DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]))
case _: HiveIntervalDayTimeObjectInspector if x.preferWritable() =>
withNullSafe(o => getHiveIntervalDayTimeWritable(o))
case _: HiveIntervalDayTimeObjectInspector =>
withNullSafe(o => {
val duration = IntervalUtils.microsToDuration(o.asInstanceOf[Long])
new HiveIntervalDayTime(duration.getSeconds, duration.getNano)
})
case _: HiveIntervalYearMonthObjectInspector if x.preferWritable() =>
withNullSafe(o => getHiveIntervalYearMonthWritable(o))
case _: HiveIntervalYearMonthObjectInspector =>
withNullSafe(o => new HiveIntervalYearMonth(o.asInstanceOf[Int]))
case _: VoidObjectInspector =>
(_: Any) => null // always be null for void object inspector
}
@ -512,6 +524,13 @@ private[hive] trait HiveInspectors {
_ => constant
case poi: VoidObjectInspector =>
_ => null // always be null for void object inspector
case dt: WritableConstantHiveIntervalDayTimeObjectInspector =>
val constant = dt.getWritableConstantValue.asInstanceOf[HiveIntervalDayTime]
_ => IntervalUtils.durationToMicros(
Duration.ofSeconds(constant.getTotalSeconds).plusNanos(constant.getNanos.toLong))
case ym: WritableConstantHiveIntervalYearMonthObjectInspector =>
val constant = ym.getWritableConstantValue.asInstanceOf[HiveIntervalYearMonth]
_ => constant.getTotalMonths
case pi: PrimitiveObjectInspector => pi match {
// We think HiveVarchar/HiveChar is also a String
case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() =>
@ -647,6 +666,42 @@ private[hive] trait HiveInspectors {
null
}
}
case dt: HiveIntervalDayTimeObjectInspector if dt.preferWritable() =>
data: Any => {
if (data != null) {
val dayTime = dt.getPrimitiveWritableObject(data).getHiveIntervalDayTime
IntervalUtils.durationToMicros(
Duration.ofSeconds(dayTime.getTotalSeconds).plusNanos(dayTime.getNanos.toLong))
} else {
null
}
}
case dt: HiveIntervalDayTimeObjectInspector =>
data: Any => {
if (data != null) {
val dayTime = dt.getPrimitiveJavaObject(data)
IntervalUtils.durationToMicros(
Duration.ofSeconds(dayTime.getTotalSeconds).plusNanos(dayTime.getNanos.toLong))
} else {
null
}
}
case ym: HiveIntervalYearMonthObjectInspector if ym.preferWritable() =>
data: Any => {
if (data != null) {
ym.getPrimitiveWritableObject(data).getHiveIntervalYearMonth.getTotalMonths
} else {
null
}
}
case ym: HiveIntervalYearMonthObjectInspector =>
data: Any => {
if (data != null) {
ym.getPrimitiveJavaObject(data).getTotalMonths
} else {
null
}
}
case _ =>
data: Any => {
if (data != null) {
@ -785,6 +840,10 @@ private[hive] trait HiveInspectors {
case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
case DateType => PrimitiveObjectInspectorFactory.javaDateObjectInspector
case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
case DayTimeIntervalType =>
PrimitiveObjectInspectorFactory.javaHiveIntervalDayTimeObjectInspector
case YearMonthIntervalType =>
PrimitiveObjectInspectorFactory.javaHiveIntervalYearMonthObjectInspector
// TODO decimal precision?
case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
case StructType(fields) =>
@ -830,6 +889,10 @@ private[hive] trait HiveInspectors {
getDecimalWritableConstantObjectInspector(value)
case Literal(_, NullType) =>
getPrimitiveNullWritableConstantObjectInspector
case Literal(_, DayTimeIntervalType) =>
getHiveIntervalDayTimeWritableConstantObjectInspector
case Literal(_, YearMonthIntervalType) =>
getHiveIntervalYearMonthWritableConstantObjectInspector
case Literal(value, ArrayType(dt, _)) =>
val listObjectInspector = toInspector(dt)
if (value == null) {
@ -906,6 +969,10 @@ private[hive] trait HiveInspectors {
case _: JavaDateObjectInspector => DateType
case _: WritableTimestampObjectInspector => TimestampType
case _: JavaTimestampObjectInspector => TimestampType
case _: WritableHiveIntervalDayTimeObjectInspector => DayTimeIntervalType
case _: JavaHiveIntervalDayTimeObjectInspector => DayTimeIntervalType
case _: WritableHiveIntervalYearMonthObjectInspector => YearMonthIntervalType
case _: JavaHiveIntervalYearMonthObjectInspector => YearMonthIntervalType
case _: WritableVoidObjectInspector => NullType
case _: JavaVoidObjectInspector => NullType
}
@ -967,6 +1034,14 @@ private[hive] trait HiveInspectors {
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.voidTypeInfo, null)
private def getHiveIntervalDayTimeWritableConstantObjectInspector: ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.intervalDayTimeTypeInfo, null)
private def getHiveIntervalYearMonthWritableConstantObjectInspector: ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.intervalYearMonthTypeInfo, null)
private def getStringWritable(value: Any): hadoopIo.Text =
if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].getBytes)
@ -1024,6 +1099,22 @@ private[hive] trait HiveInspectors {
new hiveIo.TimestampWritable(DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long]))
}
private def getHiveIntervalDayTimeWritable(value: Any): hiveIo.HiveIntervalDayTimeWritable =
if (value == null) {
null
} else {
val duration = IntervalUtils.microsToDuration(value.asInstanceOf[Long])
new hiveIo.HiveIntervalDayTimeWritable(
new HiveIntervalDayTime(duration.getSeconds, duration.getNano))
}
private def getHiveIntervalYearMonthWritable(value: Any): hiveIo.HiveIntervalYearMonthWritable =
if (value == null) {
null
} else {
new hiveIo.HiveIntervalYearMonthWritable(new HiveIntervalYearMonth(value.asInstanceOf[Int]))
}
private def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable =
if (value == null) {
null
@ -1064,6 +1155,8 @@ private[hive] trait HiveInspectors {
case DateType => dateTypeInfo
case TimestampType => timestampTypeInfo
case NullType => voidTypeInfo
case DayTimeIntervalType => intervalDayTimeTypeInfo
case YearMonthIntervalType => intervalYearMonthTypeInfo
case dt =>
throw new AnalysisException(
s"${dt.catalogString} cannot be converted to Hive TypeInfo")

View file

@ -18,6 +18,8 @@
package org.apache.spark.sql.hive.execution
import java.sql.Timestamp
import java.time.{Duration, Period}
import java.time.temporal.ChronoUnit
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.scalatest.exceptions.TestFailedException
@ -25,6 +27,7 @@ import org.scalatest.exceptions.TestFailedException
import org.apache.spark.{SparkException, TestUtils}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.util.DateTimeConstants
import org.apache.spark.sql.execution._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
@ -528,4 +531,61 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
checkAnswer(query2, identity, Row("\\N,\\N,\\N") :: Nil)
}
test("SPARK-34879: HiveInspectors supports DayTimeIntervalType and YearMonthIntervalType") {
assume(TestUtils.testCommandAvailable("/bin/bash"))
withTempView("v") {
val df = Seq(
(Duration.ofDays(1),
Duration.ofSeconds(100).plusNanos(123456),
Duration.of(Long.MaxValue, ChronoUnit.MICROS),
Period.ofMonths(10)),
(Duration.ofDays(1),
Duration.ofSeconds(100).plusNanos(1123456789),
Duration.ofSeconds(Long.MaxValue / DateTimeConstants.MICROS_PER_SECOND),
Period.ofMonths(10))
).toDF("a", "b", "c", "d")
df.createTempView("v")
// Hive serde supports DayTimeIntervalType/YearMonthIntervalType as input and output data type
checkAnswer(
df,
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(
df.col("a").expr,
df.col("b").expr,
df.col("c").expr,
df.col("d").expr),
script = "cat",
output = Seq(
AttributeReference("a", DayTimeIntervalType)(),
AttributeReference("b", DayTimeIntervalType)(),
AttributeReference("c", DayTimeIntervalType)(),
AttributeReference("d", YearMonthIntervalType)()),
child = child,
ioschema = hiveIOSchema),
df.select($"a", $"b", $"c", $"d").collect())
}
}
test("SPARK-34879: HiveInspectors throw overflow when" +
" HiveIntervalDayTime overflow then DayTimeIntervalType") {
withTempView("v") {
val df = Seq(("579025220 15:30:06.000001000")).toDF("a")
df.createTempView("v")
val e = intercept[Exception] {
checkAnswer(
df,
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(df.col("a").expr),
script = "cat",
output = Seq(AttributeReference("a", DayTimeIntervalType)()),
child = child,
ioschema = hiveIOSchema),
df.select($"a").collect())
}.getMessage
assert(e.contains("java.lang.ArithmeticException: long overflow"))
}
}
}