[SPARK-17296][SQL] Simplify parser join processing.
## What changes were proposed in this pull request? Join processing in the parser relies on the fact that the grammar produces a right nested trees, for instance the parse tree for `select * from a join b join c` is expected to produce a tree similar to `JOIN(a, JOIN(b, c))`. However there are cases in which this (invariant) is violated, like: ```sql SELECT COUNT(1) FROM test T1 CROSS JOIN test T2 JOIN test T3 ON T3.col = T1.col JOIN test T4 ON T4.col = T1.col ``` In this case the parser returns a tree in which Joins are located on both the left and the right sides of the parent join node. This PR introduces a different grammar rule which does not make this assumption. The new rule takes a relation and searches for zero or more joined relations. As a bonus processing is much easier. ## How was this patch tested? Existing tests and I have added a regression test to the plan parser suite. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #14867 from hvanhovell/SPARK-17296.
This commit is contained in:
parent
29cfab3f15
commit
4f769b903b
|
@ -374,11 +374,12 @@ setQuantifier
|
|||
;
|
||||
|
||||
relation
|
||||
: left=relation
|
||||
(joinType JOIN right=relation joinCriteria?
|
||||
| NATURAL joinType JOIN right=relation
|
||||
) #joinRelation
|
||||
| relationPrimary #relationDefault
|
||||
: relationPrimary joinRelation*
|
||||
;
|
||||
|
||||
joinRelation
|
||||
: (joinType) JOIN right=relationPrimary joinCriteria?
|
||||
| NATURAL joinType JOIN right=relationPrimary
|
||||
;
|
||||
|
||||
joinType
|
||||
|
|
|
@ -92,10 +92,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
|
|||
|
||||
// Apply CTEs
|
||||
query.optional(ctx.ctes) {
|
||||
val ctes = ctx.ctes.namedQuery.asScala.map {
|
||||
case nCtx =>
|
||||
val namedQuery = visitNamedQuery(nCtx)
|
||||
(namedQuery.alias, namedQuery)
|
||||
val ctes = ctx.ctes.namedQuery.asScala.map { nCtx =>
|
||||
val namedQuery = visitNamedQuery(nCtx)
|
||||
(namedQuery.alias, namedQuery)
|
||||
}
|
||||
// Check for duplicate names.
|
||||
checkDuplicateKeys(ctes, ctx)
|
||||
|
@ -401,7 +400,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
|
|||
* separated) relations here, these get converted into a single plan by condition-less inner join.
|
||||
*/
|
||||
override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) {
|
||||
val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None))
|
||||
val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) =>
|
||||
val right = plan(relation.relationPrimary)
|
||||
val join = right.optionalMap(left)(Join(_, _, Inner, None))
|
||||
withJoinRelations(join, relation)
|
||||
}
|
||||
ctx.lateralView.asScala.foldLeft(from)(withGenerate)
|
||||
}
|
||||
|
||||
|
@ -532,55 +535,53 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
|
|||
}
|
||||
|
||||
/**
|
||||
* Create a joins between two or more logical plans.
|
||||
* Create a single relation referenced in a FROM claused. This method is used when a part of the
|
||||
* join condition is nested, for example:
|
||||
* {{{
|
||||
* select * from t1 join (t2 cross join t3) on col1 = col2
|
||||
* }}}
|
||||
*/
|
||||
override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) {
|
||||
/** Build a join between two plans. */
|
||||
def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = {
|
||||
val baseJoinType = ctx.joinType match {
|
||||
case null => Inner
|
||||
case jt if jt.CROSS != null => Cross
|
||||
case jt if jt.FULL != null => FullOuter
|
||||
case jt if jt.SEMI != null => LeftSemi
|
||||
case jt if jt.ANTI != null => LeftAnti
|
||||
case jt if jt.LEFT != null => LeftOuter
|
||||
case jt if jt.RIGHT != null => RightOuter
|
||||
case _ => Inner
|
||||
}
|
||||
override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) {
|
||||
withJoinRelations(plan(ctx.relationPrimary), ctx)
|
||||
}
|
||||
|
||||
// Resolve the join type and join condition
|
||||
val (joinType, condition) = Option(ctx.joinCriteria) match {
|
||||
case Some(c) if c.USING != null =>
|
||||
val columns = c.identifier.asScala.map { column =>
|
||||
UnresolvedAttribute.quoted(column.getText)
|
||||
}
|
||||
(UsingJoin(baseJoinType, columns), None)
|
||||
case Some(c) if c.booleanExpression != null =>
|
||||
(baseJoinType, Option(expression(c.booleanExpression)))
|
||||
case None if ctx.NATURAL != null =>
|
||||
(NaturalJoin(baseJoinType), None)
|
||||
case None =>
|
||||
(baseJoinType, None)
|
||||
}
|
||||
Join(left, right, joinType, condition)
|
||||
}
|
||||
/**
|
||||
* Join one more [[LogicalPlan]]s to the current logical plan.
|
||||
*/
|
||||
private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = {
|
||||
ctx.joinRelation.asScala.foldLeft(base) { (left, join) =>
|
||||
withOrigin(join) {
|
||||
val baseJoinType = join.joinType match {
|
||||
case null => Inner
|
||||
case jt if jt.CROSS != null => Cross
|
||||
case jt if jt.FULL != null => FullOuter
|
||||
case jt if jt.SEMI != null => LeftSemi
|
||||
case jt if jt.ANTI != null => LeftAnti
|
||||
case jt if jt.LEFT != null => LeftOuter
|
||||
case jt if jt.RIGHT != null => RightOuter
|
||||
case _ => Inner
|
||||
}
|
||||
|
||||
// Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the
|
||||
// first join clause is at the top. However fields of previously referenced tables can be used
|
||||
// in following join clauses. The tree needs to be reversed in order to make this work.
|
||||
var result = plan(ctx.left)
|
||||
var current = ctx
|
||||
while (current != null) {
|
||||
current.right match {
|
||||
case right: JoinRelationContext =>
|
||||
result = join(current, result, plan(right.left))
|
||||
current = right
|
||||
case right =>
|
||||
result = join(current, result, plan(right))
|
||||
current = null
|
||||
// Resolve the join type and join condition
|
||||
val (joinType, condition) = Option(join.joinCriteria) match {
|
||||
case Some(c) if c.USING != null =>
|
||||
val columns = c.identifier.asScala.map { column =>
|
||||
UnresolvedAttribute.quoted(column.getText)
|
||||
}
|
||||
(UsingJoin(baseJoinType, columns), None)
|
||||
case Some(c) if c.booleanExpression != null =>
|
||||
(baseJoinType, Option(expression(c.booleanExpression)))
|
||||
case None if join.NATURAL != null =>
|
||||
if (baseJoinType == Cross) {
|
||||
throw new ParseException("NATURAL CROSS JOIN is not supported", ctx)
|
||||
}
|
||||
(NaturalJoin(baseJoinType), None)
|
||||
case None =>
|
||||
(baseJoinType, None)
|
||||
}
|
||||
Join(left, plan(join.right), joinType, condition)
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser
|
|||
|
||||
import scala.collection.mutable.StringBuilder
|
||||
|
||||
import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token}
|
||||
import org.antlr.v4.runtime.{ParserRuleContext, Token}
|
||||
import org.antlr.v4.runtime.misc.Interval
|
||||
import org.antlr.v4.runtime.tree.TerminalNode
|
||||
|
||||
|
@ -189,9 +189,7 @@ object ParserUtils {
|
|||
* Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the
|
||||
* passed function. The original plan is returned when the context does not exist.
|
||||
*/
|
||||
def optionalMap[C <: ParserRuleContext](
|
||||
ctx: C)(
|
||||
f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
|
||||
def optionalMap[C](ctx: C)(f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
|
||||
if (ctx != null) {
|
||||
f(ctx, plan)
|
||||
} else {
|
||||
|
|
|
@ -360,10 +360,54 @@ class PlanParserSuite extends PlanTest {
|
|||
test("left anti join", LeftAnti, testExistence)
|
||||
test("anti join", LeftAnti, testExistence)
|
||||
|
||||
// Test natural cross join
|
||||
intercept("select * from a natural cross join b")
|
||||
|
||||
// Test natural join with a condition
|
||||
intercept("select * from a natural join b on a.id = b.id")
|
||||
|
||||
// Test multiple consecutive joins
|
||||
assertEqual(
|
||||
"select * from a join b join c right join d",
|
||||
table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star()))
|
||||
|
||||
// SPARK-17296
|
||||
assertEqual(
|
||||
"select * from t1 cross join t2 join t3 on t3.id = t1.id join t4 on t4.id = t1.id",
|
||||
table("t1")
|
||||
.join(table("t2"), Cross)
|
||||
.join(table("t3"), Inner, Option(Symbol("t3.id") === Symbol("t1.id")))
|
||||
.join(table("t4"), Inner, Option(Symbol("t4.id") === Symbol("t1.id")))
|
||||
.select(star()))
|
||||
|
||||
// Test multiple on clauses.
|
||||
intercept("select * from t1 inner join t2 inner join t3 on col3 = col2 on col3 = col1")
|
||||
|
||||
// Parenthesis
|
||||
assertEqual(
|
||||
"select * from t1 inner join (t2 inner join t3 on col3 = col2) on col3 = col1",
|
||||
table("t1")
|
||||
.join(table("t2")
|
||||
.join(table("t3"), Inner, Option('col3 === 'col2)), Inner, Option('col3 === 'col1))
|
||||
.select(star()))
|
||||
assertEqual(
|
||||
"select * from t1 inner join (t2 inner join t3) on col3 = col2",
|
||||
table("t1")
|
||||
.join(table("t2").join(table("t3"), Inner, None), Inner, Option('col3 === 'col2))
|
||||
.select(star()))
|
||||
assertEqual(
|
||||
"select * from t1 inner join (t2 inner join t3 on col3 = col2)",
|
||||
table("t1")
|
||||
.join(table("t2").join(table("t3"), Inner, Option('col3 === 'col2)), Inner, None)
|
||||
.select(star()))
|
||||
|
||||
// Implicit joins.
|
||||
assertEqual(
|
||||
"select * from t1, t3 join t2 on t1.col1 = t2.col2",
|
||||
table("t1")
|
||||
.join(table("t3"))
|
||||
.join(table("t2"), Inner, Option(Symbol("t1.col1") === Symbol("t2.col2")))
|
||||
.select(star()))
|
||||
}
|
||||
|
||||
test("sampled relations") {
|
||||
|
|
Loading…
Reference in a new issue