[SPARK-20311][SQL] Support aliases for table value functions
## What changes were proposed in this pull request? This pr added parsing rules to support aliases in table value functions. ## How was this patch tested? Added tests in `PlanParserSuite`. Author: Takeshi Yamamuro <yamamuro@apache.org> Closes #17666 from maropu/SPARK-20311.
This commit is contained in:
parent
0d00c768a8
commit
714811d0b5
|
@ -472,15 +472,23 @@ identifierComment
|
|||
;
|
||||
|
||||
relationPrimary
|
||||
: tableIdentifier sample? (AS? strictIdentifier)? #tableName
|
||||
| '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery
|
||||
| '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation
|
||||
| inlineTable #inlineTableDefault2
|
||||
| identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction
|
||||
: tableIdentifier sample? (AS? strictIdentifier)? #tableName
|
||||
| '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery
|
||||
| '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation
|
||||
| inlineTable #inlineTableDefault2
|
||||
| functionTable #tableValuedFunction
|
||||
;
|
||||
|
||||
inlineTable
|
||||
: VALUES expression (',' expression)* (AS? identifier identifierList?)?
|
||||
: VALUES expression (',' expression)* tableAlias
|
||||
;
|
||||
|
||||
functionTable
|
||||
: identifier '(' (expression (',' expression)*)? ')' tableAlias
|
||||
;
|
||||
|
||||
tableAlias
|
||||
: (AS? identifier identifierList?)?
|
||||
;
|
||||
|
||||
rowFormat
|
||||
|
|
|
@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis
|
|||
|
||||
import java.util.Locale
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range}
|
||||
import org.apache.spark.sql.catalyst.rules._
|
||||
import org.apache.spark.sql.types.{DataType, IntegerType, LongType}
|
||||
|
||||
|
@ -105,7 +105,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
|
|||
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
|
||||
case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
|
||||
builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
|
||||
val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
|
||||
case Some(tvf) =>
|
||||
val resolved = tvf.flatMap { case (argList, resolver) =>
|
||||
argList.implicitCast(u.functionArgs) match {
|
||||
|
@ -125,5 +125,21 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
|
|||
case _ =>
|
||||
u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function")
|
||||
}
|
||||
|
||||
// If alias names assigned, add `Project` with the aliases
|
||||
if (u.outputNames.nonEmpty) {
|
||||
val outputAttrs = resolvedFunc.output
|
||||
// Checks if the number of the aliases is equal to expected one
|
||||
if (u.outputNames.size != outputAttrs.size) {
|
||||
u.failAnalysis(s"expected ${outputAttrs.size} columns but " +
|
||||
s"found ${u.outputNames.size} columns")
|
||||
}
|
||||
val aliases = outputAttrs.zip(u.outputNames).map {
|
||||
case (attr, name) => Alias(attr, name)()
|
||||
}
|
||||
Project(aliases, resolvedFunc)
|
||||
} else {
|
||||
resolvedFunc
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -66,10 +66,16 @@ case class UnresolvedInlineTable(
|
|||
/**
|
||||
* A table-valued function, e.g.
|
||||
* {{{
|
||||
* select * from range(10);
|
||||
* select id from range(10);
|
||||
*
|
||||
* // Assign alias names
|
||||
* select t.a from range(10) t(a);
|
||||
* }}}
|
||||
*/
|
||||
case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression])
|
||||
case class UnresolvedTableValuedFunction(
|
||||
functionName: String,
|
||||
functionArgs: Seq[Expression],
|
||||
outputNames: Seq[String])
|
||||
extends LeafNode {
|
||||
|
||||
override def output: Seq[Attribute] = Nil
|
||||
|
|
|
@ -687,7 +687,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
|
|||
*/
|
||||
override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
|
||||
: LogicalPlan = withOrigin(ctx) {
|
||||
UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression))
|
||||
val func = ctx.functionTable
|
||||
val aliases = if (func.tableAlias.identifierList != null) {
|
||||
visitIdentifierList(func.tableAlias.identifierList)
|
||||
} else {
|
||||
Seq.empty
|
||||
}
|
||||
|
||||
val tvf = UnresolvedTableValuedFunction(
|
||||
func.identifier.getText, func.expression.asScala.map(expression), aliases)
|
||||
tvf.optionalMap(func.tableAlias.identifier)(aliasPlan)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -705,14 +714,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
|
|||
}
|
||||
}
|
||||
|
||||
val aliases = if (ctx.identifierList != null) {
|
||||
visitIdentifierList(ctx.identifierList)
|
||||
val aliases = if (ctx.tableAlias.identifierList != null) {
|
||||
visitIdentifierList(ctx.tableAlias.identifierList)
|
||||
} else {
|
||||
Seq.tabulate(rows.head.size)(i => s"col${i + 1}")
|
||||
}
|
||||
|
||||
val table = UnresolvedInlineTable(aliases, rows)
|
||||
table.optionalMap(ctx.identifier)(aliasPlan)
|
||||
table.optionalMap(ctx.tableAlias.identifier)(aliasPlan)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier
|
|||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
|
||||
import org.apache.spark.sql.catalyst.plans.Cross
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -441,4 +440,17 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers {
|
|||
|
||||
checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation)
|
||||
}
|
||||
|
||||
test("SPARK-20311 range(N) as alias") {
|
||||
def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): LogicalPlan = {
|
||||
SubqueryAlias("t", UnresolvedTableValuedFunction("range", args.map(Literal(_)), outputNames))
|
||||
.select(star())
|
||||
}
|
||||
assertAnalysisSuccess(rangeWithAliases(3 :: Nil, "a" :: Nil))
|
||||
assertAnalysisSuccess(rangeWithAliases(1 :: 4 :: Nil, "b" :: Nil))
|
||||
assertAnalysisSuccess(rangeWithAliases(2 :: 6 :: 2 :: Nil, "c" :: Nil))
|
||||
assertAnalysisError(
|
||||
rangeWithAliases(3 :: Nil, "a" :: "b" :: Nil),
|
||||
Seq("expected 1 columns but found 2 columns"))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -468,7 +468,18 @@ class PlanParserSuite extends PlanTest {
|
|||
test("table valued function") {
|
||||
assertEqual(
|
||||
"select * from range(2)",
|
||||
UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star()))
|
||||
UnresolvedTableValuedFunction("range", Literal(2) :: Nil, Seq.empty).select(star()))
|
||||
}
|
||||
|
||||
test("SPARK-20311 range(N) as alias") {
|
||||
assertEqual(
|
||||
"select * from range(10) AS t",
|
||||
SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(10) :: Nil, Seq.empty))
|
||||
.select(star()))
|
||||
assertEqual(
|
||||
"select * from range(7) AS t(a)",
|
||||
SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(7) :: Nil, "a" :: Nil))
|
||||
.select(star()))
|
||||
}
|
||||
|
||||
test("inline table") {
|
||||
|
|
Loading…
Reference in a new issue