[SPARK-36022][SQL] Respect interval fields in extract

### What changes were proposed in this pull request?

This PR fixes an issue about `extract`.
`Extract` should process only existing fields of interval types. For example:

```
spark-sql> SELECT EXTRACT(MONTH FROM INTERVAL '2021-11' YEAR TO MONTH);
11
spark-sql> SELECT EXTRACT(MONTH FROM INTERVAL '2021' YEAR);
0
```
The last command should fail as the month field doesn't present in INTERVAL YEAR.

### Why are the changes needed?

Bug fix.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New tests.

Closes #33247 from sarutak/fix-extract-interval.

Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
(cherry picked from commit 39002cb995)
Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
Kousuke Saruta 2021-07-08 09:40:57 +03:00 committed by Max Gekk
parent 12b29cd41a
commit 429d1780b3
2 changed files with 82 additions and 6 deletions

View file

@ -29,6 +29,8 @@ import org.apache.spark.sql.catalyst.util.IntervalUtils._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND}
import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
import org.apache.spark.unsafe.types.CalendarInterval
abstract class ExtractIntervalPart[T](
@ -125,33 +127,43 @@ object ExtractIntervalPart {
source: Expression,
errorHandleFunc: => Nothing): Expression = {
(extractField.toUpperCase(Locale.ROOT), source.dataType) match {
case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", _: YearMonthIntervalType) =>
case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", YearMonthIntervalType(start, end))
if isUnitInIntervalRange(YEAR, start, end) =>
ExtractANSIIntervalYears(source)
case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", CalendarIntervalType) =>
ExtractIntervalYears(source)
case ("MONTH" | "MON" | "MONS" | "MONTHS", _: YearMonthIntervalType) =>
case ("MONTH" | "MON" | "MONS" | "MONTHS", YearMonthIntervalType(start, end))
if isUnitInIntervalRange(MONTH, start, end) =>
ExtractANSIIntervalMonths(source)
case ("MONTH" | "MON" | "MONS" | "MONTHS", CalendarIntervalType) =>
ExtractIntervalMonths(source)
case ("DAY" | "D" | "DAYS", _: DayTimeIntervalType) =>
case ("DAY" | "D" | "DAYS", DayTimeIntervalType(start, end))
if isUnitInIntervalRange(DAY, start, end) =>
ExtractANSIIntervalDays(source)
case ("DAY" | "D" | "DAYS", CalendarIntervalType) =>
ExtractIntervalDays(source)
case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", _: DayTimeIntervalType) =>
case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", DayTimeIntervalType(start, end))
if isUnitInIntervalRange(HOUR, start, end) =>
ExtractANSIIntervalHours(source)
case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", CalendarIntervalType) =>
ExtractIntervalHours(source)
case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", _: DayTimeIntervalType) =>
case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", DayTimeIntervalType(start, end))
if isUnitInIntervalRange(MINUTE, start, end) =>
ExtractANSIIntervalMinutes(source)
case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", CalendarIntervalType) =>
ExtractIntervalMinutes(source)
case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", _: DayTimeIntervalType) =>
case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", DayTimeIntervalType(start, end))
if isUnitInIntervalRange(SECOND, start, end) =>
ExtractANSIIntervalSeconds(source)
case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", CalendarIntervalType) =>
ExtractIntervalSeconds(source)
case _ => errorHandleFunc
}
}
private def isUnitInIntervalRange(unit: Byte, start: Byte, end: Byte): Boolean = {
start <= unit && unit <= end
}
}
abstract class IntervalNumOperation(

View file

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import java.time.{Duration, Period}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{DayTimeIntervalType => DT, YearMonthIntervalType => YM}
import org.apache.spark.sql.types.DataTypeTestUtils._
class IntervalFunctionsSuite extends QueryTest with SharedSparkSession {
import testImplicits._
test("SPARK-36022: Respect interval fields in extract") {
yearMonthIntervalTypes.foreach { dtype =>
val ymDF = Seq(Period.of(1, 2, 0)).toDF.select($"value" cast dtype as "value")
.select($"value" cast dtype as "value")
val expectedMap = Map("year" -> 1, "month" -> 2)
YM.yearMonthFields.foreach { field =>
val extractUnit = YM.fieldToString(field)
val extractExpr = s"extract($extractUnit FROM value)"
if (dtype.startField <= field && field <= dtype.endField) {
checkAnswer(ymDF.selectExpr(extractExpr), Row(expectedMap(extractUnit)))
} else {
intercept[AnalysisException] {
ymDF.selectExpr(extractExpr)
}
}
}
}
dayTimeIntervalTypes.foreach { dtype =>
val dtDF = Seq(Duration.ofDays(1).plusHours(2).plusMinutes(3).plusSeconds(4)).toDF
.select($"value" cast dtype as "value")
val expectedMap = Map("day" -> 1, "hour" -> 2, "minute" -> 3, "second" -> 4)
DT.dayTimeFields.foreach { field =>
val extractUnit = DT.fieldToString(field)
val extractExpr = s"extract($extractUnit FROM value)"
if (dtype.startField <= field && field <= dtype.endField) {
checkAnswer(dtDF.selectExpr(extractExpr), Row(expectedMap(extractUnit)))
} else {
intercept[AnalysisException] {
dtDF.selectExpr(extractExpr)
}
}
}
}
}
}