[SPARK-11351] [SQL] support hive interval literal
Author: Wenchen Fan <wenchen@databricks.com> Closes #9304 from cloud-fan/interval.
This commit is contained in:
parent
e5b89978ed
commit
0cb7662d86
|
@ -322,7 +322,7 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
|
|||
protected lazy val literal: Parser[Literal] =
|
||||
( numericLiteral
|
||||
| booleanLiteral
|
||||
| stringLit ^^ {case s => Literal.create(s, StringType) }
|
||||
| stringLit ^^ { case s => Literal.create(s, StringType) }
|
||||
| intervalLiteral
|
||||
| NULL ^^^ Literal.create(null, NullType)
|
||||
)
|
||||
|
@ -349,13 +349,12 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
|
|||
protected lazy val integral: Parser[String] =
|
||||
sign.? ~ numericLit ^^ { case s ~ n => s.getOrElse("") + n }
|
||||
|
||||
private def intervalUnit(unitName: String) =
|
||||
acceptIf {
|
||||
case lexical.Identifier(str) =>
|
||||
val normalized = lexical.normalizeKeyword(str)
|
||||
normalized == unitName || normalized == unitName + "s"
|
||||
case _ => false
|
||||
} {_ => "wrong interval unit"}
|
||||
private def intervalUnit(unitName: String) = acceptIf {
|
||||
case lexical.Identifier(str) =>
|
||||
val normalized = lexical.normalizeKeyword(str)
|
||||
normalized == unitName || normalized == unitName + "s"
|
||||
case _ => false
|
||||
} {_ => "wrong interval unit"}
|
||||
|
||||
protected lazy val month: Parser[Int] =
|
||||
integral <~ intervalUnit("month") ^^ { case num => num.toInt }
|
||||
|
@ -396,21 +395,53 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
|
|||
case num => num.toLong * CalendarInterval.MICROS_PER_WEEK
|
||||
}
|
||||
|
||||
private def intervalKeyword(keyword: String) = acceptIf {
|
||||
case lexical.Identifier(str) =>
|
||||
lexical.normalizeKeyword(str) == keyword
|
||||
case _ => false
|
||||
} {_ => "wrong interval keyword"}
|
||||
|
||||
protected lazy val intervalLiteral: Parser[Literal] =
|
||||
INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~
|
||||
millisecond.? ~ microsecond.? ^^ {
|
||||
case year ~ month ~ week ~ day ~ hour ~ minute ~ second ~
|
||||
( INTERVAL ~> stringLit <~ intervalKeyword("year") ~ intervalKeyword("to") ~
|
||||
intervalKeyword("month") ^^ { case s =>
|
||||
Literal(CalendarInterval.fromYearMonthString(s))
|
||||
}
|
||||
| INTERVAL ~> stringLit <~ intervalKeyword("day") ~ intervalKeyword("to") ~
|
||||
intervalKeyword("second") ^^ { case s =>
|
||||
Literal(CalendarInterval.fromDayTimeString(s))
|
||||
}
|
||||
| INTERVAL ~> stringLit <~ intervalKeyword("year") ^^ { case s =>
|
||||
Literal(CalendarInterval.fromSingleUnitString("year", s))
|
||||
}
|
||||
| INTERVAL ~> stringLit <~ intervalKeyword("month") ^^ { case s =>
|
||||
Literal(CalendarInterval.fromSingleUnitString("month", s))
|
||||
}
|
||||
| INTERVAL ~> stringLit <~ intervalKeyword("day") ^^ { case s =>
|
||||
Literal(CalendarInterval.fromSingleUnitString("day", s))
|
||||
}
|
||||
| INTERVAL ~> stringLit <~ intervalKeyword("hour") ^^ { case s =>
|
||||
Literal(CalendarInterval.fromSingleUnitString("hour", s))
|
||||
}
|
||||
| INTERVAL ~> stringLit <~ intervalKeyword("minute") ^^ { case s =>
|
||||
Literal(CalendarInterval.fromSingleUnitString("minute", s))
|
||||
}
|
||||
| INTERVAL ~> stringLit <~ intervalKeyword("second") ^^ { case s =>
|
||||
Literal(CalendarInterval.fromSingleUnitString("second", s))
|
||||
}
|
||||
| INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~
|
||||
millisecond.? ~ microsecond.? ^^ { case year ~ month ~ week ~ day ~ hour ~ minute ~ second ~
|
||||
millisecond ~ microsecond =>
|
||||
if (!Seq(year, month, week, day, hour, minute, second,
|
||||
millisecond, microsecond).exists(_.isDefined)) {
|
||||
throw new AnalysisException(
|
||||
"at least one time unit should be given for interval literal")
|
||||
}
|
||||
val months = Seq(year, month).map(_.getOrElse(0)).sum
|
||||
val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond)
|
||||
.map(_.getOrElse(0L)).sum
|
||||
Literal.create(new CalendarInterval(months, microseconds), CalendarIntervalType)
|
||||
if (!Seq(year, month, week, day, hour, minute, second,
|
||||
millisecond, microsecond).exists(_.isDefined)) {
|
||||
throw new AnalysisException(
|
||||
"at least one time unit should be given for interval literal")
|
||||
}
|
||||
val months = Seq(year, month).map(_.getOrElse(0)).sum
|
||||
val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond)
|
||||
.map(_.getOrElse(0L)).sum
|
||||
Literal(new CalendarInterval(months, microseconds))
|
||||
}
|
||||
)
|
||||
|
||||
private def toNarrowestIntegerType(value: String): Any = {
|
||||
val bigIntValue = BigDecimal(value)
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
|
|||
import org.apache.spark.sql.catalyst.expressions.{Literal, GreaterThan, Not, Attribute}
|
||||
import org.apache.spark.sql.catalyst.plans.PlanTest
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, LogicalPlan, Command}
|
||||
import org.apache.spark.unsafe.types.CalendarInterval
|
||||
|
||||
private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command {
|
||||
override def output: Seq[Attribute] = Seq.empty
|
||||
|
@ -74,4 +75,55 @@ class SqlParserSuite extends PlanTest {
|
|||
OneRowRelation)
|
||||
comparePlans(parsed, expected)
|
||||
}
|
||||
|
||||
test("support hive interval literal") {
|
||||
def checkInterval(sql: String, result: CalendarInterval): Unit = {
|
||||
val parsed = SqlParser.parse(sql)
|
||||
val expected = Project(
|
||||
UnresolvedAlias(
|
||||
Literal(result)
|
||||
) :: Nil,
|
||||
OneRowRelation)
|
||||
comparePlans(parsed, expected)
|
||||
}
|
||||
|
||||
def checkYearMonth(lit: String): Unit = {
|
||||
checkInterval(
|
||||
s"SELECT INTERVAL '$lit' YEAR TO MONTH",
|
||||
CalendarInterval.fromYearMonthString(lit))
|
||||
}
|
||||
|
||||
def checkDayTime(lit: String): Unit = {
|
||||
checkInterval(
|
||||
s"SELECT INTERVAL '$lit' DAY TO SECOND",
|
||||
CalendarInterval.fromDayTimeString(lit))
|
||||
}
|
||||
|
||||
def checkSingleUnit(lit: String, unit: String): Unit = {
|
||||
checkInterval(
|
||||
s"SELECT INTERVAL '$lit' $unit",
|
||||
CalendarInterval.fromSingleUnitString(unit, lit))
|
||||
}
|
||||
|
||||
checkYearMonth("123-10")
|
||||
checkYearMonth("496-0")
|
||||
checkYearMonth("-2-3")
|
||||
checkYearMonth("-123-0")
|
||||
|
||||
checkDayTime("99 11:22:33.123456789")
|
||||
checkDayTime("-99 11:22:33.123456789")
|
||||
checkDayTime("10 9:8:7.123456789")
|
||||
checkDayTime("1 0:0:0")
|
||||
checkDayTime("-1 0:0:0")
|
||||
checkDayTime("1 0:0:1")
|
||||
|
||||
for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) {
|
||||
checkSingleUnit("7", unit)
|
||||
checkSingleUnit("-7", unit)
|
||||
checkSingleUnit("0", unit)
|
||||
}
|
||||
|
||||
checkSingleUnit("13.123456789", "second")
|
||||
checkSingleUnit("-13.123456789", "second")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue