[SPARK-35776][SQL][TESTS] Check all year-month interval types in arrow

### What changes were proposed in this pull request?
Add tests to check that all year-month interval types are supported in (de-)serialization from/to Arrow format.

### Why are the changes needed?
New tests should improve test coverage.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
added ut

Closes #32993 from AngersZhuuuu/SPARK-35776.

Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
Angerszhuuuu 2021-06-23 10:59:50 +03:00 committed by Max Gekk
parent 79e3d0d98f
commit 7c1a9dd3f5

View file

@ -30,7 +30,11 @@ class ArrowWriterSuite extends SparkFunSuite {
test("simple") { test("simple") {
def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = {
val schema = new StructType().add("value", dt, nullable = true) val avroDatatype = dt match {
case _: YearMonthIntervalType => YearMonthIntervalType()
case tpe => tpe
}
val schema = new StructType().add("value", avroDatatype, nullable = true)
val writer = ArrowWriter.create(schema, timeZoneId) val writer = ArrowWriter.create(schema, timeZoneId)
assert(writer.schema === schema) assert(writer.schema === schema)
@ -77,7 +81,8 @@ class ArrowWriterSuite extends SparkFunSuite {
check(DateType, Seq(0, 1, 2, null, 4)) check(DateType, Seq(0, 1, 2, null, 4))
check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles") check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles")
check(NullType, Seq(null, null, null)) check(NullType, Seq(null, null, null))
check(YearMonthIntervalType(), Seq(null, 0, 1, -1, Int.MaxValue, Int.MinValue)) DataTypeTestUtils.yearMonthIntervalTypes
.foreach(check(_, Seq(null, 0, 1, -1, Int.MaxValue, Int.MinValue)))
check(DayTimeIntervalType(), Seq(null, 0L, 1000L, -1000L, (Long.MaxValue - 807L), check(DayTimeIntervalType(), Seq(null, 0L, 1000L, -1000L, (Long.MaxValue - 807L),
(Long.MinValue + 808L))) (Long.MinValue + 808L)))
} }
@ -108,7 +113,11 @@ class ArrowWriterSuite extends SparkFunSuite {
test("get multiple") { test("get multiple") {
def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = {
val schema = new StructType().add("value", dt, nullable = false) val avroDatatype = dt match {
case _: YearMonthIntervalType => YearMonthIntervalType()
case tpe => tpe
}
val schema = new StructType().add("value", avroDatatype, nullable = false)
val writer = ArrowWriter.create(schema, timeZoneId) val writer = ArrowWriter.create(schema, timeZoneId)
assert(writer.schema === schema) assert(writer.schema === schema)
@ -144,8 +153,7 @@ class ArrowWriterSuite extends SparkFunSuite {
check(DoubleType, (0 until 10).map(_.toDouble)) check(DoubleType, (0 until 10).map(_.toDouble))
check(DateType, (0 until 10)) check(DateType, (0 until 10))
check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), "America/Los_Angeles") check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), "America/Los_Angeles")
// TODO(SPARK-35776): Check all year-month interval types in arrow DataTypeTestUtils.yearMonthIntervalTypes.foreach(check(_, (0 until 14)))
check(YearMonthIntervalType(), (0 until 10))
// TODO(SPARK-35731): Check all day-time interval types in arrow // TODO(SPARK-35731): Check all day-time interval types in arrow
check(DayTimeIntervalType(), (-10 until 10).map(_ * 1000.toLong)) check(DayTimeIntervalType(), (-10 until 10).map(_ * 1000.toLong))
} }