[SPARK-36524][SQL] Common class for ANSI interval types

### What changes were proposed in this pull request?
Add new type `AnsiIntervalType` to `AbstractDataType.scala`, and extend it by `YearMonthIntervalType` and by `DayTimeIntervalType`

### Why are the changes needed?
To improve code maintenance. The change will allow to replace checking of both `YearMonthIntervalType` and `DayTimeIntervalType` by a check of `AnsiIntervalType`, for instance:
```scala
    case _: YearMonthIntervalType | _: DayTimeIntervalType => false
```
by
```scala
    case _: AnsiIntervalType => false
```

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By existing test suites.

Closes #33753 from MaxGekk/ansi-interval-type-trait.

Authored-by: Max Gekk <max.gekk@gmail.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
This commit is contained in:
Max Gekk 2021-08-17 12:27:56 +03:00
parent ea13c5a743
commit 82a31508af
22 changed files with 33 additions and 29 deletions

View file

@ -71,7 +71,7 @@ private[sql] object AvroUtils extends Logging {
}
def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false
case _: AnsiIntervalType => false
case _: AtomicType => true

View file

@ -377,9 +377,9 @@ class Analyzer(override val catalogManager: CatalogManager)
TimestampAddYMInterval(r, l)
case (CalendarIntervalType, CalendarIntervalType) |
(_: DayTimeIntervalType, _: DayTimeIntervalType) => a
case (_: NullType, _: DayTimeIntervalType | _: YearMonthIntervalType) =>
case (_: NullType, _: AnsiIntervalType) =>
a.copy(left = Cast(a.left, a.right.dataType))
case (_: DayTimeIntervalType | _: YearMonthIntervalType, _: NullType) =>
case (_: AnsiIntervalType, _: NullType) =>
a.copy(right = Cast(a.right, a.left.dataType))
case (DateType, CalendarIntervalType) => DateAddInterval(l, r, ansiEnabled = f)
case (_, CalendarIntervalType | _: DayTimeIntervalType) => Cast(TimeAdd(l, r), l.dataType)
@ -400,9 +400,9 @@ class Analyzer(override val catalogManager: CatalogManager)
DatetimeSub(l, r, TimestampAddYMInterval(l, UnaryMinus(r, f)))
case (CalendarIntervalType, CalendarIntervalType) |
(_: DayTimeIntervalType, _: DayTimeIntervalType) => s
case (_: NullType, _: DayTimeIntervalType | _: YearMonthIntervalType) =>
case (_: NullType, _: AnsiIntervalType) =>
s.copy(left = Cast(s.left, s.right.dataType))
case (_: DayTimeIntervalType | _: YearMonthIntervalType, _: NullType) =>
case (_: AnsiIntervalType, _: NullType) =>
s.copy(right = Cast(s.right, s.left.dataType))
case (DateType, CalendarIntervalType) =>
DatetimeSub(l, r, DateAddInterval(l, UnaryMinus(r, f), ansiEnabled = f))

View file

@ -275,7 +275,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
// If a binary operation contains interval type and string literal, we can't decide which
// interval type the string literal should be promoted as. There are many possible interval
// types, such as year interval, month interval, day interval, hour interval, etc.
case _: YearMonthIntervalType | _: DayTimeIntervalType => false
case _: AnsiIntervalType => false
case _: AtomicType => true
case _ => false
}

View file

@ -981,7 +981,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
case u: UserDefinedType[_] =>
alter.failAnalysis(s"Cannot update ${table.name} field $fieldName type: " +
s"update a UserDefinedType[${u.sql}] by updating its fields")
case _: CalendarIntervalType | _: YearMonthIntervalType | _: DayTimeIntervalType =>
case _: CalendarIntervalType | _: AnsiIntervalType =>
alter.failAnalysis(s"Cannot update ${table.name} field $fieldName to interval type")
case _ => // update is okay
}

View file

@ -85,7 +85,7 @@ case class UnaryMinus(
val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
val method = if (failOnError) "negateExact" else "negate"
defineCodeGen(ctx, ev, c => s"$iu.$method($c)")
case _: DayTimeIntervalType | _: YearMonthIntervalType =>
case _: AnsiIntervalType =>
nullSafeCodeGen(ctx, ev, eval => {
val mathClass = classOf[Math].getName
s"${ev.value} = $mathClass.negateExact($eval);"
@ -229,7 +229,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
case CalendarIntervalType =>
val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
defineCodeGen(ctx, ev, (eval1, eval2) => s"$iu.$calendarIntervalMethod($eval1, $eval2)")
case _: DayTimeIntervalType | _: YearMonthIntervalType =>
case _: AnsiIntervalType =>
assert(exactMathMethod.isDefined,
s"The expression '$nodeName' must override the exactMathMethod() method " +
"if it is supposed to operate over interval types.")

View file

@ -2597,7 +2597,7 @@ case class Sequence(
}
private def isNotIntervalType(expr: Expression) = expr.dataType match {
case CalendarIntervalType | _: YearMonthIntervalType | _: DayTimeIntervalType => false
case CalendarIntervalType | _: AnsiIntervalType => false
case _ => true
}

View file

@ -2733,7 +2733,7 @@ object DatePart {
throw QueryCompilationErrors.literalTypeUnsupportedForSourceTypeError(fieldStr, source)
source.dataType match {
case _: YearMonthIntervalType | _: DayTimeIntervalType | CalendarIntervalType =>
case _: AnsiIntervalType | CalendarIntervalType =>
ExtractIntervalPart.parseExtractField(fieldStr, source, analysisException)
case _ =>
DatePart.parseExtractField(fieldStr, source, analysisException)

View file

@ -63,7 +63,7 @@ object TypeUtils {
def checkForAnsiIntervalOrNumericType(
dt: DataType, funcName: String): TypeCheckResult = dt match {
case _: YearMonthIntervalType | _: DayTimeIntervalType | NullType =>
case _: AnsiIntervalType | NullType =>
TypeCheckResult.TypeCheckSuccess
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
case other => TypeCheckResult.TypeCheckFailure(
@ -117,7 +117,7 @@ object TypeUtils {
def invokeOnceForInterval(dataType: DataType)(f: => Unit): Unit = {
def isInterval(dataType: DataType): Boolean = dataType match {
case CalendarIntervalType | _: DayTimeIntervalType | _: YearMonthIntervalType => true
case CalendarIntervalType | _: AnsiIntervalType => true
case _ => false
}
if (dataType.existsRecursively(isInterval)) f

View file

@ -222,3 +222,8 @@ private[sql] object AnyTimestampType extends AbstractDataType with Serializable
def unapply(e: Expression): Boolean = acceptsType(e.dataType)
}
/**
* The interval type which conforms to the ANSI SQL standard.
*/
private[sql] abstract class AnsiIntervalType extends AtomicType

View file

@ -42,7 +42,7 @@ import org.apache.spark.sql.types.DayTimeIntervalType.fieldToString
* @since 3.2.0
*/
@Unstable
case class DayTimeIntervalType(startField: Byte, endField: Byte) extends AtomicType {
case class DayTimeIntervalType(startField: Byte, endField: Byte) extends AnsiIntervalType {
/**
* Internally, values of day-time intervals are stored in `Long` values as amount of time in terms
* of microseconds that are calculated by the formula:

View file

@ -40,7 +40,7 @@ import org.apache.spark.sql.types.YearMonthIntervalType.fieldToString
* @since 3.2.0
*/
@Unstable
case class YearMonthIntervalType(startField: Byte, endField: Byte) extends AtomicType {
case class YearMonthIntervalType(startField: Byte, endField: Byte) extends AnsiIntervalType {
/**
* Internally, values of year-month intervals are stored in `Int` values as amount of months
* that are calculated by the formula:

View file

@ -148,7 +148,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat]
override def supportDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false
case _: AnsiIntervalType => false
case _: AtomicType => true

View file

@ -134,7 +134,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat]
override def supportDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false
case _: AnsiIntervalType => false
case _: AtomicType => true

View file

@ -232,7 +232,7 @@ class OrcFileFormat
}
override def supportDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false
case _: AnsiIntervalType => false
case _: AtomicType => true

View file

@ -373,7 +373,7 @@ class ParquetFileFormat
}
override def supportDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false
case _: AnsiIntervalType => false
case _: AtomicType => true

View file

@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.write.{LogicalWriteInfo, Write, WriteBuild
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.csv.CSVDataSource
import org.apache.spark.sql.execution.datasources.v2.FileTable
import org.apache.spark.sql.types.{AtomicType, DataType, DayTimeIntervalType, StructType, UserDefinedType, YearMonthIntervalType}
import org.apache.spark.sql.types.{AnsiIntervalType, AtomicType, DataType, StructType, UserDefinedType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
case class CSVTable(
@ -55,7 +55,7 @@ case class CSVTable(
}
override def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false
case _: AnsiIntervalType => false
case _: AtomicType => true

View file

@ -55,7 +55,7 @@ case class JsonTable(
}
override def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false
case _: AnsiIntervalType => false
case _: AtomicType => true

View file

@ -49,7 +49,7 @@ case class OrcTable(
}
override def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false
case _: AnsiIntervalType => false
case _: AtomicType => true

View file

@ -49,7 +49,7 @@ case class ParquetTable(
}
override def supportsDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false
case _: AnsiIntervalType => false
case _: AtomicType => true

View file

@ -121,7 +121,7 @@ private[hive] class SparkExecuteStatementOperation(
false,
timeFormatters)
case _: ArrayType | _: StructType | _: MapType | _: UserDefinedType[_] |
_: YearMonthIntervalType | _: DayTimeIntervalType | _: TimestampNTZType =>
_: AnsiIntervalType | _: TimestampNTZType =>
to += toHiveString((from.get(ordinal), dataTypes(ordinal)), false, timeFormatters)
}
}

View file

@ -131,8 +131,7 @@ private[hive] class SparkGetColumnsOperation(
*/
private def getColumnSize(typ: DataType): Option[Int] = typ match {
case dt @ (BooleanType | _: NumericType | DateType | TimestampType | TimestampNTZType |
CalendarIntervalType | NullType |
_: YearMonthIntervalType | _: DayTimeIntervalType) =>
CalendarIntervalType | NullType | _: AnsiIntervalType) =>
Some(dt.defaultSize)
case CharType(n) => Some(n)
case StructType(fields) =>
@ -187,7 +186,7 @@ private[hive] class SparkGetColumnsOperation(
case _: MapType => java.sql.Types.JAVA_OBJECT
case _: StructType => java.sql.Types.STRUCT
// Hive's year-month and day-time intervals are mapping to java.sql.Types.OTHER
case _: CalendarIntervalType | _: YearMonthIntervalType | _: DayTimeIntervalType =>
case _: CalendarIntervalType | _: AnsiIntervalType =>
java.sql.Types.OTHER
case _ => throw new IllegalArgumentException(s"Unrecognized type name: ${typ.sql}")
}

View file

@ -194,7 +194,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable
}
override def supportDataType(dataType: DataType): Boolean = dataType match {
case _: DayTimeIntervalType | _: YearMonthIntervalType => false
case _: AnsiIntervalType => false
case _: AtomicType => true