[SPARK-36022][SQL] Respect interval fields in extract
### What changes were proposed in this pull request?
This PR fixes an issue about `extract`.
`Extract` should process only existing fields of interval types. For example:
```
spark-sql> SELECT EXTRACT(MONTH FROM INTERVAL '2021-11' YEAR TO MONTH);
11
spark-sql> SELECT EXTRACT(MONTH FROM INTERVAL '2021' YEAR);
0
```
The last command should fail as the month field doesn't present in INTERVAL YEAR.
### Why are the changes needed?
Bug fix.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New tests.
Closes #33247 from sarutak/fix-extract-interval.
Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
(cherry picked from commit 39002cb995
)
Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
parent
12b29cd41a
commit
429d1780b3
|
@ -29,6 +29,8 @@ import org.apache.spark.sql.catalyst.util.IntervalUtils._
|
|||
import org.apache.spark.sql.errors.QueryExecutionErrors
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND}
|
||||
import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
|
||||
import org.apache.spark.unsafe.types.CalendarInterval
|
||||
|
||||
abstract class ExtractIntervalPart[T](
|
||||
|
@ -125,33 +127,43 @@ object ExtractIntervalPart {
|
|||
source: Expression,
|
||||
errorHandleFunc: => Nothing): Expression = {
|
||||
(extractField.toUpperCase(Locale.ROOT), source.dataType) match {
|
||||
case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", _: YearMonthIntervalType) =>
|
||||
case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", YearMonthIntervalType(start, end))
|
||||
if isUnitInIntervalRange(YEAR, start, end) =>
|
||||
ExtractANSIIntervalYears(source)
|
||||
case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", CalendarIntervalType) =>
|
||||
ExtractIntervalYears(source)
|
||||
case ("MONTH" | "MON" | "MONS" | "MONTHS", _: YearMonthIntervalType) =>
|
||||
case ("MONTH" | "MON" | "MONS" | "MONTHS", YearMonthIntervalType(start, end))
|
||||
if isUnitInIntervalRange(MONTH, start, end) =>
|
||||
ExtractANSIIntervalMonths(source)
|
||||
case ("MONTH" | "MON" | "MONS" | "MONTHS", CalendarIntervalType) =>
|
||||
ExtractIntervalMonths(source)
|
||||
case ("DAY" | "D" | "DAYS", _: DayTimeIntervalType) =>
|
||||
case ("DAY" | "D" | "DAYS", DayTimeIntervalType(start, end))
|
||||
if isUnitInIntervalRange(DAY, start, end) =>
|
||||
ExtractANSIIntervalDays(source)
|
||||
case ("DAY" | "D" | "DAYS", CalendarIntervalType) =>
|
||||
ExtractIntervalDays(source)
|
||||
case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", _: DayTimeIntervalType) =>
|
||||
case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", DayTimeIntervalType(start, end))
|
||||
if isUnitInIntervalRange(HOUR, start, end) =>
|
||||
ExtractANSIIntervalHours(source)
|
||||
case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", CalendarIntervalType) =>
|
||||
ExtractIntervalHours(source)
|
||||
case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", _: DayTimeIntervalType) =>
|
||||
case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", DayTimeIntervalType(start, end))
|
||||
if isUnitInIntervalRange(MINUTE, start, end) =>
|
||||
ExtractANSIIntervalMinutes(source)
|
||||
case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", CalendarIntervalType) =>
|
||||
ExtractIntervalMinutes(source)
|
||||
case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", _: DayTimeIntervalType) =>
|
||||
case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", DayTimeIntervalType(start, end))
|
||||
if isUnitInIntervalRange(SECOND, start, end) =>
|
||||
ExtractANSIIntervalSeconds(source)
|
||||
case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", CalendarIntervalType) =>
|
||||
ExtractIntervalSeconds(source)
|
||||
case _ => errorHandleFunc
|
||||
}
|
||||
}
|
||||
|
||||
private def isUnitInIntervalRange(unit: Byte, start: Byte, end: Byte): Boolean = {
|
||||
start <= unit && unit <= end
|
||||
}
|
||||
}
|
||||
|
||||
abstract class IntervalNumOperation(
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import java.time.{Duration, Period}
|
||||
|
||||
import org.apache.spark.sql.test.SharedSparkSession
|
||||
import org.apache.spark.sql.types.{DayTimeIntervalType => DT, YearMonthIntervalType => YM}
|
||||
import org.apache.spark.sql.types.DataTypeTestUtils._
|
||||
|
||||
class IntervalFunctionsSuite extends QueryTest with SharedSparkSession {
|
||||
import testImplicits._
|
||||
|
||||
test("SPARK-36022: Respect interval fields in extract") {
|
||||
yearMonthIntervalTypes.foreach { dtype =>
|
||||
val ymDF = Seq(Period.of(1, 2, 0)).toDF.select($"value" cast dtype as "value")
|
||||
.select($"value" cast dtype as "value")
|
||||
val expectedMap = Map("year" -> 1, "month" -> 2)
|
||||
YM.yearMonthFields.foreach { field =>
|
||||
val extractUnit = YM.fieldToString(field)
|
||||
val extractExpr = s"extract($extractUnit FROM value)"
|
||||
if (dtype.startField <= field && field <= dtype.endField) {
|
||||
checkAnswer(ymDF.selectExpr(extractExpr), Row(expectedMap(extractUnit)))
|
||||
} else {
|
||||
intercept[AnalysisException] {
|
||||
ymDF.selectExpr(extractExpr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dayTimeIntervalTypes.foreach { dtype =>
|
||||
val dtDF = Seq(Duration.ofDays(1).plusHours(2).plusMinutes(3).plusSeconds(4)).toDF
|
||||
.select($"value" cast dtype as "value")
|
||||
val expectedMap = Map("day" -> 1, "hour" -> 2, "minute" -> 3, "second" -> 4)
|
||||
DT.dayTimeFields.foreach { field =>
|
||||
val extractUnit = DT.fieldToString(field)
|
||||
val extractExpr = s"extract($extractUnit FROM value)"
|
||||
if (dtype.startField <= field && field <= dtype.endField) {
|
||||
checkAnswer(dtDF.selectExpr(extractExpr), Row(expectedMap(extractUnit)))
|
||||
} else {
|
||||
intercept[AnalysisException] {
|
||||
dtDF.selectExpr(extractExpr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue