[SPARK-26840][SQL] Avoid cost-based join reorder in presence of join hints

## What changes were proposed in this pull request?

This is a fix for https://github.com/apache/spark/pull/23524, which did not stop cost-based join reorder when the CostBasedJoinReorder rule recurses down the tree and applies join reorder for nested joins with hints.

The issue had not been detected by the existing tests because CBO is disabled by default.

## How was this patch tested?

Enabled CBO for JoinHintSuite.

Closes #23759 from maryannxue/spark-26840.

Lead-authored-by: maryannxue <maryannxue@apache.org>
Co-authored-by: Dongjoon Hyun <dhyun@apple.com>
Signed-off-by: gatorsmile <gatorsmile@gmail.com>
This commit is contained in:
maryannxue 2019-02-14 16:56:55 -08:00 committed by gatorsmile
parent 8656af98c0
commit a7e3da42cd
3 changed files with 80 additions and 53 deletions

View file

@ -43,10 +43,10 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
val result = plan transformDown { val result = plan transformDown {
// Start reordering with a joinable item, which is an InnerLike join with conditions. // Start reordering with a joinable item, which is an InnerLike join with conditions.
// Avoid reordering if a join hint is present. // Avoid reordering if a join hint is present.
case j @ Join(_, _, _: InnerLike, Some(cond), hint) if hint == JoinHint.NONE => case j @ Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE) =>
reorder(j, j.output) reorder(j, j.output)
case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), hint)) case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE))
if projectList.forall(_.isInstanceOf[Attribute]) && hint == JoinHint.NONE => if projectList.forall(_.isInstanceOf[Attribute]) =>
reorder(p, p.output) reorder(p, p.output)
} }
// After reordering is finished, convert OrderedJoin back to Join. // After reordering is finished, convert OrderedJoin back to Join.
@ -77,12 +77,12 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
*/ */
private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = {
plan match { plan match {
case Join(left, right, _: InnerLike, Some(cond), _) => case Join(left, right, _: InnerLike, Some(cond), JoinHint.NONE) =>
val (leftPlans, leftConditions) = extractInnerJoins(left) val (leftPlans, leftConditions) = extractInnerJoins(left)
val (rightPlans, rightConditions) = extractInnerJoins(right) val (rightPlans, rightConditions) = extractInnerJoins(right)
(leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++ (leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++
leftConditions ++ rightConditions) leftConditions ++ rightConditions)
case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE))
if projectList.forall(_.isInstanceOf[Attribute]) => if projectList.forall(_.isInstanceOf[Attribute]) =>
extractInnerJoins(j) extractInnerJoins(j)
case _ => case _ =>
@ -91,11 +91,11 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
} }
private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match {
case j @ Join(left, right, jt: InnerLike, Some(cond), _) => case j @ Join(left, right, jt: InnerLike, Some(cond), JoinHint.NONE) =>
val replacedLeft = replaceWithOrderedJoin(left) val replacedLeft = replaceWithOrderedJoin(left)
val replacedRight = replaceWithOrderedJoin(right) val replacedRight = replaceWithOrderedJoin(right)
OrderedJoin(replacedLeft, replacedRight, jt, Some(cond)) OrderedJoin(replacedLeft, replacedRight, jt, Some(cond))
case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) => case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE)) =>
p.copy(child = replaceWithOrderedJoin(j)) p.copy(child = replaceWithOrderedJoin(j))
case _ => case _ =>
plan plan

View file

@ -312,6 +312,14 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(t3, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) .join(t3, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
assertEqualPlans(originalPlan2, originalPlan2) assertEqualPlans(originalPlan2, originalPlan2)
val originalPlan3 =
t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.join(t4).hint("broadcast")
.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.join(t5, Inner, Some(nameToAttr("t5.v-1-5") === nameToAttr("t3.v-1-100")))
assertEqualPlans(originalPlan3, originalPlan3)
} }
test("reorder below and above the hint node") { test("reorder below and above the hint node") {
@ -342,6 +350,23 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(t4.hint("broadcast")) .join(t4.hint("broadcast"))
assertEqualPlans(originalPlan2, bestPlan2) assertEqualPlans(originalPlan2, bestPlan2)
val originalPlan3 =
t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.hint("broadcast")
.join(t4, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
.join(t5, Inner, Some(nameToAttr("t5.v-1-5") === nameToAttr("t3.v-1-100")))
val bestPlan3 =
t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.select(outputsOf(t1, t2, t3): _*)
.hint("broadcast")
.join(t4, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100")))
.join(t5, Inner, Some(nameToAttr("t5.v-1-5") === nameToAttr("t3.v-1-100")))
assertEqualPlans(originalPlan3, bestPlan3)
} }
private def assertEqualPlans( private def assertEqualPlans(

View file

@ -102,58 +102,60 @@ class JoinHintSuite extends PlanTest with SharedSQLContext {
} }
test("hints prevent join reorder") { test("hints prevent join reorder") {
withTempView("a", "b", "c") { withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") {
df1.createOrReplaceTempView("a") withTempView("a", "b", "c") {
df2.createOrReplaceTempView("b") df1.createOrReplaceTempView("a")
df3.createOrReplaceTempView("c") df2.createOrReplaceTempView("b")
verifyJoinHint( df3.createOrReplaceTempView("c")
sql("select /*+ broadcast(a, c)*/ * from a, b, c " + verifyJoinHint(
"where a.a1 = b.b1 and b.b1 = c.c1"), sql("select /*+ broadcast(a, c)*/ * from a, b, c " +
JoinHint( "where a.a1 = b.b1 and b.b1 = c.c1"),
None,
Some(HintInfo(broadcast = true))) ::
JoinHint(
Some(HintInfo(broadcast = true)),
None):: Nil
)
verifyJoinHint(
sql("select /*+ broadcast(a, c)*/ * from a, c, b " +
"where a.a1 = b.b1 and b.b1 = c.c1"),
JoinHint.NONE ::
JoinHint(
Some(HintInfo(broadcast = true)),
Some(HintInfo(broadcast = true))):: Nil
)
verifyJoinHint(
sql("select /*+ broadcast(b, c)*/ * from a, c, b " +
"where a.a1 = b.b1 and b.b1 = c.c1"),
JoinHint(
None,
Some(HintInfo(broadcast = true))) ::
JoinHint( JoinHint(
None, None,
Some(HintInfo(broadcast = true))):: Nil Some(HintInfo(broadcast = true))) ::
) JoinHint(
Some(HintInfo(broadcast = true)),
None) :: Nil
)
verifyJoinHint(
sql("select /*+ broadcast(a, c)*/ * from a, c, b " +
"where a.a1 = b.b1 and b.b1 = c.c1"),
JoinHint.NONE ::
JoinHint(
Some(HintInfo(broadcast = true)),
Some(HintInfo(broadcast = true))) :: Nil
)
verifyJoinHint(
sql("select /*+ broadcast(b, c)*/ * from a, c, b " +
"where a.a1 = b.b1 and b.b1 = c.c1"),
JoinHint(
None,
Some(HintInfo(broadcast = true))) ::
JoinHint(
None,
Some(HintInfo(broadcast = true))) :: Nil
)
verifyJoinHint( verifyJoinHint(
df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast") df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast")
.join(df3, 'b1 === 'c1 && 'a1 < 10), .join(df3, 'b1 === 'c1 && 'a1 < 10),
JoinHint(
Some(HintInfo(broadcast = true)),
None) ::
JoinHint.NONE:: Nil
)
verifyJoinHint(
df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast")
.join(df3, 'b1 === 'c1 && 'a1 < 10)
.join(df, 'b1 === 'id),
JoinHint.NONE ::
JoinHint( JoinHint(
Some(HintInfo(broadcast = true)), Some(HintInfo(broadcast = true)),
None) :: None) ::
JoinHint.NONE:: Nil JoinHint.NONE :: Nil
) )
verifyJoinHint(
df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast")
.join(df3, 'b1 === 'c1 && 'a1 < 10)
.join(df, 'b1 === 'id),
JoinHint.NONE ::
JoinHint(
Some(HintInfo(broadcast = true)),
None) ::
JoinHint.NONE :: Nil
)
}
} }
} }