[SPARK-11351] [SQL] support hive interval literal

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9304 from cloud-fan/interval.
This commit is contained in:
Wenchen Fan 2015-10-28 21:35:57 -07:00 committed by Yin Huai
parent e5b89978ed
commit 0cb7662d86
2 changed files with 103 additions and 20 deletions

View file

@ -322,7 +322,7 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected lazy val literal: Parser[Literal] =
( numericLiteral
| booleanLiteral
| stringLit ^^ {case s => Literal.create(s, StringType) }
| stringLit ^^ { case s => Literal.create(s, StringType) }
| intervalLiteral
| NULL ^^^ Literal.create(null, NullType)
)
@ -349,13 +349,12 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected lazy val integral: Parser[String] =
sign.? ~ numericLit ^^ { case s ~ n => s.getOrElse("") + n }
private def intervalUnit(unitName: String) =
acceptIf {
case lexical.Identifier(str) =>
val normalized = lexical.normalizeKeyword(str)
normalized == unitName || normalized == unitName + "s"
case _ => false
} {_ => "wrong interval unit"}
private def intervalUnit(unitName: String) = acceptIf {
case lexical.Identifier(str) =>
val normalized = lexical.normalizeKeyword(str)
normalized == unitName || normalized == unitName + "s"
case _ => false
} {_ => "wrong interval unit"}
protected lazy val month: Parser[Int] =
integral <~ intervalUnit("month") ^^ { case num => num.toInt }
@ -396,21 +395,53 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
case num => num.toLong * CalendarInterval.MICROS_PER_WEEK
}
private def intervalKeyword(keyword: String) = acceptIf {
case lexical.Identifier(str) =>
lexical.normalizeKeyword(str) == keyword
case _ => false
} {_ => "wrong interval keyword"}
protected lazy val intervalLiteral: Parser[Literal] =
INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~
millisecond.? ~ microsecond.? ^^ {
case year ~ month ~ week ~ day ~ hour ~ minute ~ second ~
( INTERVAL ~> stringLit <~ intervalKeyword("year") ~ intervalKeyword("to") ~
intervalKeyword("month") ^^ { case s =>
Literal(CalendarInterval.fromYearMonthString(s))
}
| INTERVAL ~> stringLit <~ intervalKeyword("day") ~ intervalKeyword("to") ~
intervalKeyword("second") ^^ { case s =>
Literal(CalendarInterval.fromDayTimeString(s))
}
| INTERVAL ~> stringLit <~ intervalKeyword("year") ^^ { case s =>
Literal(CalendarInterval.fromSingleUnitString("year", s))
}
| INTERVAL ~> stringLit <~ intervalKeyword("month") ^^ { case s =>
Literal(CalendarInterval.fromSingleUnitString("month", s))
}
| INTERVAL ~> stringLit <~ intervalKeyword("day") ^^ { case s =>
Literal(CalendarInterval.fromSingleUnitString("day", s))
}
| INTERVAL ~> stringLit <~ intervalKeyword("hour") ^^ { case s =>
Literal(CalendarInterval.fromSingleUnitString("hour", s))
}
| INTERVAL ~> stringLit <~ intervalKeyword("minute") ^^ { case s =>
Literal(CalendarInterval.fromSingleUnitString("minute", s))
}
| INTERVAL ~> stringLit <~ intervalKeyword("second") ^^ { case s =>
Literal(CalendarInterval.fromSingleUnitString("second", s))
}
| INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~
millisecond.? ~ microsecond.? ^^ { case year ~ month ~ week ~ day ~ hour ~ minute ~ second ~
millisecond ~ microsecond =>
if (!Seq(year, month, week, day, hour, minute, second,
millisecond, microsecond).exists(_.isDefined)) {
throw new AnalysisException(
"at least one time unit should be given for interval literal")
}
val months = Seq(year, month).map(_.getOrElse(0)).sum
val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond)
.map(_.getOrElse(0L)).sum
Literal.create(new CalendarInterval(months, microseconds), CalendarIntervalType)
if (!Seq(year, month, week, day, hour, minute, second,
millisecond, microsecond).exists(_.isDefined)) {
throw new AnalysisException(
"at least one time unit should be given for interval literal")
}
val months = Seq(year, month).map(_.getOrElse(0)).sum
val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond)
.map(_.getOrElse(0L)).sum
Literal(new CalendarInterval(months, microseconds))
}
)
private def toNarrowestIntegerType(value: String): Any = {
val bigIntValue = BigDecimal(value)

View file

@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
import org.apache.spark.sql.catalyst.expressions.{Literal, GreaterThan, Not, Attribute}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, LogicalPlan, Command}
import org.apache.spark.unsafe.types.CalendarInterval
private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command {
override def output: Seq[Attribute] = Seq.empty
@ -74,4 +75,55 @@ class SqlParserSuite extends PlanTest {
OneRowRelation)
comparePlans(parsed, expected)
}
test("support hive interval literal") {
def checkInterval(sql: String, result: CalendarInterval): Unit = {
val parsed = SqlParser.parse(sql)
val expected = Project(
UnresolvedAlias(
Literal(result)
) :: Nil,
OneRowRelation)
comparePlans(parsed, expected)
}
def checkYearMonth(lit: String): Unit = {
checkInterval(
s"SELECT INTERVAL '$lit' YEAR TO MONTH",
CalendarInterval.fromYearMonthString(lit))
}
def checkDayTime(lit: String): Unit = {
checkInterval(
s"SELECT INTERVAL '$lit' DAY TO SECOND",
CalendarInterval.fromDayTimeString(lit))
}
def checkSingleUnit(lit: String, unit: String): Unit = {
checkInterval(
s"SELECT INTERVAL '$lit' $unit",
CalendarInterval.fromSingleUnitString(unit, lit))
}
checkYearMonth("123-10")
checkYearMonth("496-0")
checkYearMonth("-2-3")
checkYearMonth("-123-0")
checkDayTime("99 11:22:33.123456789")
checkDayTime("-99 11:22:33.123456789")
checkDayTime("10 9:8:7.123456789")
checkDayTime("1 0:0:0")
checkDayTime("-1 0:0:0")
checkDayTime("1 0:0:1")
for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) {
checkSingleUnit("7", unit)
checkSingleUnit("-7", unit)
checkSingleUnit("0", unit)
}
checkSingleUnit("13.123456789", "second")
checkSingleUnit("-13.123456789", "second")
}
}