diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 40b9d6554d..aeb236ea5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -158,7 +158,8 @@ abstract class Optimizer(catalogManager: CatalogManager) RewriteNonCorrelatedExists, PullOutGroupingExpressions, ComputeCurrentTime, - ReplaceCurrentLike(catalogManager)) :: + ReplaceCurrentLike(catalogManager), + SpecialDatetimeValues) :: ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here ////////////////////////////////////////////////////////////////////////////////////////// @@ -265,6 +266,7 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateView.ruleName :: ReplaceExpressions.ruleName :: ComputeCurrentTime.ruleName :: + SpecialDatetimeValues.ruleName :: ReplaceCurrentLike(catalogManager).ruleName :: RewriteDistinctAggregates.ruleName :: ReplaceDeduplicateWithAggregate.ruleName :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index deacc3b9a1..daf4c5e275 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -119,3 +120,25 @@ case class ReplaceCurrentLike(catalogManager: CatalogManager) extends Rule[Logic } } } + +/** + * Replaces casts of special datetime strings by its date/timestamp values + * if the input strings are foldable. + */ +object SpecialDatetimeValues extends Rule[LogicalPlan] { + private val conv = Map[DataType, (String, java.time.ZoneId) => Option[Any]]( + DateType -> convertSpecialDate, + TimestampType -> convertSpecialTimestamp, + TimestampNTZType -> ((s: String, _: java.time.ZoneId) => convertSpecialTimestampNTZ(s)) + ) + def apply(plan: LogicalPlan): LogicalPlan = { + plan.transformAllExpressionsWithPruning(_.containsPattern(CAST)) { + case cast @ Cast(e, dt @ (DateType | TimestampType | TimestampNTZType), _, _) + if e.foldable && e.dataType == StringType => + Option(e.eval()) + .flatMap(s => conv(dt)(s.toString, cast.zoneId)) + .map(Literal(_, dt)) + .getOrElse(cast) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SpecialDatetimeValuesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SpecialDatetimeValuesSuite.scala new file mode 100644 index 0000000000..e68a751a6e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SpecialDatetimeValuesSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import java.time.{Instant, LocalDate, LocalDateTime, LocalTime, ZoneId} + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MINUTE +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateTimeToMicros} +import org.apache.spark.sql.types.{AtomicType, DateType, TimestampNTZType, TimestampType} + +class SpecialDatetimeValuesSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Seq(Batch("SpecialDatetimeValues", Once, SpecialDatetimeValues)) + } + + test("special date values") { + testSpecialDatetimeValues { zoneId => + val expected = Set( + LocalDate.ofEpochDay(0), + LocalDate.now(zoneId), + LocalDate.now(zoneId).minusDays(1), + LocalDate.now(zoneId).plusDays(1) + ).map(_.toEpochDay.toInt) + val in = Project(Seq( + Alias(Cast(Literal("epoch"), DateType, Some(zoneId.getId)), "epoch")(), + Alias(Cast(Literal("today"), DateType, Some(zoneId.getId)), "today")(), + Alias(Cast(Literal("yesterday"), DateType, Some(zoneId.getId)), "yesterday")(), + Alias(Cast(Literal("tomorrow"), DateType, Some(zoneId.getId)), "tomorrow")()), + LocalRelation()) + + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val lits = new scala.collection.mutable.ArrayBuffer[Int] + plan.transformAllExpressions { case e: Literal if e.dataType == DateType => + lits += e.value.asInstanceOf[Int] + e + } + assert(expected === lits.toSet) + } + } + + private def testSpecialTs(tsType: AtomicType, expected: Set[Long], zoneId: ZoneId): Unit = { + val in = Project(Seq( + Alias(Cast(Literal("epoch"), tsType, Some(zoneId.getId)), "epoch")(), + Alias(Cast(Literal("now"), tsType, Some(zoneId.getId)), "now")(), + Alias(Cast(Literal("tomorrow"), tsType, Some(zoneId.getId)), "tomorrow")(), + Alias(Cast(Literal("yesterday"), tsType, Some(zoneId.getId)), "yesterday")()), + LocalRelation()) + + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val lits = new scala.collection.mutable.ArrayBuffer[Long] + plan.transformAllExpressions { case e: Literal if e.dataType == tsType => + lits += e.value.asInstanceOf[Long] + e + } + assert(lits.forall(ts => expected.exists(ets => Math.abs(ets -ts) <= MICROS_PER_MINUTE))) + } + + test("special timestamp_ltz values") { + testSpecialDatetimeValues { zoneId => + val expected = Set( + Instant.ofEpochSecond(0), + Instant.now(), + Instant.now().atZone(zoneId).`with`(LocalTime.MIDNIGHT).plusDays(1).toInstant, + Instant.now().atZone(zoneId).`with`(LocalTime.MIDNIGHT).minusDays(1).toInstant + ).map(instantToMicros) + testSpecialTs(TimestampType, expected, zoneId) + } + } + + test("special timestamp_ntz values") { + testSpecialDatetimeValues { zoneId => + val expected = Set( + LocalDateTime.of(1970, 1, 1, 0, 0), + LocalDateTime.now(), + LocalDateTime.now().`with`(LocalTime.MIDNIGHT).plusDays(1), + LocalDateTime.now().`with`(LocalTime.MIDNIGHT).minusDays(1) + ).map(localDateTimeToMicros) + testSpecialTs(TimestampNTZType, expected, zoneId) + } + } +}