[SPARK-30925][SQL] Prevent overflow/round errors in conversions of milliseconds to/from microseconds

### What changes were proposed in this pull request?
- Use `Math.multiplyExact()` in `DateTimeUtils.fromMillis()` to prevent silent overflow in conversion milliseconds to microseconds.
- Use `DateTimeUtils.fromMillis()` in all places where milliseconds are converted to microseconds
- Use `DateTimeUtils.toMillis()` in all places where microseconds are converted to milliseconds

### Why are the changes needed?

1. To prevent silent arithmetic overflow while multiplying by 1000 in `fromMillis()`. Instead of it, `new ArithmeticException("long overflow")` will be thrown, and handled accordantly.
2. To correctly round microseconds in conversion to milliseconds. For example, `1965-01-01 10:11:12.123456` is represented as `-157700927876544` in micro precision. In milliseconds precision the above needs to be represented as `-157700927877` or `1965-01-01 10:11:12.123`.

### Does this PR introduce any user-facing change?
Yes

### How was this patch tested?
By `TimestampFormatterSuite`, `CastSuite`, `DateExpressionsSuite`, `IntervalExpressionsSuite`, `ExpressionParserSuite`, `ExpressionParserSuite`, `DateTimeUtilsSuite`, `IntervalUtilsSuite`

Closes #27676 from MaxGekk/millis-2-micros-overflow.

Authored-by: Maxim Gekk <max.gekk@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Maxim Gekk 2020-02-24 14:06:25 +08:00 committed by Wenchen Fan
parent 7aa94ca9cb
commit c41ef39819
13 changed files with 43 additions and 35 deletions

View file

@ -65,7 +65,7 @@ object DateTimeUtils {
}
def millisToDays(millisUtc: Long, zoneId: ZoneId): SQLDate = {
val instant = microsToInstant(Math.multiplyExact(millisUtc, MICROS_PER_MILLIS))
val instant = microsToInstant(fromMillis(millisUtc))
localDateToDays(LocalDateTime.ofInstant(instant, zoneId).toLocalDate)
}
@ -76,7 +76,7 @@ object DateTimeUtils {
def daysToMillis(days: SQLDate, zoneId: ZoneId): Long = {
val instant = daysToLocalDate(days).atStartOfDay(zoneId).toInstant
instantToMicros(instant) / MICROS_PER_MILLIS
toMillis(instantToMicros(instant))
}
// Converts Timestamp to string according to Hive TimestampWritable convention.
@ -149,7 +149,7 @@ object DateTimeUtils {
* Converts milliseconds since epoch to SQLTimestamp.
*/
def fromMillis(millis: Long): SQLTimestamp = {
MILLISECONDS.toMicros(millis)
Math.multiplyExact(millis, MICROS_PER_MILLIS)
}
def microsToEpochDays(epochMicros: SQLTimestamp, zoneId: ZoneId): SQLDate = {
@ -574,8 +574,8 @@ object DateTimeUtils {
time2: SQLTimestamp,
roundOff: Boolean,
zoneId: ZoneId): Double = {
val millis1 = MICROSECONDS.toMillis(time1)
val millis2 = MICROSECONDS.toMillis(time2)
val millis1 = toMillis(time1)
val millis2 = toMillis(time2)
val date1 = millisToDays(millis1, zoneId)
val date2 = millisToDays(millis2, zoneId)
val (year1, monthInYear1, dayInMonth1, daysToMonthEnd1) = splitDate(date1)
@ -714,7 +714,7 @@ object DateTimeUtils {
case TRUNC_TO_HOUR => truncToUnit(t, zoneId, ChronoUnit.HOURS)
case TRUNC_TO_DAY => truncToUnit(t, zoneId, ChronoUnit.DAYS)
case _ =>
val millis = MICROSECONDS.toMillis(t)
val millis = toMillis(t)
val truncated = level match {
case TRUNC_TO_MILLISECOND => millis
case TRUNC_TO_SECOND =>
@ -725,7 +725,7 @@ object DateTimeUtils {
val dDays = millisToDays(millis, zoneId)
daysToMillis(truncDate(dDays, level), zoneId)
}
truncated * MICROS_PER_MILLIS
fromMillis(truncated)
}
}

View file

@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit
import scala.util.control.NonFatal
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.fromMillis
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@ -704,9 +705,7 @@ object IntervalUtils {
microseconds = Math.addExact(microseconds, minutesUs)
i += minuteStr.numBytes()
} else if (s.matchAt(millisStr, i)) {
val millisUs = Math.multiplyExact(
currentValue,
MICROS_PER_MILLIS)
val millisUs = fromMillis(currentValue)
microseconds = Math.addExact(microseconds, millisUs)
i += millisStr.numBytes()
} else if (s.matchAt(microsStr, i)) {

View file

@ -28,7 +28,7 @@ import java.util.concurrent.TimeUnit.SECONDS
import org.apache.commons.lang3.time.FastDateFormat
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{ convertSpecialTimestamp, SQLTimestamp}
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.Decimal
@ -141,7 +141,7 @@ class LegacyFastTimestampFormatter(
}
val micros = cal.getMicros()
cal.set(Calendar.MILLISECOND, 0)
cal.getTimeInMillis * MICROS_PER_MILLIS + micros
Math.addExact(fromMillis(cal.getTimeInMillis), micros)
}
def format(timestamp: SQLTimestamp): String = {
@ -164,7 +164,7 @@ class LegacySimpleTimestampFormatter(
}
override def parse(s: String): Long = {
sdf.parse(s).getTime * MICROS_PER_MILLIS
fromMillis(sdf.parse(s).getTime)
}
override def format(us: Long): String = {

View file

@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
import java.util.{Calendar, TimeZone}
import java.util.concurrent.TimeUnit._
import scala.collection.parallel.immutable.ParVector
@ -272,13 +271,13 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(
cast(cast(new Timestamp(c.getTimeInMillis), StringType, timeZoneId),
TimestampType, timeZoneId),
MILLISECONDS.toMicros(c.getTimeInMillis))
fromMillis(c.getTimeInMillis))
c = Calendar.getInstance(TimeZoneGMT)
c.set(2015, 10, 1, 2, 30, 0)
checkEvaluation(
cast(cast(new Timestamp(c.getTimeInMillis), StringType, timeZoneId),
TimestampType, timeZoneId),
MILLISECONDS.toMicros(c.getTimeInMillis))
fromMillis(c.getTimeInMillis))
}
val gmtId = Option("GMT")

View file

@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.time.{Instant, LocalDate, LocalDateTime, ZoneId, ZoneOffset}
import java.util.{Calendar, Locale, TimeZone}
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeUnit._
import org.apache.spark.SparkFunSuite
@ -48,7 +47,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
def toMillis(timestamp: String): Long = {
val tf = TimestampFormatter("yyyy-MM-dd HH:mm:ss", ZoneOffset.UTC)
TimeUnit.MICROSECONDS.toMillis(tf.parse(timestamp))
DateTimeUtils.toMillis(tf.parse(timestamp))
}
val date = "2015-04-08 13:10:15"
val d = new Date(toMillis(date))

View file

@ -21,6 +21,7 @@ import scala.language.implicitConversions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.catalyst.util.IntervalUtils.{safeStringToInterval, stringToInterval}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.Decimal
@ -260,7 +261,7 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
seconds: Int = 0,
millis: Int = 0,
micros: Int = 0): Unit = {
val secFrac = seconds * MICROS_PER_SECOND + millis * MICROS_PER_MILLIS + micros
val secFrac = DateTimeTestUtils.secFrac(seconds, millis, micros)
val intervalExpr = MakeInterval(Literal(years), Literal(months), Literal(weeks),
Literal(days), Literal(hours), Literal(minutes), Literal(Decimal(secFrac, 8, 6)))
val totalMonths = years * MONTHS_PER_YEAR + months

View file

@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.IntervalUtils.IntervalUnit._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@ -681,13 +680,13 @@ class ExpressionParserSuite extends AnalysisTest {
Literal(new CalendarInterval(
0,
0,
-13 * MICROS_PER_SECOND - 123 * MICROS_PER_MILLIS - 456)))
DateTimeTestUtils.secFrac(-13, -123, -456))))
checkIntervals(
"13.123456 second",
Literal(new CalendarInterval(
0,
0,
13 * MICROS_PER_SECOND + 123 * MICROS_PER_MILLIS + 456)))
DateTimeTestUtils.secFrac(13, 123, 456))))
checkIntervals("1.001 second",
Literal(IntervalUtils.stringToInterval("1 second 1 millisecond")))

View file

@ -21,6 +21,8 @@ import java.time.{LocalDate, LocalDateTime, LocalTime, ZoneId, ZoneOffset}
import java.util.TimeZone
import java.util.concurrent.TimeUnit
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
/**
* Helper functions for testing date and time functionality.
*/
@ -95,4 +97,11 @@ object DateTimeTestUtils {
val localDateTime = LocalDateTime.of(localDate, localTime)
localDateTimeToMicros(localDateTime, zid)
}
def secFrac(seconds: Int, milliseconds: Int, microseconds: Int): Long = {
var result: Long = microseconds
result = Math.addExact(result, Math.multiplyExact(milliseconds, MICROS_PER_MILLIS))
result = Math.addExact(result, Math.multiplyExact(seconds, MICROS_PER_SECOND))
result
}
}

View file

@ -89,8 +89,7 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
test("SPARK-6785: java date conversion before and after epoch") {
def format(d: Date): String = {
TimestampFormatter("uuuu-MM-dd", defaultTimeZone().toZoneId)
.format(d.getTime * MICROS_PER_MILLIS)
TimestampFormatter("uuuu-MM-dd", defaultTimeZone().toZoneId).format(fromMillis(d.getTime))
}
def checkFromToJavaDate(d1: Date): Unit = {
val d2 = toJavaDate(fromJavaDate(d1))
@ -584,15 +583,15 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
}
test("daysToMillis and millisToDays") {
val input = TimeUnit.MICROSECONDS.toMillis(date(2015, 12, 31, 16, zid = zonePST))
val input = toMillis(date(2015, 12, 31, 16, zid = zonePST))
assert(millisToDays(input, zonePST) === 16800)
assert(millisToDays(input, ZoneOffset.UTC) === 16801)
assert(millisToDays(-1 * MILLIS_PER_DAY + 1, ZoneOffset.UTC) == -1)
var expected = TimeUnit.MICROSECONDS.toMillis(date(2015, 12, 31, zid = zonePST))
var expected = toMillis(date(2015, 12, 31, zid = zonePST))
assert(daysToMillis(16800, zonePST) === expected)
expected = TimeUnit.MICROSECONDS.toMillis(date(2015, 12, 31, zid = zoneGMT))
expected = toMillis(date(2015, 12, 31, zid = zoneGMT))
assert(daysToMillis(16800, ZoneOffset.UTC) === expected)
// There are some days are skipped entirely in some timezone, skip them here.

View file

@ -22,6 +22,7 @@ import java.util.concurrent.TimeUnit
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.fromMillis
import org.apache.spark.sql.catalyst.util.IntervalUtils._
import org.apache.spark.sql.catalyst.util.IntervalUtils.IntervalUnit._
import org.apache.spark.sql.internal.SQLConf
@ -76,7 +77,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper {
testSingleUnit("HouR", 3, 0, 0, 3 * MICROS_PER_HOUR)
testSingleUnit("MiNuTe", 3, 0, 0, 3 * MICROS_PER_MINUTE)
testSingleUnit("Second", 3, 0, 0, 3 * MICROS_PER_SECOND)
testSingleUnit("MilliSecond", 3, 0, 0, 3 * MICROS_PER_MILLIS)
testSingleUnit("MilliSecond", 3, 0, 0, fromMillis(3))
testSingleUnit("MicroSecond", 3, 0, 0, 3)
checkFromInvalidString(null, "cannot be null")
@ -175,7 +176,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper {
new CalendarInterval(
0,
10,
12 * MICROS_PER_MINUTE + 888 * MICROS_PER_MILLIS))
12 * MICROS_PER_MINUTE + fromMillis(888)))
assert(fromDayTimeString("-3 0:0:0") === new CalendarInterval(0, -3, 0L))
try {

View file

@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS
import org.apache.spark.sql.catalyst.util.DateTimeUtils.toMillis
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.types.MetadataBuilder
import org.apache.spark.unsafe.types.CalendarInterval
@ -100,7 +100,7 @@ case class EventTimeWatermarkExec(
child.execute().mapPartitions { iter =>
val getEventTime = UnsafeProjection.create(eventTime :: Nil, child.output)
iter.map { row =>
eventTimeStats.add(getEventTime(row).getLong(0) / MICROS_PER_MILLIS)
eventTimeStats.add(toMillis(getEventTime(row).getLong(0)))
row
}
}

View file

@ -22,6 +22,7 @@ import java.util.concurrent.TimeUnit
import scala.concurrent.duration.Duration
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY
import org.apache.spark.sql.catalyst.util.DateTimeUtils.toMillis
import org.apache.spark.sql.catalyst.util.IntervalUtils
import org.apache.spark.sql.streaming.Trigger
import org.apache.spark.unsafe.types.UTF8String
@ -36,7 +37,8 @@ private object Triggers {
if (cal.months != 0) {
throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval")
}
TimeUnit.MICROSECONDS.toMillis(cal.microseconds + cal.days * MICROS_PER_DAY)
val microsInDays = Math.multiplyExact(cal.days, MICROS_PER_DAY)
toMillis(Math.addExact(cal.microseconds, microsInDays))
}
def convert(interval: Duration): Long = interval.toMillis

View file

@ -20,7 +20,6 @@ package org.apache.spark.sql
import java.{lang => jl}
import java.io.File
import java.sql.{Date, Timestamp}
import java.util.concurrent.TimeUnit
import scala.collection.mutable
import scala.util.Random
@ -30,6 +29,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatisti
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, HistogramSerializer, LogicalPlan}
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.test.SQLTestUtils
@ -51,10 +51,10 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils
private val d2 = Date.valueOf(d2Str)
private val t1Str = "2016-05-08 00:00:01.000000"
private val t1Internal = date(2016, 5, 8, 0, 0, 1)
private val t1 = new Timestamp(TimeUnit.MICROSECONDS.toMillis(t1Internal))
private val t1 = new Timestamp(DateTimeUtils.toMillis(t1Internal))
private val t2Str = "2016-05-09 00:00:02.000000"
private val t2Internal = date(2016, 5, 9, 0, 0, 2)
private val t2 = new Timestamp(TimeUnit.MICROSECONDS.toMillis(t2Internal))
private val t2 = new Timestamp(DateTimeUtils.toMillis(t2Internal))
/**
* Define a very simple 3 row table used for testing column serialization.