[SPARK-20311][SQL] Support aliases for table value functions

## What changes were proposed in this pull request?
This pr added parsing rules to support aliases in table value functions.

## How was this patch tested?
Added tests in `PlanParserSuite`.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #17666 from maropu/SPARK-20311.
This commit is contained in:
Takeshi Yamamuro 2017-05-09 20:22:51 +08:00 committed by Wenchen Fan
parent 0d00c768a8
commit 714811d0b5
6 changed files with 79 additions and 17 deletions

View file

@ -472,15 +472,23 @@ identifierComment
;
relationPrimary
: tableIdentifier sample? (AS? strictIdentifier)? #tableName
| '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery
| '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation
| inlineTable #inlineTableDefault2
| identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction
: tableIdentifier sample? (AS? strictIdentifier)? #tableName
| '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery
| '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation
| inlineTable #inlineTableDefault2
| functionTable #tableValuedFunction
;
inlineTable
: VALUES expression (',' expression)* (AS? identifier identifierList?)?
: VALUES expression (',' expression)* tableAlias
;
functionTable
: identifier '(' (expression (',' expression)*)? ')' tableAlias
;
tableAlias
: (AS? identifier identifierList?)?
;
rowFormat

View file

@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis
import java.util.Locale
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{DataType, IntegerType, LongType}
@ -105,7 +105,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
case Some(tvf) =>
val resolved = tvf.flatMap { case (argList, resolver) =>
argList.implicitCast(u.functionArgs) match {
@ -125,5 +125,21 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
case _ =>
u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function")
}
// If alias names assigned, add `Project` with the aliases
if (u.outputNames.nonEmpty) {
val outputAttrs = resolvedFunc.output
// Checks if the number of the aliases is equal to expected one
if (u.outputNames.size != outputAttrs.size) {
u.failAnalysis(s"expected ${outputAttrs.size} columns but " +
s"found ${u.outputNames.size} columns")
}
val aliases = outputAttrs.zip(u.outputNames).map {
case (attr, name) => Alias(attr, name)()
}
Project(aliases, resolvedFunc)
} else {
resolvedFunc
}
}
}

View file

@ -66,10 +66,16 @@ case class UnresolvedInlineTable(
/**
* A table-valued function, e.g.
* {{{
* select * from range(10);
* select id from range(10);
*
* // Assign alias names
* select t.a from range(10) t(a);
* }}}
*/
case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression])
case class UnresolvedTableValuedFunction(
functionName: String,
functionArgs: Seq[Expression],
outputNames: Seq[String])
extends LeafNode {
override def output: Seq[Attribute] = Nil

View file

@ -687,7 +687,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
*/
override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
: LogicalPlan = withOrigin(ctx) {
UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression))
val func = ctx.functionTable
val aliases = if (func.tableAlias.identifierList != null) {
visitIdentifierList(func.tableAlias.identifierList)
} else {
Seq.empty
}
val tvf = UnresolvedTableValuedFunction(
func.identifier.getText, func.expression.asScala.map(expression), aliases)
tvf.optionalMap(func.tableAlias.identifier)(aliasPlan)
}
/**
@ -705,14 +714,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}
}
val aliases = if (ctx.identifierList != null) {
visitIdentifierList(ctx.identifierList)
val aliases = if (ctx.tableAlias.identifierList != null) {
visitIdentifierList(ctx.tableAlias.identifierList)
} else {
Seq.tabulate(rows.head.size)(i => s"col${i + 1}")
}
val table = UnresolvedInlineTable(aliases, rows)
table.optionalMap(ctx.identifier)(aliasPlan)
table.optionalMap(ctx.tableAlias.identifier)(aliasPlan)
}
/**

View file

@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.Cross
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
@ -441,4 +440,17 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers {
checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation)
}
test("SPARK-20311 range(N) as alias") {
def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): LogicalPlan = {
SubqueryAlias("t", UnresolvedTableValuedFunction("range", args.map(Literal(_)), outputNames))
.select(star())
}
assertAnalysisSuccess(rangeWithAliases(3 :: Nil, "a" :: Nil))
assertAnalysisSuccess(rangeWithAliases(1 :: 4 :: Nil, "b" :: Nil))
assertAnalysisSuccess(rangeWithAliases(2 :: 6 :: 2 :: Nil, "c" :: Nil))
assertAnalysisError(
rangeWithAliases(3 :: Nil, "a" :: "b" :: Nil),
Seq("expected 1 columns but found 2 columns"))
}
}

View file

@ -468,7 +468,18 @@ class PlanParserSuite extends PlanTest {
test("table valued function") {
assertEqual(
"select * from range(2)",
UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star()))
UnresolvedTableValuedFunction("range", Literal(2) :: Nil, Seq.empty).select(star()))
}
test("SPARK-20311 range(N) as alias") {
assertEqual(
"select * from range(10) AS t",
SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(10) :: Nil, Seq.empty))
.select(star()))
assertEqual(
"select * from range(7) AS t(a)",
SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(7) :: Nil, "a" :: Nil))
.select(star()))
}
test("inline table") {