diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 83ee1913da..c9862cb629 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -653,6 +653,12 @@ object StructType extends AbstractDataType { case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) if leftUdt.userClass == rightUdt.userClass => leftUdt + case (YearMonthIntervalType(lstart, lend), YearMonthIntervalType(rstart, rend)) => + YearMonthIntervalType(Math.min(lstart, rstart).toByte, Math.max(lend, rend).toByte) + + case (DayTimeIntervalType(lstart, lend), DayTimeIntervalType(rstart, rend)) => + DayTimeIntervalType(Math.min(lstart, rstart).toByte, Math.max(lend, rend).toByte) + case (leftType, rightType) if leftType == rightType => leftType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index 8db3831392..8cc04c78d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -25,7 +25,9 @@ import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DayTimeIntervalType => DT} import org.apache.spark.sql.types.{YearMonthIntervalType => YM} +import org.apache.spark.sql.types.DayTimeIntervalType._ import org.apache.spark.sql.types.StructType.fromDDL +import org.apache.spark.sql.types.YearMonthIntervalType._ class StructTypeSuite extends SparkFunSuite with SQLHelper { @@ -382,4 +384,25 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { assert(e.getMessage.contains( "Field name a2.element.C.name is invalid: a2.element.c is not a struct")) } + + test("SPARK-36807: Merge ANSI interval types to a tightest common type") { + Seq( + (YM(YEAR), YM(YEAR)) -> YM(YEAR), + (YM(YEAR), YM(MONTH)) -> YM(YEAR, MONTH), + (YM(MONTH), YM(MONTH)) -> YM(MONTH), + (YM(YEAR, MONTH), YM(YEAR)) -> YM(YEAR, MONTH), + (YM(YEAR, MONTH), YM(YEAR, MONTH)) -> YM(YEAR, MONTH), + (DT(DAY), DT(DAY)) -> DT(DAY), + (DT(SECOND), DT(SECOND)) -> DT(SECOND), + (DT(DAY), DT(SECOND)) -> DT(DAY, SECOND), + (DT(HOUR, SECOND), DT(DAY, MINUTE)) -> DT(DAY, SECOND), + (DT(HOUR, MINUTE), DT(DAY, SECOND)) -> DT(DAY, SECOND) + ).foreach { case ((i1, i2), expected) => + val st1 = new StructType().add("interval", i1) + val st2 = new StructType().add("interval", i2) + val expectedStruct = new StructType().add("interval", expected) + assert(st1.merge(st2) === expectedStruct) + assert(st2.merge(st1) === expectedStruct) + } + } }