[SPARK-30925][SQL] Prevent overflow/round errors in conversions of milliseconds to/from microseconds
### What changes were proposed in this pull request? - Use `Math.multiplyExact()` in `DateTimeUtils.fromMillis()` to prevent silent overflow in conversion milliseconds to microseconds. - Use `DateTimeUtils.fromMillis()` in all places where milliseconds are converted to microseconds - Use `DateTimeUtils.toMillis()` in all places where microseconds are converted to milliseconds ### Why are the changes needed? 1. To prevent silent arithmetic overflow while multiplying by 1000 in `fromMillis()`. Instead of it, `new ArithmeticException("long overflow")` will be thrown, and handled accordantly. 2. To correctly round microseconds in conversion to milliseconds. For example, `1965-01-01 10:11:12.123456` is represented as `-157700927876544` in micro precision. In milliseconds precision the above needs to be represented as `-157700927877` or `1965-01-01 10:11:12.123`. ### Does this PR introduce any user-facing change? Yes ### How was this patch tested? By `TimestampFormatterSuite`, `CastSuite`, `DateExpressionsSuite`, `IntervalExpressionsSuite`, `ExpressionParserSuite`, `ExpressionParserSuite`, `DateTimeUtilsSuite`, `IntervalUtilsSuite` Closes #27676 from MaxGekk/millis-2-micros-overflow. Authored-by: Maxim Gekk <max.gekk@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
7aa94ca9cb
commit
c41ef39819
|
@ -65,7 +65,7 @@ object DateTimeUtils {
|
|||
}
|
||||
|
||||
def millisToDays(millisUtc: Long, zoneId: ZoneId): SQLDate = {
|
||||
val instant = microsToInstant(Math.multiplyExact(millisUtc, MICROS_PER_MILLIS))
|
||||
val instant = microsToInstant(fromMillis(millisUtc))
|
||||
localDateToDays(LocalDateTime.ofInstant(instant, zoneId).toLocalDate)
|
||||
}
|
||||
|
||||
|
@ -76,7 +76,7 @@ object DateTimeUtils {
|
|||
|
||||
def daysToMillis(days: SQLDate, zoneId: ZoneId): Long = {
|
||||
val instant = daysToLocalDate(days).atStartOfDay(zoneId).toInstant
|
||||
instantToMicros(instant) / MICROS_PER_MILLIS
|
||||
toMillis(instantToMicros(instant))
|
||||
}
|
||||
|
||||
// Converts Timestamp to string according to Hive TimestampWritable convention.
|
||||
|
@ -149,7 +149,7 @@ object DateTimeUtils {
|
|||
* Converts milliseconds since epoch to SQLTimestamp.
|
||||
*/
|
||||
def fromMillis(millis: Long): SQLTimestamp = {
|
||||
MILLISECONDS.toMicros(millis)
|
||||
Math.multiplyExact(millis, MICROS_PER_MILLIS)
|
||||
}
|
||||
|
||||
def microsToEpochDays(epochMicros: SQLTimestamp, zoneId: ZoneId): SQLDate = {
|
||||
|
@ -574,8 +574,8 @@ object DateTimeUtils {
|
|||
time2: SQLTimestamp,
|
||||
roundOff: Boolean,
|
||||
zoneId: ZoneId): Double = {
|
||||
val millis1 = MICROSECONDS.toMillis(time1)
|
||||
val millis2 = MICROSECONDS.toMillis(time2)
|
||||
val millis1 = toMillis(time1)
|
||||
val millis2 = toMillis(time2)
|
||||
val date1 = millisToDays(millis1, zoneId)
|
||||
val date2 = millisToDays(millis2, zoneId)
|
||||
val (year1, monthInYear1, dayInMonth1, daysToMonthEnd1) = splitDate(date1)
|
||||
|
@ -714,7 +714,7 @@ object DateTimeUtils {
|
|||
case TRUNC_TO_HOUR => truncToUnit(t, zoneId, ChronoUnit.HOURS)
|
||||
case TRUNC_TO_DAY => truncToUnit(t, zoneId, ChronoUnit.DAYS)
|
||||
case _ =>
|
||||
val millis = MICROSECONDS.toMillis(t)
|
||||
val millis = toMillis(t)
|
||||
val truncated = level match {
|
||||
case TRUNC_TO_MILLISECOND => millis
|
||||
case TRUNC_TO_SECOND =>
|
||||
|
@ -725,7 +725,7 @@ object DateTimeUtils {
|
|||
val dDays = millisToDays(millis, zoneId)
|
||||
daysToMillis(truncDate(dDays, level), zoneId)
|
||||
}
|
||||
truncated * MICROS_PER_MILLIS
|
||||
fromMillis(truncated)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit
|
|||
import scala.util.control.NonFatal
|
||||
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils.fromMillis
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types.Decimal
|
||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||
|
@ -704,9 +705,7 @@ object IntervalUtils {
|
|||
microseconds = Math.addExact(microseconds, minutesUs)
|
||||
i += minuteStr.numBytes()
|
||||
} else if (s.matchAt(millisStr, i)) {
|
||||
val millisUs = Math.multiplyExact(
|
||||
currentValue,
|
||||
MICROS_PER_MILLIS)
|
||||
val millisUs = fromMillis(currentValue)
|
||||
microseconds = Math.addExact(microseconds, millisUs)
|
||||
i += millisStr.numBytes()
|
||||
} else if (s.matchAt(microsStr, i)) {
|
||||
|
|
|
@ -28,7 +28,7 @@ import java.util.concurrent.TimeUnit.SECONDS
|
|||
import org.apache.commons.lang3.time.FastDateFormat
|
||||
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{ convertSpecialTimestamp, SQLTimestamp}
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types.Decimal
|
||||
|
||||
|
@ -141,7 +141,7 @@ class LegacyFastTimestampFormatter(
|
|||
}
|
||||
val micros = cal.getMicros()
|
||||
cal.set(Calendar.MILLISECOND, 0)
|
||||
cal.getTimeInMillis * MICROS_PER_MILLIS + micros
|
||||
Math.addExact(fromMillis(cal.getTimeInMillis), micros)
|
||||
}
|
||||
|
||||
def format(timestamp: SQLTimestamp): String = {
|
||||
|
@ -164,7 +164,7 @@ class LegacySimpleTimestampFormatter(
|
|||
}
|
||||
|
||||
override def parse(s: String): Long = {
|
||||
sdf.parse(s).getTime * MICROS_PER_MILLIS
|
||||
fromMillis(sdf.parse(s).getTime)
|
||||
}
|
||||
|
||||
override def format(us: Long): String = {
|
||||
|
|
|
@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
|
||||
import java.sql.{Date, Timestamp}
|
||||
import java.util.{Calendar, TimeZone}
|
||||
import java.util.concurrent.TimeUnit._
|
||||
|
||||
import scala.collection.parallel.immutable.ParVector
|
||||
|
||||
|
@ -272,13 +271,13 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(
|
||||
cast(cast(new Timestamp(c.getTimeInMillis), StringType, timeZoneId),
|
||||
TimestampType, timeZoneId),
|
||||
MILLISECONDS.toMicros(c.getTimeInMillis))
|
||||
fromMillis(c.getTimeInMillis))
|
||||
c = Calendar.getInstance(TimeZoneGMT)
|
||||
c.set(2015, 10, 1, 2, 30, 0)
|
||||
checkEvaluation(
|
||||
cast(cast(new Timestamp(c.getTimeInMillis), StringType, timeZoneId),
|
||||
TimestampType, timeZoneId),
|
||||
MILLISECONDS.toMicros(c.getTimeInMillis))
|
||||
fromMillis(c.getTimeInMillis))
|
||||
}
|
||||
|
||||
val gmtId = Option("GMT")
|
||||
|
|
|
@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp}
|
|||
import java.text.SimpleDateFormat
|
||||
import java.time.{Instant, LocalDate, LocalDateTime, ZoneId, ZoneOffset}
|
||||
import java.util.{Calendar, Locale, TimeZone}
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.TimeUnit._
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
|
@ -48,7 +47,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
|
||||
def toMillis(timestamp: String): Long = {
|
||||
val tf = TimestampFormatter("yyyy-MM-dd HH:mm:ss", ZoneOffset.UTC)
|
||||
TimeUnit.MICROSECONDS.toMillis(tf.parse(timestamp))
|
||||
DateTimeUtils.toMillis(tf.parse(timestamp))
|
||||
}
|
||||
val date = "2015-04-08 13:10:15"
|
||||
val d = new Date(toMillis(date))
|
||||
|
|
|
@ -21,6 +21,7 @@ import scala.language.implicitConversions
|
|||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
|
||||
import org.apache.spark.sql.catalyst.util.IntervalUtils.{safeStringToInterval, stringToInterval}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types.Decimal
|
||||
|
@ -260,7 +261,7 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
seconds: Int = 0,
|
||||
millis: Int = 0,
|
||||
micros: Int = 0): Unit = {
|
||||
val secFrac = seconds * MICROS_PER_SECOND + millis * MICROS_PER_MILLIS + micros
|
||||
val secFrac = DateTimeTestUtils.secFrac(seconds, millis, micros)
|
||||
val intervalExpr = MakeInterval(Literal(years), Literal(months), Literal(weeks),
|
||||
Literal(days), Literal(hours), Literal(minutes), Literal(Decimal(secFrac, 8, 6)))
|
||||
val totalMonths = years * MONTHS_PER_YEAR + months
|
||||
|
|
|
@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _}
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
|
||||
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, IntervalUtils}
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
|
||||
import org.apache.spark.sql.catalyst.util.IntervalUtils.IntervalUnit._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -681,13 +680,13 @@ class ExpressionParserSuite extends AnalysisTest {
|
|||
Literal(new CalendarInterval(
|
||||
0,
|
||||
0,
|
||||
-13 * MICROS_PER_SECOND - 123 * MICROS_PER_MILLIS - 456)))
|
||||
DateTimeTestUtils.secFrac(-13, -123, -456))))
|
||||
checkIntervals(
|
||||
"13.123456 second",
|
||||
Literal(new CalendarInterval(
|
||||
0,
|
||||
0,
|
||||
13 * MICROS_PER_SECOND + 123 * MICROS_PER_MILLIS + 456)))
|
||||
DateTimeTestUtils.secFrac(13, 123, 456))))
|
||||
checkIntervals("1.001 second",
|
||||
Literal(IntervalUtils.stringToInterval("1 second 1 millisecond")))
|
||||
|
||||
|
|
|
@ -21,6 +21,8 @@ import java.time.{LocalDate, LocalDateTime, LocalTime, ZoneId, ZoneOffset}
|
|||
import java.util.TimeZone
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
|
||||
|
||||
/**
|
||||
* Helper functions for testing date and time functionality.
|
||||
*/
|
||||
|
@ -95,4 +97,11 @@ object DateTimeTestUtils {
|
|||
val localDateTime = LocalDateTime.of(localDate, localTime)
|
||||
localDateTimeToMicros(localDateTime, zid)
|
||||
}
|
||||
|
||||
def secFrac(seconds: Int, milliseconds: Int, microseconds: Int): Long = {
|
||||
var result: Long = microseconds
|
||||
result = Math.addExact(result, Math.multiplyExact(milliseconds, MICROS_PER_MILLIS))
|
||||
result = Math.addExact(result, Math.multiplyExact(seconds, MICROS_PER_SECOND))
|
||||
result
|
||||
}
|
||||
}
|
||||
|
|
|
@ -89,8 +89,7 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
|
|||
|
||||
test("SPARK-6785: java date conversion before and after epoch") {
|
||||
def format(d: Date): String = {
|
||||
TimestampFormatter("uuuu-MM-dd", defaultTimeZone().toZoneId)
|
||||
.format(d.getTime * MICROS_PER_MILLIS)
|
||||
TimestampFormatter("uuuu-MM-dd", defaultTimeZone().toZoneId).format(fromMillis(d.getTime))
|
||||
}
|
||||
def checkFromToJavaDate(d1: Date): Unit = {
|
||||
val d2 = toJavaDate(fromJavaDate(d1))
|
||||
|
@ -584,15 +583,15 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
|
|||
}
|
||||
|
||||
test("daysToMillis and millisToDays") {
|
||||
val input = TimeUnit.MICROSECONDS.toMillis(date(2015, 12, 31, 16, zid = zonePST))
|
||||
val input = toMillis(date(2015, 12, 31, 16, zid = zonePST))
|
||||
assert(millisToDays(input, zonePST) === 16800)
|
||||
assert(millisToDays(input, ZoneOffset.UTC) === 16801)
|
||||
assert(millisToDays(-1 * MILLIS_PER_DAY + 1, ZoneOffset.UTC) == -1)
|
||||
|
||||
var expected = TimeUnit.MICROSECONDS.toMillis(date(2015, 12, 31, zid = zonePST))
|
||||
var expected = toMillis(date(2015, 12, 31, zid = zonePST))
|
||||
assert(daysToMillis(16800, zonePST) === expected)
|
||||
|
||||
expected = TimeUnit.MICROSECONDS.toMillis(date(2015, 12, 31, zid = zoneGMT))
|
||||
expected = toMillis(date(2015, 12, 31, zid = zoneGMT))
|
||||
assert(daysToMillis(16800, ZoneOffset.UTC) === expected)
|
||||
|
||||
// There are some days are skipped entirely in some timezone, skip them here.
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.util.concurrent.TimeUnit
|
|||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.plans.SQLHelper
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils.fromMillis
|
||||
import org.apache.spark.sql.catalyst.util.IntervalUtils._
|
||||
import org.apache.spark.sql.catalyst.util.IntervalUtils.IntervalUnit._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
@ -76,7 +77,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper {
|
|||
testSingleUnit("HouR", 3, 0, 0, 3 * MICROS_PER_HOUR)
|
||||
testSingleUnit("MiNuTe", 3, 0, 0, 3 * MICROS_PER_MINUTE)
|
||||
testSingleUnit("Second", 3, 0, 0, 3 * MICROS_PER_SECOND)
|
||||
testSingleUnit("MilliSecond", 3, 0, 0, 3 * MICROS_PER_MILLIS)
|
||||
testSingleUnit("MilliSecond", 3, 0, 0, fromMillis(3))
|
||||
testSingleUnit("MicroSecond", 3, 0, 0, 3)
|
||||
|
||||
checkFromInvalidString(null, "cannot be null")
|
||||
|
@ -175,7 +176,7 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper {
|
|||
new CalendarInterval(
|
||||
0,
|
||||
10,
|
||||
12 * MICROS_PER_MINUTE + 888 * MICROS_PER_MILLIS))
|
||||
12 * MICROS_PER_MINUTE + fromMillis(888)))
|
||||
assert(fromDayTimeString("-3 0:0:0") === new CalendarInterval(0, -3, 0L))
|
||||
|
||||
try {
|
||||
|
|
|
@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils.toMillis
|
||||
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
|
||||
import org.apache.spark.sql.types.MetadataBuilder
|
||||
import org.apache.spark.unsafe.types.CalendarInterval
|
||||
|
@ -100,7 +100,7 @@ case class EventTimeWatermarkExec(
|
|||
child.execute().mapPartitions { iter =>
|
||||
val getEventTime = UnsafeProjection.create(eventTime :: Nil, child.output)
|
||||
iter.map { row =>
|
||||
eventTimeStats.add(getEventTime(row).getLong(0) / MICROS_PER_MILLIS)
|
||||
eventTimeStats.add(toMillis(getEventTime(row).getLong(0)))
|
||||
row
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.util.concurrent.TimeUnit
|
|||
import scala.concurrent.duration.Duration
|
||||
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils.toMillis
|
||||
import org.apache.spark.sql.catalyst.util.IntervalUtils
|
||||
import org.apache.spark.sql.streaming.Trigger
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
@ -36,7 +37,8 @@ private object Triggers {
|
|||
if (cal.months != 0) {
|
||||
throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval")
|
||||
}
|
||||
TimeUnit.MICROSECONDS.toMillis(cal.microseconds + cal.days * MICROS_PER_DAY)
|
||||
val microsInDays = Math.multiplyExact(cal.days, MICROS_PER_DAY)
|
||||
toMillis(Math.addExact(cal.microseconds, microsInDays))
|
||||
}
|
||||
|
||||
def convert(interval: Duration): Long = interval.toMillis
|
||||
|
|
|
@ -20,7 +20,6 @@ package org.apache.spark.sql
|
|||
import java.{lang => jl}
|
||||
import java.io.File
|
||||
import java.sql.{Date, Timestamp}
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.util.Random
|
||||
|
@ -30,6 +29,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatisti
|
|||
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, HistogramBin, HistogramSerializer, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.execution.datasources.LogicalRelation
|
||||
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
|
@ -51,10 +51,10 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils
|
|||
private val d2 = Date.valueOf(d2Str)
|
||||
private val t1Str = "2016-05-08 00:00:01.000000"
|
||||
private val t1Internal = date(2016, 5, 8, 0, 0, 1)
|
||||
private val t1 = new Timestamp(TimeUnit.MICROSECONDS.toMillis(t1Internal))
|
||||
private val t1 = new Timestamp(DateTimeUtils.toMillis(t1Internal))
|
||||
private val t2Str = "2016-05-09 00:00:02.000000"
|
||||
private val t2Internal = date(2016, 5, 9, 0, 0, 2)
|
||||
private val t2 = new Timestamp(TimeUnit.MICROSECONDS.toMillis(t2Internal))
|
||||
private val t2 = new Timestamp(DateTimeUtils.toMillis(t2Internal))
|
||||
|
||||
/**
|
||||
* Define a very simple 3 row table used for testing column serialization.
|
||||
|
|
Loading…
Reference in a new issue