[SPARK-30551][SQL] Disable comparison for interval type

### What changes were proposed in this pull request?

As we are not going to follow ANSI to implement year-month and day-time interval types, it is weird to compare the year-month part to the day-time part for our current implementation of interval type now.

Additionally, the current ordering logic comes from PostgreSQL where the implementation of the interval is messy. And we are not aiming PostgreSQL compliance at all.

THIS PR will revert https://github.com/apache/spark/pull/26681 and https://github.com/apache/spark/pull/26337

### Why are the changes needed?

make interval type more future-proofing

### Does this PR introduce any user-facing change?

there are new in 3.0, so no

### How was this patch tested?

existing uts shall work

Closes #27262 from yaooqinn/SPARK-30551.

Authored-by: Kent Yao <yaooqinn@hotmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Kent Yao 2020-01-19 15:27:51 +08:00 committed by Wenchen Fan
parent 0d99d7e3f2
commit 17857f9b8b
16 changed files with 636 additions and 1112 deletions

View file

@ -29,7 +29,7 @@ import static org.apache.spark.sql.catalyst.util.DateTimeConstants.*;
/**
* The internal representation of interval type.
*/
public final class CalendarInterval implements Serializable, Comparable<CalendarInterval> {
public final class CalendarInterval implements Serializable {
public final int months;
public final int days;
public final long microseconds;
@ -59,29 +59,6 @@ public final class CalendarInterval implements Serializable, Comparable<Calendar
return Objects.hash(months, days, microseconds);
}
@Override
public int compareTo(CalendarInterval that) {
long thisAdjustDays =
this.microseconds / MICROS_PER_DAY + this.days + this.months * DAYS_PER_MONTH;
long thatAdjustDays =
that.microseconds / MICROS_PER_DAY + that.days + that.months * DAYS_PER_MONTH;
long daysDiff = thisAdjustDays - thatAdjustDays;
if (daysDiff == 0) {
long msDiff = (this.microseconds % MICROS_PER_DAY) - (that.microseconds % MICROS_PER_DAY);
if (msDiff == 0) {
return 0;
} else if (msDiff > 0) {
return 1;
} else {
return -1;
}
} else if (daysDiff > 0){
return 1;
} else {
return -1;
}
}
@Override
public String toString() {
if (months == 0 && days == 0 && microseconds == 0) {
@ -133,16 +110,4 @@ public final class CalendarInterval implements Serializable, Comparable<Calendar
* @throws ArithmeticException if a numeric overflow occurs
*/
public Duration extractAsDuration() { return Duration.of(microseconds, ChronoUnit.MICROS); }
/**
* A constant holding the minimum value an {@code CalendarInterval} can have.
*/
public static CalendarInterval MIN_VALUE =
new CalendarInterval(Integer.MIN_VALUE, Integer.MIN_VALUE, Long.MIN_VALUE);
/**
* A constant holding the maximum value an {@code CalendarInterval} can have.
*/
public static CalendarInterval MAX_VALUE =
new CalendarInterval(Integer.MAX_VALUE, Integer.MAX_VALUE, Long.MAX_VALUE);
}

View file

@ -605,7 +605,6 @@ class CodegenContext extends Logging {
s"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2)"
case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
case CalendarIntervalType => s"$c1.equals($c2)"
case array: ArrayType => genComp(array, c1, c2) + " == 0"
case struct: StructType => genComp(struct, c1, c2) + " == 0"
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
@ -630,7 +629,6 @@ class CodegenContext extends Logging {
// use c1 - c2 may overflow
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
case CalendarIntervalType => s"$c1.compareTo($c2)"
case NullType => "0"
case array: ArrayType =>
val elementType = array.elementType

View file

@ -101,7 +101,6 @@ object RowOrdering extends CodeGeneratorWithInterpretedFallback[Seq[SortOrder],
def isOrderable(dataType: DataType): Boolean = dataType match {
case NullType => true
case dt: AtomicType => true
case CalendarIntervalType => true
case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType))
case array: ArrayType => isOrderable(array.elementType)
case udt: UserDefinedType[_] => isOrderable(udt.sqlType)

View file

@ -71,7 +71,6 @@ object TypeUtils {
def getInterpretedOrdering(t: DataType): Ordering[Any] = {
t match {
case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case c: CalendarIntervalType => c.ordering.asInstanceOf[Ordering[Any]]
case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
case udt: UserDefinedType[_] => getInterpretedOrdering(udt.sqlType)

View file

@ -79,8 +79,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
private[sql] object TypeCollection {
/**
* Types that include numeric types and interval type, which support numeric type calculations,
* i.e. unary_minus, unary_positive, sum, avg, min, max, add and subtract operations.
* Types that include numeric types and interval type. They are only used in unary_minus,
* unary_positive, add and subtract operations.
*/
val NumericAndInterval = TypeCollection(NumericType, CalendarIntervalType)

View file

@ -18,7 +18,6 @@
package org.apache.spark.sql.types
import org.apache.spark.annotation.Stable
import org.apache.spark.unsafe.types.CalendarInterval
/**
* The data type representing calendar intervals. The calendar interval is stored internally in
@ -40,8 +39,6 @@ class CalendarIntervalType private() extends DataType {
override def simpleString: String = "interval"
val ordering: Ordering[CalendarInterval] = Ordering[CalendarInterval]
private[spark] override def asNullable: CalendarIntervalType = this
}

View file

@ -426,38 +426,33 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper {
}
test("interval overflow check") {
intercept[ArithmeticException](negateExact(new CalendarInterval(Int.MinValue, 0, 0)))
assert(negate(new CalendarInterval(Int.MinValue, 0, 0)) ===
new CalendarInterval(Int.MinValue, 0, 0))
intercept[ArithmeticException](negateExact(CalendarInterval.MIN_VALUE))
assert(negate(CalendarInterval.MIN_VALUE) === CalendarInterval.MIN_VALUE)
intercept[ArithmeticException](addExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(0, 0, 1)))
intercept[ArithmeticException](addExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(0, 1, 0)))
intercept[ArithmeticException](addExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(1, 0, 0)))
assert(add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, 1)) ===
new CalendarInterval(Int.MaxValue, Int.MaxValue, Long.MinValue))
assert(add(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 1, 0)) ===
new CalendarInterval(Int.MaxValue, Int.MinValue, Long.MaxValue))
assert(add(CalendarInterval.MAX_VALUE, new CalendarInterval(1, 0, 0)) ===
new CalendarInterval(Int.MinValue, Int.MaxValue, Long.MaxValue))
val maxMonth = new CalendarInterval(Int.MaxValue, 0, 0)
val minMonth = new CalendarInterval(Int.MinValue, 0, 0)
val oneMonth = new CalendarInterval(1, 0, 0)
val maxDay = new CalendarInterval(0, Int.MaxValue, 0)
val minDay = new CalendarInterval(0, Int.MinValue, 0)
val oneDay = new CalendarInterval(0, 1, 0)
val maxMicros = new CalendarInterval(0, 0, Long.MaxValue)
val minMicros = new CalendarInterval(0, 0, Long.MinValue)
val oneMicros = new CalendarInterval(0, 0, 1)
intercept[ArithmeticException](negateExact(minMonth))
assert(negate(minMonth) === minMonth)
intercept[ArithmeticException](subtractExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(0, 0, -1)))
intercept[ArithmeticException](subtractExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(0, -1, 0)))
intercept[ArithmeticException](subtractExact(CalendarInterval.MAX_VALUE,
new CalendarInterval(-1, 0, 0)))
assert(subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, 0, -1)) ===
new CalendarInterval(Int.MaxValue, Int.MaxValue, Long.MinValue))
assert(subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(0, -1, 0)) ===
new CalendarInterval(Int.MaxValue, Int.MinValue, Long.MaxValue))
assert(subtract(CalendarInterval.MAX_VALUE, new CalendarInterval(-1, 0, 0)) ===
new CalendarInterval(Int.MinValue, Int.MaxValue, Long.MaxValue))
intercept[ArithmeticException](addExact(maxMonth, oneMonth))
intercept[ArithmeticException](addExact(maxDay, oneDay))
intercept[ArithmeticException](addExact(maxMicros, oneMicros))
assert(add(maxMonth, oneMonth) === minMonth)
assert(add(maxDay, oneDay) === minDay)
assert(add(maxMicros, oneMicros) === minMicros)
intercept[ArithmeticException](multiplyExact(CalendarInterval.MAX_VALUE, 2))
intercept[ArithmeticException](divideExact(CalendarInterval.MAX_VALUE, 0.5))
intercept[ArithmeticException](subtractExact(minDay, oneDay))
intercept[ArithmeticException](subtractExact(minMonth, oneMonth))
intercept[ArithmeticException](subtractExact(minMicros, oneMicros))
assert(subtract(minMonth, oneMonth) === maxMonth)
assert(subtract(minDay, oneDay) === maxDay)
assert(subtract(minMicros, oneMicros) === maxMicros)
intercept[ArithmeticException](multiplyExact(maxMonth, 2))
intercept[ArithmeticException](divideExact(maxDay, 0.5))
}
}

View file

@ -268,9 +268,9 @@ class Dataset[T] private[sql](
}
}
private[sql] def numericCalculationSupportedColumns: Seq[Expression] = {
queryExecution.analyzed.output.filter { attr =>
TypeCollection.NumericAndInterval.acceptsType(attr.dataType)
private[sql] def numericColumns: Seq[Expression] = {
schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get
}
}

View file

@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StructType, TypeCollection}
import org.apache.spark.sql.types.{NumericType, StructType}
/**
* A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]],
@ -88,20 +88,20 @@ class RelationalGroupedDataset protected[sql](
case expr: Expression => Alias(expr, toPrettySQL(expr))()
}
private[this] def aggregateNumericOrIntervalColumns(
colNames: String*)(f: Expression => AggregateFunction): DataFrame = {
private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction)
: DataFrame = {
val columnExprs = if (colNames.isEmpty) {
// No columns specified. Use all numeric calculation supported columns.
df.numericCalculationSupportedColumns
// No columns specified. Use all numeric columns.
df.numericColumns
} else {
// Make sure all specified columns are numeric calculation supported columns.
// Make sure all specified columns are numeric.
colNames.map { colName =>
val namedExpr = df.resolve(colName)
if (!TypeCollection.NumericAndInterval.acceptsType(namedExpr.dataType)) {
if (!namedExpr.dataType.isInstanceOf[NumericType]) {
throw new AnalysisException(
s""""$colName" is not a numeric or calendar interval column. """ +
"Aggregation function can only be applied on a numeric or calendar interval column.")
s""""$colName" is not a numeric column. """ +
"Aggregation function can only be applied on a numeric column.")
}
namedExpr
}
@ -269,8 +269,7 @@ class RelationalGroupedDataset protected[sql](
def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")()))
/**
* Compute the average value for each numeric or calender interval columns for each group. This
* is an alias for `avg`.
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
* The resulting `DataFrame` will also contain the grouping columns.
* When specified columns are given, only compute the average values for them.
*
@ -278,11 +277,11 @@ class RelationalGroupedDataset protected[sql](
*/
@scala.annotation.varargs
def mean(colNames: String*): DataFrame = {
aggregateNumericOrIntervalColumns(colNames : _*)(Average)
aggregateNumericColumns(colNames : _*)(Average)
}
/**
* Compute the max value for each numeric calender interval columns for each group.
* Compute the max value for each numeric columns for each group.
* The resulting `DataFrame` will also contain the grouping columns.
* When specified columns are given, only compute the max values for them.
*
@ -290,11 +289,11 @@ class RelationalGroupedDataset protected[sql](
*/
@scala.annotation.varargs
def max(colNames: String*): DataFrame = {
aggregateNumericOrIntervalColumns(colNames : _*)(Max)
aggregateNumericColumns(colNames : _*)(Max)
}
/**
* Compute the mean value for each numeric calender interval columns for each group.
* Compute the mean value for each numeric columns for each group.
* The resulting `DataFrame` will also contain the grouping columns.
* When specified columns are given, only compute the mean values for them.
*
@ -302,11 +301,11 @@ class RelationalGroupedDataset protected[sql](
*/
@scala.annotation.varargs
def avg(colNames: String*): DataFrame = {
aggregateNumericOrIntervalColumns(colNames : _*)(Average)
aggregateNumericColumns(colNames : _*)(Average)
}
/**
* Compute the min value for each numeric calender interval column for each group.
* Compute the min value for each numeric column for each group.
* The resulting `DataFrame` will also contain the grouping columns.
* When specified columns are given, only compute the min values for them.
*
@ -314,11 +313,11 @@ class RelationalGroupedDataset protected[sql](
*/
@scala.annotation.varargs
def min(colNames: String*): DataFrame = {
aggregateNumericOrIntervalColumns(colNames : _*)(Min)
aggregateNumericColumns(colNames : _*)(Min)
}
/**
* Compute the sum for each numeric calender interval columns for each group.
* Compute the sum for each numeric columns for each group.
* The resulting `DataFrame` will also contain the grouping columns.
* When specified columns are given, only compute the sum for them.
*
@ -326,7 +325,7 @@ class RelationalGroupedDataset protected[sql](
*/
@scala.annotation.varargs
def sum(colNames: String*): DataFrame = {
aggregateNumericOrIntervalColumns(colNames : _*)(Sum)
aggregateNumericColumns(colNames : _*)(Sum)
}
/**

View file

@ -296,14 +296,8 @@ private[columnar] final class BinaryColumnStats extends ColumnStats {
}
private[columnar] final class IntervalColumnStats extends ColumnStats {
protected var upper: CalendarInterval = CalendarInterval.MIN_VALUE
protected var lower: CalendarInterval = CalendarInterval.MAX_VALUE
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getInterval(ordinal)
if (value.compareTo(upper) > 0) upper = value
if (value.compareTo(lower) < 0) lower = value
sizeInBytes += CALENDAR_INTERVAL.actualSize(row, ordinal)
count += 1
} else {
@ -312,7 +306,7 @@ private[columnar] final class IntervalColumnStats extends ColumnStats {
}
override def collectedStatistics: Array[Any] =
Array[Any](lower, upper, nullCount, count, sizeInBytes)
Array[Any](null, null, nullCount, count, sizeInBytes)
}
private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {

View file

@ -1,47 +1,5 @@
-- test for intervals
-- greater than or equal
select interval '1 day' > interval '23 hour';
select interval '-1 day' >= interval '-23 hour';
select interval '-1 day' > null;
select null > interval '-1 day';
-- less than or equal
select interval '1 minutes' < interval '1 hour';
select interval '-1 day' <= interval '-23 hour';
-- equal
select interval '1 year' = interval '360 days';
select interval '1 year 2 month' = interval '420 days';
select interval '1 year' = interval '365 days';
select interval '1 month' = interval '30 days';
select interval '1 minutes' = interval '1 hour';
select interval '1 minutes' = null;
select null = interval '-1 day';
-- null safe equal
select interval '1 minutes' <=> null;
select null <=> interval '1 minutes';
-- complex interval representation
select INTERVAL '9 years 1 months -1 weeks -4 days -10 hours -46 minutes' > interval '1 minutes';
-- ordering
select cast(v as interval) i from VALUES ('1 seconds'), ('4 seconds'), ('3 seconds') t(v) order by i;
-- unlimited days
select interval '1 month 120 days' > interval '2 month';
select interval '1 month 30 days' = interval '2 month';
-- unlimited microseconds
select interval '1 month 29 days 40 hours' > interval '2 month';
-- max
select max(cast(v as interval)) from VALUES ('1 seconds'), ('4 seconds'), ('3 seconds') t(v);
-- min
select min(cast(v as interval)) from VALUES ('1 seconds'), ('4 seconds'), ('3 seconds') t(v);
-- multiply and divide an interval by a number
select 3 * (timestamp'2019-10-15 10:11:12.001002' - date'2019-10-15');
select interval 4 month 2 weeks 3 microseconds * 1.5;

View file

@ -971,24 +971,4 @@ class DataFrameAggregateSuite extends QueryTest
Row(3, new CalendarInterval(0, 3, 0)) :: Nil)
assert(find(df3.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
}
test("Dataset agg functions support calendar intervals") {
val df1 = Seq((1, "1 day"), (2, "2 day"), (3, "3 day"), (3, null)).toDF("a", "b")
val df2 = df1.select($"a", $"b" cast CalendarIntervalType).groupBy($"a" % 2)
checkAnswer(df2.sum("b"),
Row(0, new CalendarInterval(0, 2, 0)) ::
Row(1, new CalendarInterval(0, 4, 0)) :: Nil)
checkAnswer(df2.avg("b"),
Row(0, new CalendarInterval(0, 2, 0)) ::
Row(1, new CalendarInterval(0, 2, 0)) :: Nil)
checkAnswer(df2.mean("b"),
Row(0, new CalendarInterval(0, 2, 0)) ::
Row(1, new CalendarInterval(0, 2, 0)) :: Nil)
checkAnswer(df2.max("b"),
Row(0, new CalendarInterval(0, 2, 0)) ::
Row(1, new CalendarInterval(0, 3, 0)) :: Nil)
checkAnswer(df2.min("b"),
Row(0, new CalendarInterval(0, 2, 0)) ::
Row(1, new CalendarInterval(0, 1, 0)) :: Nil)
}
}

View file

@ -31,7 +31,7 @@ class ColumnStatsSuite extends SparkFunSuite {
testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, Double.MinValue, 0))
testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0))
testDecimalColumnStats(Array(null, null, 0))
testIntervalColumnStats(Array(CalendarInterval.MAX_VALUE, CalendarInterval.MIN_VALUE, 0))
testIntervalColumnStats(Array(null, null, 0))
def testColumnStats[T <: AtomicType, U <: ColumnStats](
columnStatsClass: Class[U],
@ -126,12 +126,8 @@ class ColumnStatsSuite extends SparkFunSuite {
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(columnStats.gatherStats(_, 0))
val values = rows.take(10).map(_.get(0, columnType.dataType))
val ordering = CalendarIntervalType.ordering.asInstanceOf[Ordering[Any]]
val stats = columnStats.collectedStatistics
assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
assertResult(10, "Wrong null count")(stats(2))
assertResult(20, "Wrong row count")(stats(3))
assertResult(stats(4), "Wrong size in bytes") {

View file

@ -78,7 +78,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
checkActualSize(ARRAY_TYPE, Array[Any](1), 4 + 8 + 8 + 8)
checkActualSize(MAP_TYPE, Map(1 -> "a"), 4 + (8 + 8 + 8 + 8) + (8 + 8 + 8 + 8))
checkActualSize(STRUCT_TYPE, Row("hello"), 28)
checkActualSize(CALENDAR_INTERVAL, CalendarInterval.MAX_VALUE, 4 + 4 + 8)
checkActualSize(CALENDAR_INTERVAL, new CalendarInterval(0, 0, 0), 4 + 4 + 8)
}
testNativeColumnType(BOOLEAN)