[SPARK-36023][SPARK-35735][SPARK-35768][SQL] Refactor code about parse string to DT/YM

### What changes were proposed in this pull request?
 Refactor code about parse string to DT/YM intervals.

### Why are the changes needed?
Extracting the common code about parse string to DT/YM should improve code maintenance.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existed UT.

Closes #33217 from AngersZhuuuu/SPARK-35735-35768.

Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
Angerszhuuuu 2021-07-06 13:51:06 +03:00 committed by Max Gekk
parent def8bc5c96
commit 26d1bb16bc
2 changed files with 123 additions and 106 deletions

View file

@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToMicros
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.{ANSI_STYLE, HIVE_STYLE, IntervalStyle}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DayTimeIntervalType => DT, Decimal, YearMonthIntervalType => YM}
import org.apache.spark.sql.types.{DataType, DayTimeIntervalType => DT, Decimal, YearMonthIntervalType => YM}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
// The style of textual representation of intervals
@ -110,7 +110,7 @@ object IntervalUtils {
private val yearMonthIndividualLiteralRegex =
(s"(?i)^INTERVAL\\s+([+|-])?'$yearMonthIndividualPatternString'\\s+(YEAR|MONTH)$$").r
private def getSign(firstSign: String, secondSign: String): Int = {
private def finalSign(firstSign: String, secondSign: String = null): Int = {
(firstSign, secondSign) match {
case ("-", "-") => 1
case ("-", _) => -1
@ -119,6 +119,39 @@ object IntervalUtils {
}
}
private def throwIllegalIntervalFormatException(
input: UTF8String,
startFiled: Byte,
endField: Byte,
intervalStr: String,
typeName: String,
fallBackNotice: Option[String] = None) = {
throw new IllegalArgumentException(
s"Interval string does not match $intervalStr format of " +
s"${supportedFormat((startFiled, endField)).map(format => s"`$format`").mkString(", ")} " +
s"when cast to $typeName: ${input.toString}" +
s"${fallBackNotice.map(s => s", $s").getOrElse("")}")
}
private def checkIntervalStringDataType(
input: UTF8String,
targetStartField: Byte,
targetEndField: Byte,
inputIntervalType: DataType,
fallBackNotice: Option[String] = None): Unit = {
val (intervalStr, typeName, inputStartField, inputEndField) = inputIntervalType match {
case DT(startField, endField) =>
("day-time", DT(targetStartField, targetEndField).typeName, startField, endField)
case YM(startField, endField) =>
("year-month", YM(targetStartField, targetEndField).typeName, startField, endField)
}
if (targetStartField != inputStartField || targetEndField != inputEndField) {
throwIllegalIntervalFormatException(
input, targetStartField, targetEndField, intervalStr, typeName, fallBackNotice)
}
}
val supportedFormat = Map(
(YM.YEAR, YM.MONTH) -> Seq("[+|-]y-m", "INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH"),
(YM.YEAR, YM.YEAR) -> Seq("[+|-]y", "INTERVAL [+|-]'[+|-]y' YEAR"),
@ -140,56 +173,41 @@ object IntervalUtils {
startField: Byte,
endField: Byte): Int = {
def checkStringIntervalType(targetStartField: Byte, targetEndField: Byte): Unit = {
if (startField != targetStartField || endField != targetEndField) {
throw new IllegalArgumentException(s"Interval string does not match year-month format of " +
s"${supportedFormat((targetStartField, targetStartField))
.map(format => s"`$format`").mkString(", ")} " +
s"when cast to ${YM(startField, endField).typeName}: ${input.toString}")
}
}
def checkYMIntervalStringDataType(ym: YM): Unit =
checkIntervalStringDataType(input, startField, endField, ym)
input.trimAll().toString match {
case yearMonthRegex("-", year, month) =>
checkStringIntervalType(YM.YEAR, YM.MONTH)
toYMInterval(year, month, -1)
case yearMonthRegex(_, year, month) =>
checkStringIntervalType(YM.YEAR, YM.MONTH)
toYMInterval(year, month, 1)
case yearMonthRegex(sign, year, month) =>
checkYMIntervalStringDataType(YM(YM.YEAR, YM.MONTH))
toYMInterval(year, month, finalSign(sign))
case yearMonthLiteralRegex(firstSign, secondSign, year, month) =>
checkStringIntervalType(YM.YEAR, YM.MONTH)
toYMInterval(year, month, getSign(firstSign, secondSign))
case yearMonthIndividualRegex(secondSign, value) =>
safeToInterval {
val sign = getSign("+", secondSign)
checkYMIntervalStringDataType(YM(YM.YEAR, YM.MONTH))
toYMInterval(year, month, finalSign(firstSign, secondSign))
case yearMonthIndividualRegex(firstSign, value) =>
safeToInterval("year-month") {
val sign = finalSign(firstSign)
if (endField == YM.YEAR) {
sign * Math.toIntExact(value.toLong * MONTHS_PER_YEAR)
} else if (startField == YM.MONTH) {
Math.toIntExact(sign * value.toLong)
} else {
throw new IllegalArgumentException(
s"Interval string does not match year-month format of " +
s"${supportedFormat((YM.YEAR, YM.MONTH))
.map(format => s"`$format`").mkString(", ")} " +
s"when cast to ${YM(startField, endField).typeName}: ${input.toString}")
throwIllegalIntervalFormatException(
input, startField, endField, "year-month", YM(startField, endField).typeName)
}
}
case yearMonthIndividualLiteralRegex(firstSign, secondSign, value, suffix) =>
safeToInterval {
val sign = getSign(firstSign, secondSign)
safeToInterval("year-month") {
val sign = finalSign(firstSign, secondSign)
if ("YEAR".equalsIgnoreCase(suffix)) {
checkStringIntervalType(YM.YEAR, YM.YEAR)
checkYMIntervalStringDataType(YM(YM.YEAR, YM.YEAR))
sign * Math.toIntExact(value.toLong * MONTHS_PER_YEAR)
} else {
checkStringIntervalType(YM.MONTH, YM.MONTH)
checkYMIntervalStringDataType(YM(YM.MONTH, YM.MONTH))
Math.toIntExact(sign * value.toLong)
}
}
case _ => throw new IllegalArgumentException(
s"Interval string does not match year-month format of " +
s"${supportedFormat((YM.YEAR, YM.MONTH))
.map(format => s"`$format`").mkString(", ")} " +
s"when cast to ${YM(startField, endField).typeName}: ${input.toString}")
case _ => throwIllegalIntervalFormatException(input, startField, endField,
"year-month", YM(startField, endField).typeName)
}
}
@ -201,28 +219,26 @@ object IntervalUtils {
def fromYearMonthString(input: String): CalendarInterval = {
require(input != null, "Interval year-month string must be not null")
input.trim match {
case yearMonthRegex("-", yearStr, monthStr) =>
new CalendarInterval(toYMInterval(yearStr, monthStr, -1), 0, 0)
case yearMonthRegex(_, yearStr, monthStr) =>
new CalendarInterval(toYMInterval(yearStr, monthStr, 1), 0, 0)
case yearMonthRegex(sign, yearStr, monthStr) =>
new CalendarInterval(toYMInterval(yearStr, monthStr, finalSign(sign)), 0, 0)
case _ =>
throw new IllegalArgumentException(
s"Interval string does not match year-month format of 'y-m': $input")
}
}
private def safeToInterval[T](f: => T): T = {
private def safeToInterval[T](interval: String)(f: => T): T = {
try {
f
} catch {
case NonFatal(e) =>
throw new IllegalArgumentException(
s"Error parsing interval year-month string: ${e.getMessage}", e)
s"Error parsing interval $interval string: ${e.getMessage}", e)
}
}
private def toYMInterval(yearStr: String, monthStr: String, sign: Int): Int = {
safeToInterval {
safeToInterval("year-month") {
val years = toLongWithRange(YEAR, yearStr, 0, Integer.MAX_VALUE / MONTHS_PER_YEAR)
val totalMonths = sign * (years * MONTHS_PER_YEAR + toLongWithRange(MONTH, monthStr, 0, 11))
Math.toIntExact(totalMonths)
@ -279,15 +295,6 @@ object IntervalUtils {
startField: Byte,
endField: Byte): Long = {
def checkStringIntervalType(targetStartField: Byte, targetEndField: Byte): Unit = {
if (startField != targetStartField || endField != targetEndField) {
throw new IllegalArgumentException(s"Interval string does not match day-time format of " +
s"${supportedFormat((targetStartField, targetStartField))
.map(format => s"`$format`").mkString(", ")} " +
s"when cast to ${DT(startField, endField).typeName}: ${input.toString}")
}
}
def secondAndMicro(second: String, micro: String): String = {
if (micro != null) {
s"$second$micro"
@ -296,50 +303,53 @@ object IntervalUtils {
}
}
def checkDTIntervalStringDataType(dt: DT): Unit =
checkIntervalStringDataType(input, startField, endField, dt, Some(fallbackNotice))
input.trimAll().toString match {
case dayHourRegex(sign, day, hour) =>
checkStringIntervalType(DT.DAY, DT.HOUR)
toDTInterval(day, hour, "0", "0", getSign(null, sign))
checkDTIntervalStringDataType(DT(DT.DAY, DT.HOUR))
toDTInterval(day, hour, "0", "0", finalSign(sign))
case dayHourLiteralRegex(firstSign, secondSign, day, hour) =>
checkStringIntervalType(DT.DAY, DT.HOUR)
toDTInterval(day, hour, "0", "0", getSign(firstSign, secondSign))
checkDTIntervalStringDataType(DT(DT.DAY, DT.HOUR))
toDTInterval(day, hour, "0", "0", finalSign(firstSign, secondSign))
case dayMinuteRegex(sign, day, hour, minute) =>
checkStringIntervalType(DT.DAY, DT.MINUTE)
toDTInterval(day, hour, minute, "0", getSign(null, sign))
checkDTIntervalStringDataType(DT(DT.DAY, DT.MINUTE))
toDTInterval(day, hour, minute, "0", finalSign(sign))
case dayMinuteLiteralRegex(firstSign, secondSign, day, hour, minute) =>
checkStringIntervalType(DT.DAY, DT.MINUTE)
toDTInterval(day, hour, minute, "0", getSign(firstSign, secondSign))
checkDTIntervalStringDataType(DT(DT.DAY, DT.MINUTE))
toDTInterval(day, hour, minute, "0", finalSign(firstSign, secondSign))
case daySecondRegex(sign, day, hour, minute, second, micro) =>
checkStringIntervalType(DT.DAY, DT.SECOND)
toDTInterval(day, hour, minute, secondAndMicro(second, micro), getSign(null, sign))
checkDTIntervalStringDataType(DT(DT.DAY, DT.SECOND))
toDTInterval(day, hour, minute, secondAndMicro(second, micro), finalSign(sign))
case daySecondLiteralRegex(firstSign, secondSign, day, hour, minute, second, micro) =>
checkStringIntervalType(DT.DAY, DT.SECOND)
checkDTIntervalStringDataType(DT(DT.DAY, DT.SECOND))
toDTInterval(day, hour, minute, secondAndMicro(second, micro),
getSign(firstSign, secondSign))
finalSign(firstSign, secondSign))
case hourMinuteRegex(sign, hour, minute) =>
checkStringIntervalType(DT.HOUR, DT.MINUTE)
toDTInterval(hour, minute, "0", getSign(null, sign))
checkDTIntervalStringDataType(DT(DT.HOUR, DT.MINUTE))
toDTInterval(hour, minute, "0", finalSign(sign))
case hourMinuteLiteralRegex(firstSign, secondSign, hour, minute) =>
checkStringIntervalType(DT.HOUR, DT.MINUTE)
toDTInterval(hour, minute, "0", getSign(firstSign, secondSign))
checkDTIntervalStringDataType(DT(DT.HOUR, DT.MINUTE))
toDTInterval(hour, minute, "0", finalSign(firstSign, secondSign))
case hourSecondRegex(sign, hour, minute, second, micro) =>
checkStringIntervalType(DT.HOUR, DT.SECOND)
toDTInterval(hour, minute, secondAndMicro(second, micro), getSign(null, sign))
checkDTIntervalStringDataType(DT(DT.HOUR, DT.SECOND))
toDTInterval(hour, minute, secondAndMicro(second, micro), finalSign(sign))
case hourSecondLiteralRegex(firstSign, secondSign, hour, minute, second, micro) =>
checkStringIntervalType(DT.HOUR, DT.SECOND)
toDTInterval(hour, minute, secondAndMicro(second, micro), getSign(firstSign, secondSign))
checkDTIntervalStringDataType(DT(DT.HOUR, DT.SECOND))
toDTInterval(hour, minute, secondAndMicro(second, micro), finalSign(firstSign, secondSign))
case minuteSecondRegex(sign, minute, second, micro) =>
checkStringIntervalType(DT.MINUTE, DT.SECOND)
toDTInterval(minute, secondAndMicro(second, micro), getSign(null, sign))
checkDTIntervalStringDataType(DT(DT.MINUTE, DT.SECOND))
toDTInterval(minute, secondAndMicro(second, micro), finalSign(sign))
case minuteSecondLiteralRegex(firstSign, secondSign, minute, second, micro) =>
checkStringIntervalType(DT.MINUTE, DT.SECOND)
toDTInterval(minute, secondAndMicro(second, micro), getSign(firstSign, secondSign))
checkDTIntervalStringDataType(DT(DT.MINUTE, DT.SECOND))
toDTInterval(minute, secondAndMicro(second, micro), finalSign(firstSign, secondSign))
case dayTimeIndividualRegex(secondSign, value, suffix) =>
safeToInterval {
val sign = getSign("+", secondSign)
case dayTimeIndividualRegex(firstSign, value, suffix) =>
safeToInterval("day-time") {
val sign = finalSign(firstSign)
(startField, endField) match {
case (DT.DAY, DT.DAY) if suffix == null && value.length <= 9 =>
sign * value.toLong * MICROS_PER_DAY
@ -352,46 +362,35 @@ object IntervalUtils {
case 1 => parseSecondNano(secondAndMicro(value, suffix))
case -1 => parseSecondNano(s"-${secondAndMicro(value, suffix)}")
}
case (_, _) => throw new IllegalArgumentException(
s"Interval string does not match day-time format of " +
s"${supportedFormat((startField, endField))
.map(format => s"`$format`").mkString(", ")} " +
s"when cast to ${DT(startField, endField).typeName}: ${input.toString}")
case (_, _) => throwIllegalIntervalFormatException(input, startField, endField,
"day-time", DT(startField, endField).typeName, Some(fallbackNotice))
}
}
case dayTimeIndividualLiteralRegex(firstSign, secondSign, value, suffix, unit) =>
safeToInterval {
val sign = getSign(firstSign, secondSign)
safeToInterval("day-time") {
val sign = finalSign(firstSign, secondSign)
unit match {
case "DAY" if suffix == null && value.length <= 9 =>
checkStringIntervalType(DT.DAY, DT.DAY)
checkDTIntervalStringDataType(DT(DT.DAY, DT.DAY))
sign * value.toLong * MICROS_PER_DAY
case "HOUR" if suffix == null && value.length <= 10 =>
checkStringIntervalType(DT.HOUR, DT.HOUR)
checkDTIntervalStringDataType(DT(DT.HOUR, DT.HOUR))
sign * value.toLong * MICROS_PER_HOUR
case "MINUTE" if suffix == null && value.length <= 12 =>
checkStringIntervalType(DT.MINUTE, DT.MINUTE)
checkDTIntervalStringDataType(DT(DT.MINUTE, DT.MINUTE))
sign * value.toLong * MICROS_PER_MINUTE
case "SECOND" if value.length <= 13 =>
checkStringIntervalType(DT.SECOND, DT.SECOND)
checkDTIntervalStringDataType(DT(DT.SECOND, DT.SECOND))
sign match {
case 1 => parseSecondNano(secondAndMicro(value, suffix))
case -1 => parseSecondNano(s"-${secondAndMicro(value, suffix)}")
}
case _ => throw new IllegalArgumentException(
s"Interval string does not match day-time format of " +
s"${supportedFormat((startField, endField))
.map(format => s"`$format`").mkString(", ")} " +
s"when cast to ${DT(startField, endField).typeName}: ${input.toString}")
case _ => throwIllegalIntervalFormatException(input, startField, endField,
"day-time", DT(startField, endField).typeName, Some(fallbackNotice))
}
}
case _ =>
throw new IllegalArgumentException(
s"Interval string does not match day-time format of " +
s"${supportedFormat((startField, endField))
.map(format => s"`$format`").mkString(", ")} " +
s"when cast to ${DT(startField, endField).typeName}: ${input.toString}, " +
s"$fallbackNotice")
case _ => throwIllegalIntervalFormatException(input, startField, endField,
"day-time", DT(startField, endField).typeName, Some(fallbackNotice))
}
}

View file

@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.catalyst.util.IntervalUtils.microsToDuration
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@ -1113,10 +1114,14 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
if (!isTryCast) {
Seq("INTERVAL '1-1' YEAR", "INTERVAL '1-1' MONTH").foreach { interval =>
val dataType = YearMonthIntervalType()
val e = intercept[IllegalArgumentException] {
cast(Literal.create(interval), YearMonthIntervalType()).eval()
cast(Literal.create(interval), dataType).eval()
}.getMessage
assert(e.contains("Interval string does not match year-month format"))
assert(e.contains(s"Interval string does not match year-month format of " +
s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
.map(format => s"`$format`").mkString(", ")} " +
s"when cast to ${dataType.typeName}: $interval"))
}
Seq(("1", YearMonthIntervalType(YEAR, MONTH)),
("1", YearMonthIntervalType(YEAR, MONTH)),
@ -1132,7 +1137,10 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
val e = intercept[IllegalArgumentException] {
cast(Literal.create(interval), dataType).eval()
}.getMessage
assert(e.contains("Interval string does not match year-month format"))
assert(e.contains(s"Interval string does not match year-month format of " +
s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
.map(format => s"`$format`").mkString(", ")} " +
s"when cast to ${dataType.typeName}: $interval"))
}
}
}
@ -1249,7 +1257,12 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
val e = intercept[IllegalArgumentException] {
cast(Literal.create(interval), dataType).eval()
}.getMessage
assert(e.contains("Interval string does not match day-time format"))
assert(e.contains(s"Interval string does not match day-time format of " +
s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
.map(format => s"`$format`").mkString(", ")} " +
s"when cast to ${dataType.typeName}: $interval, " +
s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
"to restore the behavior before Spark 3.0."))
}
// Check first field outof bound
@ -1267,7 +1280,12 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
val e = intercept[IllegalArgumentException] {
cast(Literal.create(interval), dataType).eval()
}.getMessage
assert(e.contains("Interval string does not match day-time format"))
assert(e.contains(s"Interval string does not match day-time format of " +
s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField))
.map(format => s"`$format`").mkString(", ")} " +
s"when cast to ${dataType.typeName}: $interval, " +
s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " +
"to restore the behavior before Spark 3.0."))
}
}
}