[SPARK-22042][SQL] ReorderJoinPredicates can break when child's partitioning is not decided

## What changes were proposed in this pull request?

See jira description for the bug : https://issues.apache.org/jira/browse/SPARK-22042

Fix done in this PR is:  In `EnsureRequirements`, apply `ReorderJoinPredicates` over the input tree before doing its core logic. Since the tree is transformed bottom-up, we can assure that the children are resolved before doing `ReorderJoinPredicates`.

Theoretically this will guarantee to cover all such cases while keeping the code simple. My small grudge is for cosmetic reasons. This PR will look weird given that we don't call rules from other rules (not to my knowledge). I could have moved all the logic for `ReorderJoinPredicates` into `EnsureRequirements` but that will make it a but crowded. I am happy to discuss if there are better options.

## How was this patch tested?

Added a new test case

Author: Tejas Patil <tejasp@fb.com>

Closes #19257 from tejasapatil/SPARK-22042_ReorderJoinPredicates.
This commit is contained in:
Tejas Patil 2017-12-12 23:30:06 -08:00 committed by gatorsmile
parent 874350905f
commit 682eb4f2ea
4 changed files with 106 additions and 97 deletions

View file

@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
import org.apache.spark.sql.execution.joins.ReorderJoinPredicates
import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _}
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
@ -104,7 +103,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
protected def preparations: Seq[Rule[SparkPlan]] = Seq( protected def preparations: Seq[Rule[SparkPlan]] = Seq(
python.ExtractPythonUDFs, python.ExtractPythonUDFs,
PlanSubqueries(sparkSession), PlanSubqueries(sparkSession),
new ReorderJoinPredicates,
EnsureRequirements(sparkSession.sessionState.conf), EnsureRequirements(sparkSession.sessionState.conf),
CollapseCodegenStages(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf),
ReuseExchange(sparkSession.sessionState.conf), ReuseExchange(sparkSession.sessionState.conf),

View file

@ -17,10 +17,14 @@
package org.apache.spark.sql.execution.exchange package org.apache.spark.sql.execution.exchange
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec,
SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
/** /**
@ -248,6 +252,75 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
operator.withNewChildren(children) operator.withNewChildren(children)
} }
/**
* When the physical operators are created for JOIN, the ordering of join keys is based on order
* in which the join keys appear in the user query. That might not match with the output
* partitioning of the join node's children (thus leading to extra sort / shuffle being
* introduced). This rule will change the ordering of the join keys to match with the
* partitioning of the join nodes' children.
*/
def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
def reorderJoinKeys(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
def reorder(expectedOrderOfKeys: Seq[Expression],
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val leftKeysBuffer = ArrayBuffer[Expression]()
val rightKeysBuffer = ArrayBuffer[Expression]()
expectedOrderOfKeys.foreach(expression => {
val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
leftKeysBuffer.append(leftKeys(index))
rightKeysBuffer.append(rightKeys(index))
})
(leftKeysBuffer, rightKeysBuffer)
}
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
leftPartitioning match {
case HashPartitioning(leftExpressions, _)
if leftExpressions.length == leftKeys.length &&
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
reorder(leftExpressions, leftKeys)
case _ => rightPartitioning match {
case HashPartitioning(rightExpressions, _)
if rightExpressions.length == rightKeys.length &&
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
reorder(rightExpressions, rightKeys)
case _ => (leftKeys, rightKeys)
}
}
} else {
(leftKeys, rightKeys)
}
}
plan.transformUp {
case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left,
right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
left, right)
case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
left, right)
case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right)
}
}
def apply(plan: SparkPlan): SparkPlan = plan.transformUp { def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator @ ShuffleExchangeExec(partitioning, child, _) => case operator @ ShuffleExchangeExec(partitioning, child, _) =>
child.children match { child.children match {
@ -255,6 +328,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
if (childPartitioning.guarantees(partitioning)) child else operator if (childPartitioning.guarantees(partitioning)) child else operator
case _ => operator case _ => operator
} }
case operator: SparkPlan => ensureDistributionAndOrdering(operator) case operator: SparkPlan =>
ensureDistributionAndOrdering(reorderJoinPredicates(operator))
} }
} }

View file

@ -1,94 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.joins
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
/**
* When the physical operators are created for JOIN, the ordering of join keys is based on order
* in which the join keys appear in the user query. That might not match with the output
* partitioning of the join node's children (thus leading to extra sort / shuffle being
* introduced). This rule will change the ordering of the join keys to match with the
* partitioning of the join nodes' children.
*/
class ReorderJoinPredicates extends Rule[SparkPlan] {
private def reorderJoinKeys(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
def reorder(
expectedOrderOfKeys: Seq[Expression],
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val leftKeysBuffer = ArrayBuffer[Expression]()
val rightKeysBuffer = ArrayBuffer[Expression]()
expectedOrderOfKeys.foreach(expression => {
val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
leftKeysBuffer.append(leftKeys(index))
rightKeysBuffer.append(rightKeys(index))
})
(leftKeysBuffer, rightKeysBuffer)
}
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
leftPartitioning match {
case HashPartitioning(leftExpressions, _)
if leftExpressions.length == leftKeys.length &&
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
reorder(leftExpressions, leftKeys)
case _ => rightPartitioning match {
case HashPartitioning(rightExpressions, _)
if rightExpressions.length == rightKeys.length &&
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
reorder(rightExpressions, rightKeys)
case _ => (leftKeys, rightKeys)
}
}
} else {
(leftKeys, rightKeys)
}
}
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
left, right)
case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
left, right)
case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right)
}
}

View file

@ -602,6 +602,37 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
) )
} }
test("SPARK-22042 ReorderJoinPredicates can break when child's partitioning is not decided") {
withTable("bucketed_table", "table1", "table2") {
df.write.format("parquet").saveAsTable("table1")
df.write.format("parquet").saveAsTable("table2")
df.write.format("parquet").bucketBy(8, "j", "k").saveAsTable("bucketed_table")
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
checkAnswer(
sql("""
|SELECT ab.i, ab.j, ab.k, c.i, c.j, c.k
|FROM (
| SELECT a.i, a.j, a.k
| FROM bucketed_table a
| JOIN table1 b
| ON a.i = b.i
|) ab
|JOIN table2 c
|ON ab.i = c.i
|""".stripMargin),
sql("""
|SELECT a.i, a.j, a.k, c.i, c.j, c.k
|FROM bucketed_table a
|JOIN table1 b
|ON a.i = b.i
|JOIN table2 c
|ON a.i = c.i
|""".stripMargin))
}
}
}
test("error if there exists any malformed bucket files") { test("error if there exists any malformed bucket files") {
withTable("bucketed_table") { withTable("bucketed_table") {
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")