From 015c59843c902339c365e3244b6d24f9050c8080 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Mon, 29 Mar 2021 08:38:20 +0300 Subject: [PATCH] [SPARK-34879][SQL] HiveInspector supports DayTimeIntervalType and YearMonthIntervalType ### What changes were proposed in this pull request? Make HiveInspector support DayTimeIntervalType and YearMonthIntervalType. Then we can use these two types in HiveUDF and HiveScriptTransformation ### Why are the changes needed? Support more data type when use hive serde ### Does this PR introduce _any_ user-facing change? User can use `DayTimeIntervalType` and `YearMonthIntervalType` in HiveUDF and HiveScriptTransformation ### How was this patch tested? Added UT Closes #31979 from AngersZhuuuu/SPARK-34879. Authored-by: Angerszhuuuu Signed-off-by: Max Gekk --- .../spark/sql/hive/HiveInspectors.scala | 95 ++++++++++++++++++- .../HiveScriptTransformationSuite.scala | 60 ++++++++++++ 2 files changed, 154 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 9213173bbc..37a1fc0bae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.hive import java.lang.reflect.{ParameterizedType, Type, WildcardType} +import java.time.Duration import scala.collection.JavaConverters._ import org.apache.hadoop.{io => hadoopIo} -import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveIntervalDayTime, HiveIntervalYearMonth, HiveVarchar} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ @@ -346,6 +347,17 @@ private[hive] trait HiveInspectors { withNullSafe(o => getTimestampWritable(o)) case _: TimestampObjectInspector => withNullSafe(o => DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long])) + case _: HiveIntervalDayTimeObjectInspector if x.preferWritable() => + withNullSafe(o => getHiveIntervalDayTimeWritable(o)) + case _: HiveIntervalDayTimeObjectInspector => + withNullSafe(o => { + val duration = IntervalUtils.microsToDuration(o.asInstanceOf[Long]) + new HiveIntervalDayTime(duration.getSeconds, duration.getNano) + }) + case _: HiveIntervalYearMonthObjectInspector if x.preferWritable() => + withNullSafe(o => getHiveIntervalYearMonthWritable(o)) + case _: HiveIntervalYearMonthObjectInspector => + withNullSafe(o => new HiveIntervalYearMonth(o.asInstanceOf[Int])) case _: VoidObjectInspector => (_: Any) => null // always be null for void object inspector } @@ -512,6 +524,13 @@ private[hive] trait HiveInspectors { _ => constant case poi: VoidObjectInspector => _ => null // always be null for void object inspector + case dt: WritableConstantHiveIntervalDayTimeObjectInspector => + val constant = dt.getWritableConstantValue.asInstanceOf[HiveIntervalDayTime] + _ => IntervalUtils.durationToMicros( + Duration.ofSeconds(constant.getTotalSeconds).plusNanos(constant.getNanos.toLong)) + case ym: WritableConstantHiveIntervalYearMonthObjectInspector => + val constant = ym.getWritableConstantValue.asInstanceOf[HiveIntervalYearMonth] + _ => constant.getTotalMonths case pi: PrimitiveObjectInspector => pi match { // We think HiveVarchar/HiveChar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => @@ -647,6 +666,42 @@ private[hive] trait HiveInspectors { null } } + case dt: HiveIntervalDayTimeObjectInspector if dt.preferWritable() => + data: Any => { + if (data != null) { + val dayTime = dt.getPrimitiveWritableObject(data).getHiveIntervalDayTime + IntervalUtils.durationToMicros( + Duration.ofSeconds(dayTime.getTotalSeconds).plusNanos(dayTime.getNanos.toLong)) + } else { + null + } + } + case dt: HiveIntervalDayTimeObjectInspector => + data: Any => { + if (data != null) { + val dayTime = dt.getPrimitiveJavaObject(data) + IntervalUtils.durationToMicros( + Duration.ofSeconds(dayTime.getTotalSeconds).plusNanos(dayTime.getNanos.toLong)) + } else { + null + } + } + case ym: HiveIntervalYearMonthObjectInspector if ym.preferWritable() => + data: Any => { + if (data != null) { + ym.getPrimitiveWritableObject(data).getHiveIntervalYearMonth.getTotalMonths + } else { + null + } + } + case ym: HiveIntervalYearMonthObjectInspector => + data: Any => { + if (data != null) { + ym.getPrimitiveJavaObject(data).getTotalMonths + } else { + null + } + } case _ => data: Any => { if (data != null) { @@ -785,6 +840,10 @@ private[hive] trait HiveInspectors { case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector case DateType => PrimitiveObjectInspectorFactory.javaDateObjectInspector case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector + case DayTimeIntervalType => + PrimitiveObjectInspectorFactory.javaHiveIntervalDayTimeObjectInspector + case YearMonthIntervalType => + PrimitiveObjectInspectorFactory.javaHiveIntervalYearMonthObjectInspector // TODO decimal precision? case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => @@ -830,6 +889,10 @@ private[hive] trait HiveInspectors { getDecimalWritableConstantObjectInspector(value) case Literal(_, NullType) => getPrimitiveNullWritableConstantObjectInspector + case Literal(_, DayTimeIntervalType) => + getHiveIntervalDayTimeWritableConstantObjectInspector + case Literal(_, YearMonthIntervalType) => + getHiveIntervalYearMonthWritableConstantObjectInspector case Literal(value, ArrayType(dt, _)) => val listObjectInspector = toInspector(dt) if (value == null) { @@ -906,6 +969,10 @@ private[hive] trait HiveInspectors { case _: JavaDateObjectInspector => DateType case _: WritableTimestampObjectInspector => TimestampType case _: JavaTimestampObjectInspector => TimestampType + case _: WritableHiveIntervalDayTimeObjectInspector => DayTimeIntervalType + case _: JavaHiveIntervalDayTimeObjectInspector => DayTimeIntervalType + case _: WritableHiveIntervalYearMonthObjectInspector => YearMonthIntervalType + case _: JavaHiveIntervalYearMonthObjectInspector => YearMonthIntervalType case _: WritableVoidObjectInspector => NullType case _: JavaVoidObjectInspector => NullType } @@ -967,6 +1034,14 @@ private[hive] trait HiveInspectors { PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( TypeInfoFactory.voidTypeInfo, null) + private def getHiveIntervalDayTimeWritableConstantObjectInspector: ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.intervalDayTimeTypeInfo, null) + + private def getHiveIntervalYearMonthWritableConstantObjectInspector: ObjectInspector = + PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.intervalYearMonthTypeInfo, null) + private def getStringWritable(value: Any): hadoopIo.Text = if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].getBytes) @@ -1024,6 +1099,22 @@ private[hive] trait HiveInspectors { new hiveIo.TimestampWritable(DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long])) } + private def getHiveIntervalDayTimeWritable(value: Any): hiveIo.HiveIntervalDayTimeWritable = + if (value == null) { + null + } else { + val duration = IntervalUtils.microsToDuration(value.asInstanceOf[Long]) + new hiveIo.HiveIntervalDayTimeWritable( + new HiveIntervalDayTime(duration.getSeconds, duration.getNano)) + } + + private def getHiveIntervalYearMonthWritable(value: Any): hiveIo.HiveIntervalYearMonthWritable = + if (value == null) { + null + } else { + new hiveIo.HiveIntervalYearMonthWritable(new HiveIntervalYearMonth(value.asInstanceOf[Int])) + } + private def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = if (value == null) { null @@ -1064,6 +1155,8 @@ private[hive] trait HiveInspectors { case DateType => dateTypeInfo case TimestampType => timestampTypeInfo case NullType => voidTypeInfo + case DayTimeIntervalType => intervalDayTimeTypeInfo + case YearMonthIntervalType => intervalYearMonthTypeInfo case dt => throw new AnalysisException( s"${dt.catalogString} cannot be converted to Hive TypeInfo") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index 3892caa51e..a2a996b419 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.hive.execution import java.sql.Timestamp +import java.time.{Duration, Period} +import java.time.temporal.ChronoUnit import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.scalatest.exceptions.TestFailedException @@ -25,6 +27,7 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.util.DateTimeConstants import org.apache.spark.sql.execution._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -528,4 +531,61 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T checkAnswer(query2, identity, Row("\\N,\\N,\\N") :: Nil) } + + test("SPARK-34879: HiveInspectors supports DayTimeIntervalType and YearMonthIntervalType") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + withTempView("v") { + val df = Seq( + (Duration.ofDays(1), + Duration.ofSeconds(100).plusNanos(123456), + Duration.of(Long.MaxValue, ChronoUnit.MICROS), + Period.ofMonths(10)), + (Duration.ofDays(1), + Duration.ofSeconds(100).plusNanos(1123456789), + Duration.ofSeconds(Long.MaxValue / DateTimeConstants.MICROS_PER_SECOND), + Period.ofMonths(10)) + ).toDF("a", "b", "c", "d") + df.createTempView("v") + + // Hive serde supports DayTimeIntervalType/YearMonthIntervalType as input and output data type + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("a").expr, + df.col("b").expr, + df.col("c").expr, + df.col("d").expr), + script = "cat", + output = Seq( + AttributeReference("a", DayTimeIntervalType)(), + AttributeReference("b", DayTimeIntervalType)(), + AttributeReference("c", DayTimeIntervalType)(), + AttributeReference("d", YearMonthIntervalType)()), + child = child, + ioschema = hiveIOSchema), + df.select($"a", $"b", $"c", $"d").collect()) + } + } + + test("SPARK-34879: HiveInspectors throw overflow when" + + " HiveIntervalDayTime overflow then DayTimeIntervalType") { + withTempView("v") { + val df = Seq(("579025220 15:30:06.000001000")).toDF("a") + df.createTempView("v") + + val e = intercept[Exception] { + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq(df.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", DayTimeIntervalType)()), + child = child, + ioschema = hiveIOSchema), + df.select($"a").collect()) + }.getMessage + assert(e.contains("java.lang.ArithmeticException: long overflow")) + } + } }