[SPARK-36022][SQL] Respect interval fields in extract
### What changes were proposed in this pull request?
This PR fixes an issue about `extract`.
`Extract` should process only existing fields of interval types. For example:
```
spark-sql> SELECT EXTRACT(MONTH FROM INTERVAL '2021-11' YEAR TO MONTH);
11
spark-sql> SELECT EXTRACT(MONTH FROM INTERVAL '2021' YEAR);
0
```
The last command should fail as the month field doesn't present in INTERVAL YEAR.
### Why are the changes needed?
Bug fix.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New tests.
Closes #33247 from sarutak/fix-extract-interval.
Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
(cherry picked from commit 39002cb995
)
Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
parent
12b29cd41a
commit
429d1780b3
|
@ -29,6 +29,8 @@ import org.apache.spark.sql.catalyst.util.IntervalUtils._
|
||||||
import org.apache.spark.sql.errors.QueryExecutionErrors
|
import org.apache.spark.sql.errors.QueryExecutionErrors
|
||||||
import org.apache.spark.sql.internal.SQLConf
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND}
|
||||||
|
import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
|
||||||
import org.apache.spark.unsafe.types.CalendarInterval
|
import org.apache.spark.unsafe.types.CalendarInterval
|
||||||
|
|
||||||
abstract class ExtractIntervalPart[T](
|
abstract class ExtractIntervalPart[T](
|
||||||
|
@ -125,33 +127,43 @@ object ExtractIntervalPart {
|
||||||
source: Expression,
|
source: Expression,
|
||||||
errorHandleFunc: => Nothing): Expression = {
|
errorHandleFunc: => Nothing): Expression = {
|
||||||
(extractField.toUpperCase(Locale.ROOT), source.dataType) match {
|
(extractField.toUpperCase(Locale.ROOT), source.dataType) match {
|
||||||
case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", _: YearMonthIntervalType) =>
|
case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", YearMonthIntervalType(start, end))
|
||||||
|
if isUnitInIntervalRange(YEAR, start, end) =>
|
||||||
ExtractANSIIntervalYears(source)
|
ExtractANSIIntervalYears(source)
|
||||||
case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", CalendarIntervalType) =>
|
case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", CalendarIntervalType) =>
|
||||||
ExtractIntervalYears(source)
|
ExtractIntervalYears(source)
|
||||||
case ("MONTH" | "MON" | "MONS" | "MONTHS", _: YearMonthIntervalType) =>
|
case ("MONTH" | "MON" | "MONS" | "MONTHS", YearMonthIntervalType(start, end))
|
||||||
|
if isUnitInIntervalRange(MONTH, start, end) =>
|
||||||
ExtractANSIIntervalMonths(source)
|
ExtractANSIIntervalMonths(source)
|
||||||
case ("MONTH" | "MON" | "MONS" | "MONTHS", CalendarIntervalType) =>
|
case ("MONTH" | "MON" | "MONS" | "MONTHS", CalendarIntervalType) =>
|
||||||
ExtractIntervalMonths(source)
|
ExtractIntervalMonths(source)
|
||||||
case ("DAY" | "D" | "DAYS", _: DayTimeIntervalType) =>
|
case ("DAY" | "D" | "DAYS", DayTimeIntervalType(start, end))
|
||||||
|
if isUnitInIntervalRange(DAY, start, end) =>
|
||||||
ExtractANSIIntervalDays(source)
|
ExtractANSIIntervalDays(source)
|
||||||
case ("DAY" | "D" | "DAYS", CalendarIntervalType) =>
|
case ("DAY" | "D" | "DAYS", CalendarIntervalType) =>
|
||||||
ExtractIntervalDays(source)
|
ExtractIntervalDays(source)
|
||||||
case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", _: DayTimeIntervalType) =>
|
case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", DayTimeIntervalType(start, end))
|
||||||
|
if isUnitInIntervalRange(HOUR, start, end) =>
|
||||||
ExtractANSIIntervalHours(source)
|
ExtractANSIIntervalHours(source)
|
||||||
case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", CalendarIntervalType) =>
|
case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", CalendarIntervalType) =>
|
||||||
ExtractIntervalHours(source)
|
ExtractIntervalHours(source)
|
||||||
case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", _: DayTimeIntervalType) =>
|
case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", DayTimeIntervalType(start, end))
|
||||||
|
if isUnitInIntervalRange(MINUTE, start, end) =>
|
||||||
ExtractANSIIntervalMinutes(source)
|
ExtractANSIIntervalMinutes(source)
|
||||||
case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", CalendarIntervalType) =>
|
case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", CalendarIntervalType) =>
|
||||||
ExtractIntervalMinutes(source)
|
ExtractIntervalMinutes(source)
|
||||||
case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", _: DayTimeIntervalType) =>
|
case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", DayTimeIntervalType(start, end))
|
||||||
|
if isUnitInIntervalRange(SECOND, start, end) =>
|
||||||
ExtractANSIIntervalSeconds(source)
|
ExtractANSIIntervalSeconds(source)
|
||||||
case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", CalendarIntervalType) =>
|
case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", CalendarIntervalType) =>
|
||||||
ExtractIntervalSeconds(source)
|
ExtractIntervalSeconds(source)
|
||||||
case _ => errorHandleFunc
|
case _ => errorHandleFunc
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private def isUnitInIntervalRange(unit: Byte, start: Byte, end: Byte): Boolean = {
|
||||||
|
start <= unit && unit <= end
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract class IntervalNumOperation(
|
abstract class IntervalNumOperation(
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql
|
||||||
|
|
||||||
|
import java.time.{Duration, Period}
|
||||||
|
|
||||||
|
import org.apache.spark.sql.test.SharedSparkSession
|
||||||
|
import org.apache.spark.sql.types.{DayTimeIntervalType => DT, YearMonthIntervalType => YM}
|
||||||
|
import org.apache.spark.sql.types.DataTypeTestUtils._
|
||||||
|
|
||||||
|
class IntervalFunctionsSuite extends QueryTest with SharedSparkSession {
|
||||||
|
import testImplicits._
|
||||||
|
|
||||||
|
test("SPARK-36022: Respect interval fields in extract") {
|
||||||
|
yearMonthIntervalTypes.foreach { dtype =>
|
||||||
|
val ymDF = Seq(Period.of(1, 2, 0)).toDF.select($"value" cast dtype as "value")
|
||||||
|
.select($"value" cast dtype as "value")
|
||||||
|
val expectedMap = Map("year" -> 1, "month" -> 2)
|
||||||
|
YM.yearMonthFields.foreach { field =>
|
||||||
|
val extractUnit = YM.fieldToString(field)
|
||||||
|
val extractExpr = s"extract($extractUnit FROM value)"
|
||||||
|
if (dtype.startField <= field && field <= dtype.endField) {
|
||||||
|
checkAnswer(ymDF.selectExpr(extractExpr), Row(expectedMap(extractUnit)))
|
||||||
|
} else {
|
||||||
|
intercept[AnalysisException] {
|
||||||
|
ymDF.selectExpr(extractExpr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dayTimeIntervalTypes.foreach { dtype =>
|
||||||
|
val dtDF = Seq(Duration.ofDays(1).plusHours(2).plusMinutes(3).plusSeconds(4)).toDF
|
||||||
|
.select($"value" cast dtype as "value")
|
||||||
|
val expectedMap = Map("day" -> 1, "hour" -> 2, "minute" -> 3, "second" -> 4)
|
||||||
|
DT.dayTimeFields.foreach { field =>
|
||||||
|
val extractUnit = DT.fieldToString(field)
|
||||||
|
val extractExpr = s"extract($extractUnit FROM value)"
|
||||||
|
if (dtype.startField <= field && field <= dtype.endField) {
|
||||||
|
checkAnswer(dtDF.selectExpr(extractExpr), Row(expectedMap(extractUnit)))
|
||||||
|
} else {
|
||||||
|
intercept[AnalysisException] {
|
||||||
|
dtDF.selectExpr(extractExpr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue