[SPARK-36524][SQL] Common class for ANSI interval types

### What changes were proposed in this pull request?
Add new type `AnsiIntervalType` to `AbstractDataType.scala`, and extend it by `YearMonthIntervalType` and by `DayTimeIntervalType`

### Why are the changes needed?
To improve code maintenance. The change will allow to replace checking of both `YearMonthIntervalType` and `DayTimeIntervalType` by a check of `AnsiIntervalType`, for instance:
```scala
    case _: YearMonthIntervalType | _: DayTimeIntervalType => false
```
by
```scala
    case _: AnsiIntervalType => false
```

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By existing test suites.

Closes #33753 from MaxGekk/ansi-interval-type-trait.

Authored-by: Max Gekk <max.gekk@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
Max Gekk 2021-08-17 12:27:56 +03:00
parent ea13c5a743
commit 82a31508af
22 changed files with 33 additions and 29 deletions

View file

@ -71,7 +71,7 @@ private[sql] object AvroUtils extends Logging {
} }
def supportsDataType(dataType: DataType): Boolean = dataType match { def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false case _: AnsiIntervalType => false
case _: AtomicType => true case _: AtomicType => true

View file

@ -377,9 +377,9 @@ class Analyzer(override val catalogManager: CatalogManager)
TimestampAddYMInterval(r, l) TimestampAddYMInterval(r, l)
case (CalendarIntervalType, CalendarIntervalType) | case (CalendarIntervalType, CalendarIntervalType) |
(_: DayTimeIntervalType, _: DayTimeIntervalType) => a (_: DayTimeIntervalType, _: DayTimeIntervalType) => a
case (_: NullType, _: DayTimeIntervalType | _: YearMonthIntervalType) => case (_: NullType, _: AnsiIntervalType) =>
a.copy(left = Cast(a.left, a.right.dataType)) a.copy(left = Cast(a.left, a.right.dataType))
case (_: DayTimeIntervalType | _: YearMonthIntervalType, _: NullType) => case (_: AnsiIntervalType, _: NullType) =>
a.copy(right = Cast(a.right, a.left.dataType)) a.copy(right = Cast(a.right, a.left.dataType))
case (DateType, CalendarIntervalType) => DateAddInterval(l, r, ansiEnabled = f) case (DateType, CalendarIntervalType) => DateAddInterval(l, r, ansiEnabled = f)
case (_, CalendarIntervalType | _: DayTimeIntervalType) => Cast(TimeAdd(l, r), l.dataType) case (_, CalendarIntervalType | _: DayTimeIntervalType) => Cast(TimeAdd(l, r), l.dataType)
@ -400,9 +400,9 @@ class Analyzer(override val catalogManager: CatalogManager)
DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, f))) DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, f)))
case (CalendarIntervalType, CalendarIntervalType) | case (CalendarIntervalType, CalendarIntervalType) |
(_: DayTimeIntervalType, _: DayTimeIntervalType) => s (_: DayTimeIntervalType, _: DayTimeIntervalType) => s
case (_: NullType, _: DayTimeIntervalType | _: YearMonthIntervalType) => case (_: NullType, _: AnsiIntervalType) =>
s.copy(left = Cast(s.left, s.right.dataType)) s.copy(left = Cast(s.left, s.right.dataType))
case (_: DayTimeIntervalType | _: YearMonthIntervalType, _: NullType) => case (_: AnsiIntervalType, _: NullType) =>
s.copy(right = Cast(s.right, s.left.dataType)) s.copy(right = Cast(s.right, s.left.dataType))
case (DateType, CalendarIntervalType) => case (DateType, CalendarIntervalType) =>
DatetimeSub(l, r, DateAddInterval(l, UnaryMinus(r, f), ansiEnabled = f)) DatetimeSub(l, r, DateAddInterval(l, UnaryMinus(r, f), ansiEnabled = f))

View file

@ -275,7 +275,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
// If a binary operation contains interval type and string literal, we can't decide which // If a binary operation contains interval type and string literal, we can't decide which
// interval type the string literal should be promoted as. There are many possible interval // interval type the string literal should be promoted as. There are many possible interval
// types, such as year interval, month interval, day interval, hour interval, etc. // types, such as year interval, month interval, day interval, hour interval, etc.
case _: YearMonthIntervalType | _: DayTimeIntervalType => false case _: AnsiIntervalType => false
case _: AtomicType => true case _: AtomicType => true
case _ => false case _ => false
} }

View file

@ -981,7 +981,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
case u: UserDefinedType[_] => case u: UserDefinedType[_] =>
alter.failAnalysis(s"Cannot update ${table.name} field $fieldName type: " + alter.failAnalysis(s"Cannot update ${table.name} field $fieldName type: " +
s"update a UserDefinedType[${u.sql}] by updating its fields") s"update a UserDefinedType[${u.sql}] by updating its fields")
case _: CalendarIntervalType | _: YearMonthIntervalType | _: DayTimeIntervalType => case _: CalendarIntervalType | _: AnsiIntervalType =>
alter.failAnalysis(s"Cannot update ${table.name} field $fieldName to interval type") alter.failAnalysis(s"Cannot update ${table.name} field $fieldName to interval type")
case _ => // update is okay case _ => // update is okay
} }

View file

@ -85,7 +85,7 @@ case class UnaryMinus(
val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
val method = if (failOnError) "negateExact" else "negate" val method = if (failOnError) "negateExact" else "negate"
defineCodeGen(ctx, ev, c => s"$iu.$method($c)") defineCodeGen(ctx, ev, c => s"$iu.$method($c)")
case _: DayTimeIntervalType | _: YearMonthIntervalType => case _: AnsiIntervalType =>
nullSafeCodeGen(ctx, ev, eval => { nullSafeCodeGen(ctx, ev, eval => {
val mathClass = classOf[Math].getName val mathClass = classOf[Math].getName
s"${ev.value} = $mathClass.negateExact($eval);" s"${ev.value} = $mathClass.negateExact($eval);"
@ -229,7 +229,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
case CalendarIntervalType => case CalendarIntervalType =>
val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
defineCodeGen(ctx, ev, (eval1, eval2) => s"$iu.$calendarIntervalMethod($eval1, $eval2)") defineCodeGen(ctx, ev, (eval1, eval2) => s"$iu.$calendarIntervalMethod($eval1, $eval2)")
case _: DayTimeIntervalType | _: YearMonthIntervalType => case _: AnsiIntervalType =>
assert(exactMathMethod.isDefined, assert(exactMathMethod.isDefined,
s"The expression '$nodeName' must override the exactMathMethod() method " + s"The expression '$nodeName' must override the exactMathMethod() method " +
"if it is supposed to operate over interval types.") "if it is supposed to operate over interval types.")

View file

@ -2597,7 +2597,7 @@ case class Sequence(
} }
private def isNotIntervalType(expr: Expression) = expr.dataType match { private def isNotIntervalType(expr: Expression) = expr.dataType match {
case CalendarIntervalType | _: YearMonthIntervalType | _: DayTimeIntervalType => false case CalendarIntervalType | _: AnsiIntervalType => false
case _ => true case _ => true
} }

View file

@ -2733,7 +2733,7 @@ object DatePart {
throw QueryCompilationErrors.literalTypeUnsupportedForSourceTypeError(fieldStr, source) throw QueryCompilationErrors.literalTypeUnsupportedForSourceTypeError(fieldStr, source)
source.dataType match { source.dataType match {
case _: YearMonthIntervalType | _: DayTimeIntervalType | CalendarIntervalType => case _: AnsiIntervalType | CalendarIntervalType =>
ExtractIntervalPart.parseExtractField(fieldStr, source, analysisException) ExtractIntervalPart.parseExtractField(fieldStr, source, analysisException)
case _ => case _ =>
DatePart.parseExtractField(fieldStr, source, analysisException) DatePart.parseExtractField(fieldStr, source, analysisException)

View file

@ -63,7 +63,7 @@ object TypeUtils {
def checkForAnsiIntervalOrNumericType( def checkForAnsiIntervalOrNumericType(
dt: DataType, funcName: String): TypeCheckResult = dt match { dt: DataType, funcName: String): TypeCheckResult = dt match {
case _: YearMonthIntervalType | _: DayTimeIntervalType | NullType => case _: AnsiIntervalType | NullType =>
TypeCheckResult.TypeCheckSuccess TypeCheckResult.TypeCheckSuccess
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
case other => TypeCheckResult.TypeCheckFailure( case other => TypeCheckResult.TypeCheckFailure(
@ -117,7 +117,7 @@ object TypeUtils {
def invokeOnceForInterval(dataType: DataType)(f: => Unit): Unit = { def invokeOnceForInterval(dataType: DataType)(f: => Unit): Unit = {
def isInterval(dataType: DataType): Boolean = dataType match { def isInterval(dataType: DataType): Boolean = dataType match {
case CalendarIntervalType | _: DayTimeIntervalType | _: YearMonthIntervalType => true case CalendarIntervalType | _: AnsiIntervalType => true
case _ => false case _ => false
} }
if (dataType.existsRecursively(isInterval)) f if (dataType.existsRecursively(isInterval)) f

View file

@ -222,3 +222,8 @@ private[sql] object AnyTimestampType extends AbstractDataType with Serializable
def unapply(e: Expression): Boolean = acceptsType(e.dataType) def unapply(e: Expression): Boolean = acceptsType(e.dataType)
} }
/**
* The interval type which conforms to the ANSI SQL standard.
*/
private[sql] abstract class AnsiIntervalType extends AtomicType

View file

@ -42,7 +42,7 @@ import org.apache.spark.sql.types.DayTimeIntervalType.fieldToString
* @since 3.2.0 * @since 3.2.0
*/ */
@Unstable @Unstable
case class DayTimeIntervalType(startField: Byte, endField: Byte) extends AtomicType { case class DayTimeIntervalType(startField: Byte, endField: Byte) extends AnsiIntervalType {
/** /**
* Internally, values of day-time intervals are stored in `Long` values as amount of time in terms * Internally, values of day-time intervals are stored in `Long` values as amount of time in terms
* of microseconds that are calculated by the formula: * of microseconds that are calculated by the formula:

View file

@ -40,7 +40,7 @@ import org.apache.spark.sql.types.YearMonthIntervalType.fieldToString
* @since 3.2.0 * @since 3.2.0
*/ */
@Unstable @Unstable
case class YearMonthIntervalType(startField: Byte, endField: Byte) extends AtomicType { case class YearMonthIntervalType(startField: Byte, endField: Byte) extends AnsiIntervalType {
/** /**
* Internally, values of year-month intervals are stored in `Int` values as amount of months * Internally, values of year-month intervals are stored in `Int` values as amount of months
* that are calculated by the formula: * that are calculated by the formula:

View file

@ -148,7 +148,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat]
override def supportDataType(dataType: DataType): Boolean = dataType match { override def supportDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false case _: AnsiIntervalType => false
case _: AtomicType => true case _: AtomicType => true

View file

@ -134,7 +134,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat] override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat]
override def supportDataType(dataType: DataType): Boolean = dataType match { override def supportDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false case _: AnsiIntervalType => false
case _: AtomicType => true case _: AtomicType => true

View file

@ -232,7 +232,7 @@ class OrcFileFormat
} }
override def supportDataType(dataType: DataType): Boolean = dataType match { override def supportDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false case _: AnsiIntervalType => false
case _: AtomicType => true case _: AtomicType => true

View file

@ -373,7 +373,7 @@ class ParquetFileFormat
} }
override def supportDataType(dataType: DataType): Boolean = dataType match { override def supportDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false case _: AnsiIntervalType => false
case _: AtomicType => true case _: AtomicType => true

View file

@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.write.{LogicalWriteInfo, Write, WriteBuild
import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.csv.CSVDataSource import org.apache.spark.sql.execution.datasources.csv.CSVDataSource
import org.apache.spark.sql.execution.datasources.v2.FileTable import org.apache.spark.sql.execution.datasources.v2.FileTable
import org.apache.spark.sql.types.{AtomicType, DataType, DayTimeIntervalType, StructType, UserDefinedType, YearMonthIntervalType} import org.apache.spark.sql.types.{AnsiIntervalType, AtomicType, DataType, StructType, UserDefinedType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.util.CaseInsensitiveStringMap
case class CSVTable( case class CSVTable(
@ -55,7 +55,7 @@ case class CSVTable(
} }
override def supportsDataType(dataType: DataType): Boolean = dataType match { override def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false case _: AnsiIntervalType => false
case _: AtomicType => true case _: AtomicType => true

View file

@ -55,7 +55,7 @@ case class JsonTable(
} }
override def supportsDataType(dataType: DataType): Boolean = dataType match { override def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false case _: AnsiIntervalType => false
case _: AtomicType => true case _: AtomicType => true

View file

@ -49,7 +49,7 @@ case class OrcTable(
} }
override def supportsDataType(dataType: DataType): Boolean = dataType match { override def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false case _: AnsiIntervalType => false
case _: AtomicType => true case _: AtomicType => true

View file

@ -49,7 +49,7 @@ case class ParquetTable(
} }
override def supportsDataType(dataType: DataType): Boolean = dataType match { override def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false case _: AnsiIntervalType => false
case _: AtomicType => true case _: AtomicType => true

View file

@ -121,7 +121,7 @@ private[hive] class SparkExecuteStatementOperation(
false, false,
timeFormatters) timeFormatters)
case _: ArrayType | _: StructType | _: MapType | _: UserDefinedType[_] | case _: ArrayType | _: StructType | _: MapType | _: UserDefinedType[_] |
_: YearMonthIntervalType | _: DayTimeIntervalType | _: TimestampNTZType => _: AnsiIntervalType | _: TimestampNTZType =>
to += toHiveString((from.get(ordinal), dataTypes(ordinal)), false, timeFormatters) to += toHiveString((from.get(ordinal), dataTypes(ordinal)), false, timeFormatters)
} }
} }

View file

@ -131,8 +131,7 @@ private[hive] class SparkGetColumnsOperation(
*/ */
private def getColumnSize(typ: DataType): Option[Int] = typ match { private def getColumnSize(typ: DataType): Option[Int] = typ match {
case dt @ (BooleanType | _: NumericType | DateType | TimestampType | TimestampNTZType | case dt @ (BooleanType | _: NumericType | DateType | TimestampType | TimestampNTZType |
CalendarIntervalType | NullType | CalendarIntervalType | NullType | _: AnsiIntervalType) =>
_: YearMonthIntervalType | _: DayTimeIntervalType) =>
Some(dt.defaultSize) Some(dt.defaultSize)
case CharType(n) => Some(n) case CharType(n) => Some(n)
case StructType(fields) => case StructType(fields) =>
@ -187,7 +186,7 @@ private[hive] class SparkGetColumnsOperation(
case _: MapType => java.sql.Types.JAVA_OBJECT case _: MapType => java.sql.Types.JAVA_OBJECT
case _: StructType => java.sql.Types.STRUCT case _: StructType => java.sql.Types.STRUCT
// Hive's year-month and day-time intervals are mapping to java.sql.Types.OTHER // Hive's year-month and day-time intervals are mapping to java.sql.Types.OTHER
case _: CalendarIntervalType | _: YearMonthIntervalType | _: DayTimeIntervalType => case _: CalendarIntervalType | _: AnsiIntervalType =>
java.sql.Types.OTHER java.sql.Types.OTHER
case _ => throw new IllegalArgumentException(s"Unrecognized type name: ${typ.sql}") case _ => throw new IllegalArgumentException(s"Unrecognized type name: ${typ.sql}")
} }

View file

@ -194,7 +194,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable
} }
override def supportDataType(dataType: DataType): Boolean = dataType match { override def supportDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false case _: AnsiIntervalType => false
case _: AtomicType => true case _: AtomicType => true