diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index c56a2103a9..8ebff296bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -30,7 +30,11 @@ class ArrowWriterSuite extends SparkFunSuite { test("simple") { def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { - val schema = new StructType().add("value", dt, nullable = true) + val avroDatatype = dt match { + case _: YearMonthIntervalType => YearMonthIntervalType() + case tpe => tpe + } + val schema = new StructType().add("value", avroDatatype, nullable = true) val writer = ArrowWriter.create(schema, timeZoneId) assert(writer.schema === schema) @@ -77,7 +81,8 @@ class ArrowWriterSuite extends SparkFunSuite { check(DateType, Seq(0, 1, 2, null, 4)) check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles") check(NullType, Seq(null, null, null)) - check(YearMonthIntervalType(), Seq(null, 0, 1, -1, Int.MaxValue, Int.MinValue)) + DataTypeTestUtils.yearMonthIntervalTypes + .foreach(check(_, Seq(null, 0, 1, -1, Int.MaxValue, Int.MinValue))) check(DayTimeIntervalType(), Seq(null, 0L, 1000L, -1000L, (Long.MaxValue - 807L), (Long.MinValue + 808L))) } @@ -108,7 +113,11 @@ class ArrowWriterSuite extends SparkFunSuite { test("get multiple") { def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { - val schema = new StructType().add("value", dt, nullable = false) + val avroDatatype = dt match { + case _: YearMonthIntervalType => YearMonthIntervalType() + case tpe => tpe + } + val schema = new StructType().add("value", avroDatatype, nullable = false) val writer = ArrowWriter.create(schema, timeZoneId) assert(writer.schema === schema) @@ -144,8 +153,7 @@ class ArrowWriterSuite extends SparkFunSuite { check(DoubleType, (0 until 10).map(_.toDouble)) check(DateType, (0 until 10)) check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), "America/Los_Angeles") - // TODO(SPARK-35776): Check all year-month interval types in arrow - check(YearMonthIntervalType(), (0 until 10)) + DataTypeTestUtils.yearMonthIntervalTypes.foreach(check(_, (0 until 14))) // TODO(SPARK-35731): Check all day-time interval types in arrow check(DayTimeIntervalType(), (-10 until 10).map(_ * 1000.toLong)) }