[SPARK-36567][SQL] Support foldable special datetime strings by CAST

### What changes were proposed in this pull request?
In the PR, I propose to add new correctness rule `SpecialDatetimeValues` to the final analysis phase. It replaces casts of strings to date/timestamp_ltz/timestamp_ntz by literals of such types if the strings contain special datetime values like `today`, `yesterday` and `tomorrow`, and the input strings are foldable.

### Why are the changes needed?
1. To avoid a breaking change.
2. To improve user experience with Spark SQL. After the PR https://github.com/apache/spark/pull/32714, users have to use typed literals instead of implicit casts. For instance,
at Spark 3.1:
```sql
select ts_col > 'now';
```
but the query fails at the moment, and users have to use typed timestamp literal:
```sql
select ts_col > timestamp'now';
```

### Does this PR introduce _any_ user-facing change?
No. Previous release 3.1 has supported the feature already till it was removed by https://github.com/apache/spark/pull/32714.

### How was this patch tested?
1. Manually tested via the sql command line:
```sql
spark-sql> select cast('today' as date);
2021-08-24
spark-sql> select timestamp('today');
2021-08-24 00:00:00
spark-sql> select timestamp'tomorrow' > 'today';
true
```
2. By running new test suite:
```
$ build/sbt "sql/testOnly org.apache.spark.sql.catalyst.optimizer.SpecialDatetimeValuesSuite"
```

Closes #33816 from MaxGekk/foldable-datetime-special-values.

Authored-by: Max Gekk <max.gekk@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit df0ec56723)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Max Gekk 2021-08-25 14:08:59 +08:00 committed by Wenchen Fan
parent beabf91ea1
commit a4c5140242
3 changed files with 127 additions and 1 deletions

View file

@ -158,7 +158,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
RewriteNonCorrelatedExists,
PullOutGroupingExpressions,
ComputeCurrentTime,
ReplaceCurrentLike(catalogManager)) ::
ReplaceCurrentLike(catalogManager),
SpecialDatetimeValues) ::
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
@ -265,6 +266,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateView.ruleName ::
ReplaceExpressions.ruleName ::
ComputeCurrentTime.ruleName ::
SpecialDatetimeValues.ruleName ::
ReplaceCurrentLike(catalogManager).ruleName ::
RewriteDistinctAggregates.ruleName ::
ReplaceDeduplicateWithAggregate.ruleName ::

View file

@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ}
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@ -119,3 +120,25 @@ case class ReplaceCurrentLike(catalogManager: CatalogManager) extends Rule[Logic
}
}
}
/**
* Replaces casts of special datetime strings by its date/timestamp values
* if the input strings are foldable.
*/
object SpecialDatetimeValues extends Rule[LogicalPlan] {
private val conv = Map[DataType, (String, java.time.ZoneId) => Option[Any]](
DateType -> convertSpecialDate,
TimestampType -> convertSpecialTimestamp,
TimestampNTZType -> ((s: String, _: java.time.ZoneId) => convertSpecialTimestampNTZ(s))
)
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformAllExpressionsWithPruning(_.containsPattern(CAST)) {
case cast @ Cast(e, dt @ (DateType | TimestampType | TimestampNTZType), _, _)
if e.foldable && e.dataType == StringType =>
Option(e.eval())
.flatMap(s => conv(dt)(s.toString, cast.zoneId))
.map(Literal(_, dt))
.getOrElse(cast)
}
}
}

View file

@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.optimizer
import java.time.{Instant, LocalDate, LocalDateTime, LocalTime, ZoneId}
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MINUTE
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateTimeToMicros}
import org.apache.spark.sql.types.{AtomicType, DateType, TimestampNTZType, TimestampType}
class SpecialDatetimeValuesSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Seq(Batch("SpecialDatetimeValues", Once, SpecialDatetimeValues))
}
test("special date values") {
testSpecialDatetimeValues { zoneId =>
val expected = Set(
LocalDate.ofEpochDay(0),
LocalDate.now(zoneId),
LocalDate.now(zoneId).minusDays(1),
LocalDate.now(zoneId).plusDays(1)
).map(_.toEpochDay.toInt)
val in = Project(Seq(
Alias(Cast(Literal("epoch"), DateType, Some(zoneId.getId)), "epoch")(),
Alias(Cast(Literal("today"), DateType, Some(zoneId.getId)), "today")(),
Alias(Cast(Literal("yesterday"), DateType, Some(zoneId.getId)), "yesterday")(),
Alias(Cast(Literal("tomorrow"), DateType, Some(zoneId.getId)), "tomorrow")()),
LocalRelation())
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
val lits = new scala.collection.mutable.ArrayBuffer[Int]
plan.transformAllExpressions { case e: Literal if e.dataType == DateType =>
lits += e.value.asInstanceOf[Int]
e
}
assert(expected === lits.toSet)
}
}
private def testSpecialTs(tsType: AtomicType, expected: Set[Long], zoneId: ZoneId): Unit = {
val in = Project(Seq(
Alias(Cast(Literal("epoch"), tsType, Some(zoneId.getId)), "epoch")(),
Alias(Cast(Literal("now"), tsType, Some(zoneId.getId)), "now")(),
Alias(Cast(Literal("tomorrow"), tsType, Some(zoneId.getId)), "tomorrow")(),
Alias(Cast(Literal("yesterday"), tsType, Some(zoneId.getId)), "yesterday")()),
LocalRelation())
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
val lits = new scala.collection.mutable.ArrayBuffer[Long]
plan.transformAllExpressions { case e: Literal if e.dataType == tsType =>
lits += e.value.asInstanceOf[Long]
e
}
assert(lits.forall(ts => expected.exists(ets => Math.abs(ets -ts) <= MICROS_PER_MINUTE)))
}
test("special timestamp_ltz values") {
testSpecialDatetimeValues { zoneId =>
val expected = Set(
Instant.ofEpochSecond(0),
Instant.now(),
Instant.now().atZone(zoneId).`with`(LocalTime.MIDNIGHT).plusDays(1).toInstant,
Instant.now().atZone(zoneId).`with`(LocalTime.MIDNIGHT).minusDays(1).toInstant
).map(instantToMicros)
testSpecialTs(TimestampType, expected, zoneId)
}
}
test("special timestamp_ntz values") {
testSpecialDatetimeValues { zoneId =>
val expected = Set(
LocalDateTime.of(1970, 1, 1, 0, 0),
LocalDateTime.now(),
LocalDateTime.now().`with`(LocalTime.MIDNIGHT).plusDays(1),
LocalDateTime.now().`with`(LocalTime.MIDNIGHT).minusDays(1)
).map(localDateTimeToMicros)
testSpecialTs(TimestampNTZType, expected, zoneId)
}
}
}