[SPARK-27485] EnsureRequirements.reorder should handle duplicate expressions gracefully

## What changes were proposed in this pull request?
When reordering joins EnsureRequirements only checks if all the join keys are present in the partitioning expression seq. This is problematic when the joins keys and and partitioning expressions both contain duplicates but not the same number of duplicates for each expression, e.g. `Seq(a, a, b)` vs `Seq(a, b, b)`. This fails with an index lookup failure in the `reorder` function.

This PR fixes this removing the equality checking logic from the `reorderJoinKeys` function, and by doing the multiset equality in the `reorder` function while building the reordered key sequences.

## How was this patch tested?
Added a unit test to the `PlannerSuite` and added an integration test to `JoinSuite`

Closes #25167 from hvanhovell/SPARK-27485.

Authored-by: herman <herman@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
herman 2019-07-16 17:09:52 +08:00 committed by Wenchen Fan
parent 9a7f01d944
commit 421d9d56ef
3 changed files with 86 additions and 32 deletions

View file

@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec,
SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
/**
@ -117,25 +116,41 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
}
private def reorder(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftKeys: IndexedSeq[Expression],
rightKeys: IndexedSeq[Expression],
expectedOrderOfKeys: Seq[Expression],
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val leftKeysBuffer = ArrayBuffer[Expression]()
val rightKeysBuffer = ArrayBuffer[Expression]()
val pickedIndexes = mutable.Set[Int]()
val keysAndIndexes = currentOrderOfKeys.zipWithIndex
if (expectedOrderOfKeys.size != currentOrderOfKeys.size) {
return (leftKeys, rightKeys)
}
expectedOrderOfKeys.foreach(expression => {
val index = keysAndIndexes.find { case (e, idx) =>
// As we may have the same key used many times, we need to filter out its occurrence we
// have already used.
e.semanticEquals(expression) && !pickedIndexes.contains(idx)
}.map(_._2).get
pickedIndexes += index
leftKeysBuffer.append(leftKeys(index))
rightKeysBuffer.append(rightKeys(index))
})
// Build a lookup between an expression and the positions its holds in the current key seq.
val keyToIndexMap = mutable.Map.empty[Expression, mutable.BitSet]
currentOrderOfKeys.zipWithIndex.foreach {
case (key, index) =>
keyToIndexMap.getOrElseUpdate(key.canonicalized, mutable.BitSet.empty).add(index)
}
// Reorder the keys.
val leftKeysBuffer = new ArrayBuffer[Expression](leftKeys.size)
val rightKeysBuffer = new ArrayBuffer[Expression](rightKeys.size)
val iterator = expectedOrderOfKeys.iterator
while (iterator.hasNext) {
// Lookup the current index of this key.
keyToIndexMap.get(iterator.next().canonicalized) match {
case Some(indices) if indices.nonEmpty =>
// Take the first available index from the map.
val index = indices.firstKey
indices.remove(index)
// Add the keys for that index to the reordered keys.
leftKeysBuffer += leftKeys(index)
rightKeysBuffer += rightKeys(index)
case _ =>
// The expression cannot be found, or we have exhausted all indices for that expression.
return (leftKeys, rightKeys)
}
}
(leftKeysBuffer, rightKeysBuffer)
}
@ -145,20 +160,13 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
leftPartitioning match {
case HashPartitioning(leftExpressions, _)
if leftExpressions.length == leftKeys.length &&
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
reorder(leftKeys, rightKeys, leftExpressions, leftKeys)
case _ => rightPartitioning match {
case HashPartitioning(rightExpressions, _)
if rightExpressions.length == rightKeys.length &&
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
reorder(leftKeys, rightKeys, rightExpressions, rightKeys)
case _ => (leftKeys, rightKeys)
}
(leftPartitioning, rightPartitioning) match {
case (HashPartitioning(leftExpressions, _), _) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys)
case (_, HashPartitioning(rightExpressions, _)) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
case _ =>
(leftKeys, rightKeys)
}
} else {
(leftKeys, rightKeys)

View file

@ -898,6 +898,26 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
}
test("SPARK-27485: EnsureRequirements should not fail join with duplicate keys") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "2",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val tbl_a = spark.range(40)
.select($"id" as "x", $"id" % 10 as "y")
.repartition(2, $"x", $"y", $"x")
.as("tbl_a")
val tbl_b = spark.range(20)
.select($"id" as "x", $"id" % 2 as "y1", $"id" % 20 as "y2")
.as("tbl_b")
val res = tbl_a
.join(tbl_b,
$"tbl_a.x" === $"tbl_b.x" && $"tbl_a.y" === $"tbl_b.y1" && $"tbl_a.y" === $"tbl_b.y2")
.select($"tbl_a.x")
checkAnswer(res, Row(0L) :: Row(1L) :: Nil)
}
}
test("SPARK-26352: join reordering should not change the order of columns") {
withTable("tab1", "tab2", "tab3") {
spark.sql("select 1 as x, 100 as y").write.saveAsTable("tab1")

View file

@ -696,6 +696,32 @@ class PlannerSuite extends SharedSQLContext {
}
}
test("SPARK-27485: EnsureRequirements.reorder should handle duplicate expressions") {
val plan1 = DummySparkPlan(
outputPartitioning = HashPartitioning(exprA :: exprB :: exprA :: Nil, 5))
val plan2 = DummySparkPlan()
val smjExec = SortMergeJoinExec(
leftKeys = exprA :: exprB :: exprB :: Nil,
rightKeys = exprA :: exprC :: exprC :: Nil,
joinType = Inner,
condition = None,
left = plan1,
right = plan2)
val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec)
outputPlan match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
SortExec(_, _,
ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _), _),
SortExec(_, _,
ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _), _)) =>
assert(leftKeys === smjExec.leftKeys)
assert(rightKeys === smjExec.rightKeys)
assert(leftKeys === leftPartitioningExpressions)
assert(rightKeys === rightPartitioningExpressions)
case _ => fail(outputPlan.toString)
}
}
test("SPARK-24500: create union with stream of children") {
val df = Union(Stream(
Range(1, 1, 1, 1),