[SPARK-36567][SQL] Support foldable special datetime strings by CAST
### What changes were proposed in this pull request?
In the PR, I propose to add new correctness rule `SpecialDatetimeValues` to the final analysis phase. It replaces casts of strings to date/timestamp_ltz/timestamp_ntz by literals of such types if the strings contain special datetime values like `today`, `yesterday` and `tomorrow`, and the input strings are foldable.
### Why are the changes needed?
1. To avoid a breaking change.
2. To improve user experience with Spark SQL. After the PR https://github.com/apache/spark/pull/32714, users have to use typed literals instead of implicit casts. For instance,
at Spark 3.1:
```sql
select ts_col > 'now';
```
but the query fails at the moment, and users have to use typed timestamp literal:
```sql
select ts_col > timestamp'now';
```
### Does this PR introduce _any_ user-facing change?
No. Previous release 3.1 has supported the feature already till it was removed by https://github.com/apache/spark/pull/32714.
### How was this patch tested?
1. Manually tested via the sql command line:
```sql
spark-sql> select cast('today' as date);
2021-08-24
spark-sql> select timestamp('today');
2021-08-24 00:00:00
spark-sql> select timestamp'tomorrow' > 'today';
true
```
2. By running new test suite:
```
$ build/sbt "sql/testOnly org.apache.spark.sql.catalyst.optimizer.SpecialDatetimeValuesSuite"
```
Closes #33816 from MaxGekk/foldable-datetime-special-values.
Authored-by: Max Gekk <max.gekk@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit df0ec56723
)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
beabf91ea1
commit
a4c5140242
|
@ -158,7 +158,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
|
|||
RewriteNonCorrelatedExists,
|
||||
PullOutGroupingExpressions,
|
||||
ComputeCurrentTime,
|
||||
ReplaceCurrentLike(catalogManager)) ::
|
||||
ReplaceCurrentLike(catalogManager),
|
||||
SpecialDatetimeValues) ::
|
||||
//////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Optimizer rules start here
|
||||
//////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -265,6 +266,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
|
|||
EliminateView.ruleName ::
|
||||
ReplaceExpressions.ruleName ::
|
||||
ComputeCurrentTime.ruleName ::
|
||||
SpecialDatetimeValues.ruleName ::
|
||||
ReplaceCurrentLike(catalogManager).ruleName ::
|
||||
RewriteDistinctAggregates.ruleName ::
|
||||
ReplaceDeduplicateWithAggregate.ruleName ::
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
|
|||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.rules._
|
||||
import org.apache.spark.sql.catalyst.trees.TreePattern._
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ}
|
||||
import org.apache.spark.sql.connector.catalog.CatalogManager
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.Utils
|
||||
|
@ -119,3 +120,25 @@ case class ReplaceCurrentLike(catalogManager: CatalogManager) extends Rule[Logic
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replaces casts of special datetime strings by its date/timestamp values
|
||||
* if the input strings are foldable.
|
||||
*/
|
||||
object SpecialDatetimeValues extends Rule[LogicalPlan] {
|
||||
private val conv = Map[DataType, (String, java.time.ZoneId) => Option[Any]](
|
||||
DateType -> convertSpecialDate,
|
||||
TimestampType -> convertSpecialTimestamp,
|
||||
TimestampNTZType -> ((s: String, _: java.time.ZoneId) => convertSpecialTimestampNTZ(s))
|
||||
)
|
||||
def apply(plan: LogicalPlan): LogicalPlan = {
|
||||
plan.transformAllExpressionsWithPruning(_.containsPattern(CAST)) {
|
||||
case cast @ Cast(e, dt @ (DateType | TimestampType | TimestampNTZType), _, _)
|
||||
if e.foldable && e.dataType == StringType =>
|
||||
Option(e.eval())
|
||||
.flatMap(s => conv(dt)(s.toString, cast.zoneId))
|
||||
.map(Literal(_, dt))
|
||||
.getOrElse(cast)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.optimizer
|
||||
|
||||
import java.time.{Instant, LocalDate, LocalDateTime, LocalTime, ZoneId}
|
||||
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Literal}
|
||||
import org.apache.spark.sql.catalyst.plans.PlanTest
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
|
||||
import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MINUTE
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateTimeToMicros}
|
||||
import org.apache.spark.sql.types.{AtomicType, DateType, TimestampNTZType, TimestampType}
|
||||
|
||||
class SpecialDatetimeValuesSuite extends PlanTest {
|
||||
object Optimize extends RuleExecutor[LogicalPlan] {
|
||||
val batches = Seq(Batch("SpecialDatetimeValues", Once, SpecialDatetimeValues))
|
||||
}
|
||||
|
||||
test("special date values") {
|
||||
testSpecialDatetimeValues { zoneId =>
|
||||
val expected = Set(
|
||||
LocalDate.ofEpochDay(0),
|
||||
LocalDate.now(zoneId),
|
||||
LocalDate.now(zoneId).minusDays(1),
|
||||
LocalDate.now(zoneId).plusDays(1)
|
||||
).map(_.toEpochDay.toInt)
|
||||
val in = Project(Seq(
|
||||
Alias(Cast(Literal("epoch"), DateType, Some(zoneId.getId)), "epoch")(),
|
||||
Alias(Cast(Literal("today"), DateType, Some(zoneId.getId)), "today")(),
|
||||
Alias(Cast(Literal("yesterday"), DateType, Some(zoneId.getId)), "yesterday")(),
|
||||
Alias(Cast(Literal("tomorrow"), DateType, Some(zoneId.getId)), "tomorrow")()),
|
||||
LocalRelation())
|
||||
|
||||
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
|
||||
val lits = new scala.collection.mutable.ArrayBuffer[Int]
|
||||
plan.transformAllExpressions { case e: Literal if e.dataType == DateType =>
|
||||
lits += e.value.asInstanceOf[Int]
|
||||
e
|
||||
}
|
||||
assert(expected === lits.toSet)
|
||||
}
|
||||
}
|
||||
|
||||
private def testSpecialTs(tsType: AtomicType, expected: Set[Long], zoneId: ZoneId): Unit = {
|
||||
val in = Project(Seq(
|
||||
Alias(Cast(Literal("epoch"), tsType, Some(zoneId.getId)), "epoch")(),
|
||||
Alias(Cast(Literal("now"), tsType, Some(zoneId.getId)), "now")(),
|
||||
Alias(Cast(Literal("tomorrow"), tsType, Some(zoneId.getId)), "tomorrow")(),
|
||||
Alias(Cast(Literal("yesterday"), tsType, Some(zoneId.getId)), "yesterday")()),
|
||||
LocalRelation())
|
||||
|
||||
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
|
||||
val lits = new scala.collection.mutable.ArrayBuffer[Long]
|
||||
plan.transformAllExpressions { case e: Literal if e.dataType == tsType =>
|
||||
lits += e.value.asInstanceOf[Long]
|
||||
e
|
||||
}
|
||||
assert(lits.forall(ts => expected.exists(ets => Math.abs(ets -ts) <= MICROS_PER_MINUTE)))
|
||||
}
|
||||
|
||||
test("special timestamp_ltz values") {
|
||||
testSpecialDatetimeValues { zoneId =>
|
||||
val expected = Set(
|
||||
Instant.ofEpochSecond(0),
|
||||
Instant.now(),
|
||||
Instant.now().atZone(zoneId).`with`(LocalTime.MIDNIGHT).plusDays(1).toInstant,
|
||||
Instant.now().atZone(zoneId).`with`(LocalTime.MIDNIGHT).minusDays(1).toInstant
|
||||
).map(instantToMicros)
|
||||
testSpecialTs(TimestampType, expected, zoneId)
|
||||
}
|
||||
}
|
||||
|
||||
test("special timestamp_ntz values") {
|
||||
testSpecialDatetimeValues { zoneId =>
|
||||
val expected = Set(
|
||||
LocalDateTime.of(1970, 1, 1, 0, 0),
|
||||
LocalDateTime.now(),
|
||||
LocalDateTime.now().`with`(LocalTime.MIDNIGHT).plusDays(1),
|
||||
LocalDateTime.now().`with`(LocalTime.MIDNIGHT).minusDays(1)
|
||||
).map(localDateTimeToMicros)
|
||||
testSpecialTs(TimestampNTZType, expected, zoneId)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue