[SPARK-34729][SQL] Faster execution for broadcast nested loop join (left semi/anti with no condition)

### What changes were proposed in this pull request?

For `BroadcastNestedLoopJoinExec` left semi and left anti join without condition. If we broadcast left side. Currently we check whether every row from broadcast side has a match or not by [iterating broadcast side a lot of time](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala#L256-L275). This is unnecessary and very inefficient when there's no condition, as we only need to check whether stream side is empty or not. Create this PR to add the optimization. This can boost the affected query execution performance a lot.

In addition, create a common method `getMatchedBroadcastRowsBitSet()` shared by several methods.
Refactor `defaultJoin()` to move
* left semi and left anti join related logic to `leftExistenceJoin`
* existence join related logic to `existenceJoin`.

After this, `defaultJoin()` holds logic only for outer join (left outer, right outer and full outer), which is much easier to read from my own opinion.

### Why are the changes needed?

Improve the affected query performance a lot.
Test with a simple query by modifying `JoinBenchmark.scala` locally:

```
val N = 20 << 20
val M = 1 << 4
val dim = broadcast(spark.range(M).selectExpr("id as k"))
val df = dim.join(spark.range(N), Seq.empty, "left_semi")
df.noop()
```

See >30x run time improvement. Note the stream side is only `spark.range(N)`. For complicated query with non-trivial stream side, the saving would be much more.

```
Running benchmark: broadcast nested loop left semi join
  Running case: broadcast nested loop left semi join optimization off
  Stopped after 2 iterations, 3163 ms
  Running case: broadcast nested loop left semi join optimization on
  Stopped after 5 iterations, 366 ms

Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.15.7
Intel(R) Core(TM) i9-9980HK CPU  2.40GHz
broadcast nested loop left semi join:                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
-------------------------------------------------------------------------------------------------------------------------------------
broadcast nested loop left semi join optimization off           1568           1582          19         13.4          74.8       1.0X
broadcast nested loop left semi join optimization on              46             73          18        456.0           2.2      34.1X
```

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Added unit test in `ExistenceJoinSuite.scala`.

Closes #31821 from c21/nested-join.

Authored-by: Cheng Su <chengsu@fb.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
Cheng Su 2021-03-14 23:51:36 -07:00 committed by Dongjoon Hyun
parent e757091820
commit a0f3b72e1c
2 changed files with 155 additions and 120 deletions

View file

@ -193,50 +193,86 @@ case class BroadcastNestedLoopJoinExec(
}
/**
* The implementation for these joins:
*
* LeftSemi with BuildRight
* LeftAnti with BuildRight
* The implementation for LeftSemi and LeftAnti joins.
*/
private def leftExistenceJoin(
relation: Broadcast[Array[InternalRow]],
exists: Boolean): RDD[InternalRow] = {
assert(buildSide == BuildRight)
streamed.execute().mapPartitionsInternal { streamedIter =>
val buildRows = relation.value
val joinedRow = new JoinedRow
buildSide match {
case BuildRight =>
streamed.execute().mapPartitionsInternal { streamedIter =>
val buildRows = relation.value
val joinedRow = new JoinedRow
if (condition.isDefined) {
streamedIter.filter(l =>
buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists
)
} else if (buildRows.nonEmpty == exists) {
streamedIter
} else {
Iterator.empty
}
if (condition.isDefined) {
streamedIter.filter(l =>
buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists
)
} else if (buildRows.nonEmpty == exists) {
streamedIter
} else {
Iterator.empty
}
}
case BuildLeft if condition.isEmpty =>
// If condition is empty, do not need to read rows from streamed side at all.
// Only need to know whether streamed side is empty or not.
val streamExists = !streamed.execute().isEmpty()
if (streamExists == exists) {
sparkContext.makeRDD(relation.value)
} else {
sparkContext.emptyRDD
}
case _ => // BuildLeft
val matchedBroadcastRows = getMatchedBroadcastRowsBitSet(streamed.execute(), relation)
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
var i = 0
val buildRows = relation.value
while (i < buildRows.length) {
if (matchedBroadcastRows.get(i) == exists) {
buf += buildRows(i).copy()
}
i += 1
}
sparkContext.makeRDD(buf)
}
}
/**
* The implementation for ExistenceJoin
*/
private def existenceJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
assert(buildSide == BuildRight)
streamed.execute().mapPartitionsInternal { streamedIter =>
val buildRows = relation.value
val joinedRow = new JoinedRow
buildSide match {
case BuildRight =>
streamed.execute().mapPartitionsInternal { streamedIter =>
val buildRows = relation.value
val joinedRow = new JoinedRow
if (condition.isDefined) {
val resultRow = new GenericInternalRow(Array[Any](null))
streamedIter.map { row =>
val result = buildRows.exists(r => boundCondition(joinedRow(row, r)))
resultRow.setBoolean(0, result)
joinedRow(row, resultRow)
if (condition.isDefined) {
val resultRow = new GenericInternalRow(Array[Any](null))
streamedIter.map { row =>
val result = buildRows.exists(r => boundCondition(joinedRow(row, r)))
resultRow.setBoolean(0, result)
joinedRow(row, resultRow)
}
} else {
val resultRow = new GenericInternalRow(Array[Any](buildRows.nonEmpty))
streamedIter.map { row =>
joinedRow(row, resultRow)
}
}
}
} else {
val resultRow = new GenericInternalRow(Array[Any](buildRows.nonEmpty))
streamedIter.map { row =>
joinedRow(row, resultRow)
case _ => // BuildLeft
val matchedBroadcastRows = getMatchedBroadcastRowsBitSet(streamed.execute(), relation)
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
var i = 0
val buildRows = relation.value
while (i < buildRows.length) {
val result = new GenericInternalRow(Array[Any](matchedBroadcastRows.get(i)))
buf += new JoinedRow(buildRows(i).copy(), result)
i += 1
}
}
sparkContext.makeRDD(buf)
}
}
@ -246,71 +282,10 @@ case class BroadcastNestedLoopJoinExec(
* LeftOuter with BuildLeft
* RightOuter with BuildRight
* FullOuter
* LeftSemi with BuildLeft
* LeftAnti with BuildLeft
* ExistenceJoin with BuildLeft
*/
private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
val streamRdd = streamed.execute()
val matchedBuildRows = streamRdd.mapPartitionsInternal { streamedIter =>
val buildRows = relation.value
val matched = new BitSet(buildRows.length)
val joinedRow = new JoinedRow
streamedIter.foreach { streamedRow =>
var i = 0
while (i < buildRows.length) {
if (boundCondition(joinedRow(streamedRow, buildRows(i)))) {
matched.set(i)
}
i += 1
}
}
Seq(matched).toIterator
}
val matchedBroadcastRows = matchedBuildRows.fold(
new BitSet(relation.value.length)
)(_ | _)
joinType match {
case LeftSemi =>
assert(buildSide == BuildLeft)
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
var i = 0
val rel = relation.value
while (i < rel.length) {
if (matchedBroadcastRows.get(i)) {
buf += rel(i).copy()
}
i += 1
}
return sparkContext.makeRDD(buf)
case _: ExistenceJoin =>
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
var i = 0
val rel = relation.value
while (i < rel.length) {
val result = new GenericInternalRow(Array[Any](matchedBroadcastRows.get(i)))
buf += new JoinedRow(rel(i).copy(), result)
i += 1
}
return sparkContext.makeRDD(buf)
case LeftAnti =>
val notMatched: CompactBuffer[InternalRow] = new CompactBuffer()
var i = 0
val rel = relation.value
while (i < rel.length) {
if (!matchedBroadcastRows.get(i)) {
notMatched += rel(i).copy()
}
i += 1
}
return sparkContext.makeRDD(notMatched)
case _ =>
}
val matchedBroadcastRows = getMatchedBroadcastRowsBitSet(streamRdd, relation)
val notMatchedBroadcastRows: Seq[InternalRow] = {
val nulls = new GenericInternalRow(streamed.output.size)
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
@ -358,6 +333,34 @@ case class BroadcastNestedLoopJoinExec(
)
}
/**
* Get matched rows from broadcast side as a [[BitSet]].
* Create a local [[BitSet]] for broadcast side on each RDD partition,
* and merge all [[BitSet]]s together.
*/
private def getMatchedBroadcastRowsBitSet(
streamRdd: RDD[InternalRow],
relation: Broadcast[Array[InternalRow]]): BitSet = {
val matchedBuildRows = streamRdd.mapPartitionsInternal { streamedIter =>
val buildRows = relation.value
val matched = new BitSet(buildRows.length)
val joinedRow = new JoinedRow
streamedIter.foreach { streamedRow =>
var i = 0
while (i < buildRows.length) {
if (boundCondition(joinedRow(streamedRow, buildRows(i)))) {
matched.set(i)
}
i += 1
}
}
Seq(matched).toIterator
}
matchedBuildRows.fold(new BitSet(relation.value.length))(_ | _)
}
protected override def doExecute(): RDD[InternalRow] = {
val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
@ -366,20 +369,17 @@ case class BroadcastNestedLoopJoinExec(
innerJoin(broadcastedRelation)
case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) =>
outerJoin(broadcastedRelation)
case (LeftSemi, BuildRight) =>
case (LeftSemi, _) =>
leftExistenceJoin(broadcastedRelation, exists = true)
case (LeftAnti, BuildRight) =>
case (LeftAnti, _) =>
leftExistenceJoin(broadcastedRelation, exists = false)
case (_: ExistenceJoin, BuildRight) =>
case (_: ExistenceJoin, _) =>
existenceJoin(broadcastedRelation)
case _ =>
/**
* LeftOuter with BuildLeft
* RightOuter with BuildRight
* FullOuter
* LeftSemi with BuildLeft
* LeftAnti with BuildLeft
* ExistenceJoin with BuildLeft
*/
defaultJoin(broadcastedRelation)
}

View file

@ -82,12 +82,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
joinType: JoinType,
leftRows: => DataFrame,
rightRows: => DataFrame,
condition: => Expression,
condition: => Option[Expression],
expectedAnswer: Seq[Row]): Unit = {
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan,
Inner, Some(condition), JoinHint.NONE)
Inner, condition, JoinHint.NONE)
ExtractEquiJoinKeys.unapply(join)
}
@ -163,13 +163,13 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements.apply(
BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, Some(condition))),
BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, condition)),
expectedAnswer,
sortAnswers = true)
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements.apply(
createLeftSemiPlusJoin(BroadcastNestedLoopJoinExec(
left, right, BuildLeft, leftSemiPlus, Some(condition)))),
left, right, BuildLeft, leftSemiPlus, condition))),
expectedAnswer,
sortAnswers = true)
}
@ -179,25 +179,42 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements.apply(
BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, Some(condition))),
BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, condition)),
expectedAnswer,
sortAnswers = true)
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements.apply(
createLeftSemiPlusJoin(BroadcastNestedLoopJoinExec(
left, right, BuildRight, leftSemiPlus, Some(condition)))),
left, right, BuildRight, leftSemiPlus, condition))),
expectedAnswer,
sortAnswers = true)
}
}
}
testExistenceJoin(
"test no condition with non-empty right side for left semi join",
LeftSemi,
left,
right,
None,
Seq(Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0), Row(3, 3.0), Row(null, null),
Row(null, 5.0), Row(6, null)))
testExistenceJoin(
"test no condition with empty right side for left semi join",
LeftSemi,
left,
spark.emptyDataFrame,
None,
Seq.empty)
testExistenceJoin(
"test single condition (equal) for left semi join",
LeftSemi,
left,
right,
singleConditionEQ,
Some(singleConditionEQ),
Seq(Row(2, 1.0), Row(2, 1.0), Row(3, 3.0), Row(6, null)))
testExistenceJoin(
@ -205,7 +222,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
LeftSemi,
left,
right.select(right.col("c")).distinct(), /* Trigger BHJs and SHJs unique key code path! */
singleConditionEQ,
Some(singleConditionEQ),
Seq(Row(2, 1.0), Row(2, 1.0), Row(3, 3.0), Row(6, null)))
testExistenceJoin(
@ -213,7 +230,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
LeftSemi,
left,
right,
composedConditionEQ,
Some(composedConditionEQ),
Seq(Row(2, 1.0), Row(2, 1.0)))
testExistenceJoin(
@ -221,47 +238,65 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
LeftSemi,
left,
right,
composedConditionNEQ,
Some(composedConditionNEQ),
Seq(Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0)))
testExistenceJoin(
"test single condition (equal) for left Anti join",
"test no condition with non-empty right side for left anti join",
LeftAnti,
left,
right,
singleConditionEQ,
None,
Seq.empty)
testExistenceJoin(
"test no condition with empty right side for left anti join",
LeftAnti,
left,
spark.emptyDataFrame,
None,
Seq(Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0), Row(3, 3.0), Row(null, null),
Row(null, 5.0), Row(6, null)))
testExistenceJoin(
"test single condition (equal) for left anti join",
LeftAnti,
left,
right,
Some(singleConditionEQ),
Seq(Row(1, 2.0), Row(1, 2.0), Row(null, null), Row(null, 5.0)))
testExistenceJoin(
"test single unique condition (equal) for left Anti join",
"test single unique condition (equal) for left anti join",
LeftAnti,
left,
right.select(right.col("c")).distinct(), /* Trigger BHJs and SHJs unique key code path! */
singleConditionEQ,
Some(singleConditionEQ),
Seq(Row(1, 2.0), Row(1, 2.0), Row(null, null), Row(null, 5.0)))
testExistenceJoin(
"test composed condition (equal & non-equal) test for anti join",
"test composed condition (equal & non-equal) test for left anti join",
LeftAnti,
left,
right,
composedConditionEQ,
Some(composedConditionEQ),
Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null)))
testExistenceJoin(
"test composed condition (both non-equal) for anti join",
"test composed condition (both non-equal) for left anti join",
LeftAnti,
left,
right,
composedConditionNEQ,
Some(composedConditionNEQ),
Seq(Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null)))
testExistenceJoin(
"test composed unique condition (both non-equal) for anti join",
"test composed unique condition (both non-equal) for left anti join",
LeftAnti,
left,
rightUniqueKey,
(left.col("a") === rightUniqueKey.col("c") && left.col("b") < rightUniqueKey.col("d")).expr,
Some((left.col("a") === rightUniqueKey.col("c") && left.col("b") < rightUniqueKey.col("d"))
.expr),
Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(null, null), Row(null, 5.0), Row(6, null)))
}