diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 5d49007f28..5b111d17cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.catalyst.util.IntervalUtils._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND} +import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR} import org.apache.spark.unsafe.types.CalendarInterval abstract class ExtractIntervalPart[T]( @@ -125,33 +127,43 @@ object ExtractIntervalPart { source: Expression, errorHandleFunc: => Nothing): Expression = { (extractField.toUpperCase(Locale.ROOT), source.dataType) match { - case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", _: YearMonthIntervalType) => + case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", YearMonthIntervalType(start, end)) + if isUnitInIntervalRange(YEAR, start, end) => ExtractANSIIntervalYears(source) case ("YEAR" | "Y" | "YEARS" | "YR" | "YRS", CalendarIntervalType) => ExtractIntervalYears(source) - case ("MONTH" | "MON" | "MONS" | "MONTHS", _: YearMonthIntervalType) => + case ("MONTH" | "MON" | "MONS" | "MONTHS", YearMonthIntervalType(start, end)) + if isUnitInIntervalRange(MONTH, start, end) => ExtractANSIIntervalMonths(source) case ("MONTH" | "MON" | "MONS" | "MONTHS", CalendarIntervalType) => ExtractIntervalMonths(source) - case ("DAY" | "D" | "DAYS", _: DayTimeIntervalType) => + case ("DAY" | "D" | "DAYS", DayTimeIntervalType(start, end)) + if isUnitInIntervalRange(DAY, start, end) => ExtractANSIIntervalDays(source) case ("DAY" | "D" | "DAYS", CalendarIntervalType) => ExtractIntervalDays(source) - case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", _: DayTimeIntervalType) => + case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", DayTimeIntervalType(start, end)) + if isUnitInIntervalRange(HOUR, start, end) => ExtractANSIIntervalHours(source) case ("HOUR" | "H" | "HOURS" | "HR" | "HRS", CalendarIntervalType) => ExtractIntervalHours(source) - case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", _: DayTimeIntervalType) => + case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", DayTimeIntervalType(start, end)) + if isUnitInIntervalRange(MINUTE, start, end) => ExtractANSIIntervalMinutes(source) case ("MINUTE" | "M" | "MIN" | "MINS" | "MINUTES", CalendarIntervalType) => ExtractIntervalMinutes(source) - case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", _: DayTimeIntervalType) => + case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", DayTimeIntervalType(start, end)) + if isUnitInIntervalRange(SECOND, start, end) => ExtractANSIIntervalSeconds(source) case ("SECOND" | "S" | "SEC" | "SECONDS" | "SECS", CalendarIntervalType) => ExtractIntervalSeconds(source) case _ => errorHandleFunc } } + + private def isUnitInIntervalRange(unit: Byte, start: Byte, end: Byte): Boolean = { + start <= unit && unit <= end + } } abstract class IntervalNumOperation( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntervalFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntervalFunctionsSuite.scala new file mode 100644 index 0000000000..c7e307b625 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntervalFunctionsSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.time.{Duration, Period} + +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DayTimeIntervalType => DT, YearMonthIntervalType => YM} +import org.apache.spark.sql.types.DataTypeTestUtils._ + +class IntervalFunctionsSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("SPARK-36022: Respect interval fields in extract") { + yearMonthIntervalTypes.foreach { dtype => + val ymDF = Seq(Period.of(1, 2, 0)).toDF.select($"value" cast dtype as "value") + .select($"value" cast dtype as "value") + val expectedMap = Map("year" -> 1, "month" -> 2) + YM.yearMonthFields.foreach { field => + val extractUnit = YM.fieldToString(field) + val extractExpr = s"extract($extractUnit FROM value)" + if (dtype.startField <= field && field <= dtype.endField) { + checkAnswer(ymDF.selectExpr(extractExpr), Row(expectedMap(extractUnit))) + } else { + intercept[AnalysisException] { + ymDF.selectExpr(extractExpr) + } + } + } + } + + dayTimeIntervalTypes.foreach { dtype => + val dtDF = Seq(Duration.ofDays(1).plusHours(2).plusMinutes(3).plusSeconds(4)).toDF + .select($"value" cast dtype as "value") + val expectedMap = Map("day" -> 1, "hour" -> 2, "minute" -> 3, "second" -> 4) + DT.dayTimeFields.foreach { field => + val extractUnit = DT.fieldToString(field) + val extractExpr = s"extract($extractUnit FROM value)" + if (dtype.startField <= field && field <= dtype.endField) { + checkAnswer(dtDF.selectExpr(extractExpr), Row(expectedMap(extractUnit))) + } else { + intercept[AnalysisException] { + dtDF.selectExpr(extractExpr) + } + } + } + } + } +}