[SPARK-34729][SQL] Faster execution for broadcast nested loop join (left semi/anti with no condition)
### What changes were proposed in this pull request? For `BroadcastNestedLoopJoinExec` left semi and left anti join without condition. If we broadcast left side. Currently we check whether every row from broadcast side has a match or not by [iterating broadcast side a lot of time](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala#L256-L275). This is unnecessary and very inefficient when there's no condition, as we only need to check whether stream side is empty or not. Create this PR to add the optimization. This can boost the affected query execution performance a lot. In addition, create a common method `getMatchedBroadcastRowsBitSet()` shared by several methods. Refactor `defaultJoin()` to move * left semi and left anti join related logic to `leftExistenceJoin` * existence join related logic to `existenceJoin`. After this, `defaultJoin()` holds logic only for outer join (left outer, right outer and full outer), which is much easier to read from my own opinion. ### Why are the changes needed? Improve the affected query performance a lot. Test with a simple query by modifying `JoinBenchmark.scala` locally: ``` val N = 20 << 20 val M = 1 << 4 val dim = broadcast(spark.range(M).selectExpr("id as k")) val df = dim.join(spark.range(N), Seq.empty, "left_semi") df.noop() ``` See >30x run time improvement. Note the stream side is only `spark.range(N)`. For complicated query with non-trivial stream side, the saving would be much more. ``` Running benchmark: broadcast nested loop left semi join Running case: broadcast nested loop left semi join optimization off Stopped after 2 iterations, 3163 ms Running case: broadcast nested loop left semi join optimization on Stopped after 5 iterations, 366 ms Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.15.7 Intel(R) Core(TM) i9-9980HK CPU 2.40GHz broadcast nested loop left semi join: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------- broadcast nested loop left semi join optimization off 1568 1582 19 13.4 74.8 1.0X broadcast nested loop left semi join optimization on 46 73 18 456.0 2.2 34.1X ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unit test in `ExistenceJoinSuite.scala`. Closes #31821 from c21/nested-join. Authored-by: Cheng Su <chengsu@fb.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
parent
e757091820
commit
a0f3b72e1c
|
@ -193,50 +193,86 @@ case class BroadcastNestedLoopJoinExec(
|
|||
}
|
||||
|
||||
/**
|
||||
* The implementation for these joins:
|
||||
*
|
||||
* LeftSemi with BuildRight
|
||||
* LeftAnti with BuildRight
|
||||
* The implementation for LeftSemi and LeftAnti joins.
|
||||
*/
|
||||
private def leftExistenceJoin(
|
||||
relation: Broadcast[Array[InternalRow]],
|
||||
exists: Boolean): RDD[InternalRow] = {
|
||||
assert(buildSide == BuildRight)
|
||||
streamed.execute().mapPartitionsInternal { streamedIter =>
|
||||
val buildRows = relation.value
|
||||
val joinedRow = new JoinedRow
|
||||
buildSide match {
|
||||
case BuildRight =>
|
||||
streamed.execute().mapPartitionsInternal { streamedIter =>
|
||||
val buildRows = relation.value
|
||||
val joinedRow = new JoinedRow
|
||||
|
||||
if (condition.isDefined) {
|
||||
streamedIter.filter(l =>
|
||||
buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists
|
||||
)
|
||||
} else if (buildRows.nonEmpty == exists) {
|
||||
streamedIter
|
||||
} else {
|
||||
Iterator.empty
|
||||
}
|
||||
if (condition.isDefined) {
|
||||
streamedIter.filter(l =>
|
||||
buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists
|
||||
)
|
||||
} else if (buildRows.nonEmpty == exists) {
|
||||
streamedIter
|
||||
} else {
|
||||
Iterator.empty
|
||||
}
|
||||
}
|
||||
case BuildLeft if condition.isEmpty =>
|
||||
// If condition is empty, do not need to read rows from streamed side at all.
|
||||
// Only need to know whether streamed side is empty or not.
|
||||
val streamExists = !streamed.execute().isEmpty()
|
||||
if (streamExists == exists) {
|
||||
sparkContext.makeRDD(relation.value)
|
||||
} else {
|
||||
sparkContext.emptyRDD
|
||||
}
|
||||
case _ => // BuildLeft
|
||||
val matchedBroadcastRows = getMatchedBroadcastRowsBitSet(streamed.execute(), relation)
|
||||
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
|
||||
var i = 0
|
||||
val buildRows = relation.value
|
||||
while (i < buildRows.length) {
|
||||
if (matchedBroadcastRows.get(i) == exists) {
|
||||
buf += buildRows(i).copy()
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
sparkContext.makeRDD(buf)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The implementation for ExistenceJoin
|
||||
*/
|
||||
private def existenceJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
|
||||
assert(buildSide == BuildRight)
|
||||
streamed.execute().mapPartitionsInternal { streamedIter =>
|
||||
val buildRows = relation.value
|
||||
val joinedRow = new JoinedRow
|
||||
buildSide match {
|
||||
case BuildRight =>
|
||||
streamed.execute().mapPartitionsInternal { streamedIter =>
|
||||
val buildRows = relation.value
|
||||
val joinedRow = new JoinedRow
|
||||
|
||||
if (condition.isDefined) {
|
||||
val resultRow = new GenericInternalRow(Array[Any](null))
|
||||
streamedIter.map { row =>
|
||||
val result = buildRows.exists(r => boundCondition(joinedRow(row, r)))
|
||||
resultRow.setBoolean(0, result)
|
||||
joinedRow(row, resultRow)
|
||||
if (condition.isDefined) {
|
||||
val resultRow = new GenericInternalRow(Array[Any](null))
|
||||
streamedIter.map { row =>
|
||||
val result = buildRows.exists(r => boundCondition(joinedRow(row, r)))
|
||||
resultRow.setBoolean(0, result)
|
||||
joinedRow(row, resultRow)
|
||||
}
|
||||
} else {
|
||||
val resultRow = new GenericInternalRow(Array[Any](buildRows.nonEmpty))
|
||||
streamedIter.map { row =>
|
||||
joinedRow(row, resultRow)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
val resultRow = new GenericInternalRow(Array[Any](buildRows.nonEmpty))
|
||||
streamedIter.map { row =>
|
||||
joinedRow(row, resultRow)
|
||||
case _ => // BuildLeft
|
||||
val matchedBroadcastRows = getMatchedBroadcastRowsBitSet(streamed.execute(), relation)
|
||||
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
|
||||
var i = 0
|
||||
val buildRows = relation.value
|
||||
while (i < buildRows.length) {
|
||||
val result = new GenericInternalRow(Array[Any](matchedBroadcastRows.get(i)))
|
||||
buf += new JoinedRow(buildRows(i).copy(), result)
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
sparkContext.makeRDD(buf)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -246,71 +282,10 @@ case class BroadcastNestedLoopJoinExec(
|
|||
* LeftOuter with BuildLeft
|
||||
* RightOuter with BuildRight
|
||||
* FullOuter
|
||||
* LeftSemi with BuildLeft
|
||||
* LeftAnti with BuildLeft
|
||||
* ExistenceJoin with BuildLeft
|
||||
*/
|
||||
private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
|
||||
val streamRdd = streamed.execute()
|
||||
|
||||
val matchedBuildRows = streamRdd.mapPartitionsInternal { streamedIter =>
|
||||
val buildRows = relation.value
|
||||
val matched = new BitSet(buildRows.length)
|
||||
val joinedRow = new JoinedRow
|
||||
|
||||
streamedIter.foreach { streamedRow =>
|
||||
var i = 0
|
||||
while (i < buildRows.length) {
|
||||
if (boundCondition(joinedRow(streamedRow, buildRows(i)))) {
|
||||
matched.set(i)
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
Seq(matched).toIterator
|
||||
}
|
||||
|
||||
val matchedBroadcastRows = matchedBuildRows.fold(
|
||||
new BitSet(relation.value.length)
|
||||
)(_ | _)
|
||||
|
||||
joinType match {
|
||||
case LeftSemi =>
|
||||
assert(buildSide == BuildLeft)
|
||||
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
|
||||
var i = 0
|
||||
val rel = relation.value
|
||||
while (i < rel.length) {
|
||||
if (matchedBroadcastRows.get(i)) {
|
||||
buf += rel(i).copy()
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
return sparkContext.makeRDD(buf)
|
||||
case _: ExistenceJoin =>
|
||||
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
|
||||
var i = 0
|
||||
val rel = relation.value
|
||||
while (i < rel.length) {
|
||||
val result = new GenericInternalRow(Array[Any](matchedBroadcastRows.get(i)))
|
||||
buf += new JoinedRow(rel(i).copy(), result)
|
||||
i += 1
|
||||
}
|
||||
return sparkContext.makeRDD(buf)
|
||||
case LeftAnti =>
|
||||
val notMatched: CompactBuffer[InternalRow] = new CompactBuffer()
|
||||
var i = 0
|
||||
val rel = relation.value
|
||||
while (i < rel.length) {
|
||||
if (!matchedBroadcastRows.get(i)) {
|
||||
notMatched += rel(i).copy()
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
return sparkContext.makeRDD(notMatched)
|
||||
case _ =>
|
||||
}
|
||||
|
||||
val matchedBroadcastRows = getMatchedBroadcastRowsBitSet(streamRdd, relation)
|
||||
val notMatchedBroadcastRows: Seq[InternalRow] = {
|
||||
val nulls = new GenericInternalRow(streamed.output.size)
|
||||
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
|
||||
|
@ -358,6 +333,34 @@ case class BroadcastNestedLoopJoinExec(
|
|||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get matched rows from broadcast side as a [[BitSet]].
|
||||
* Create a local [[BitSet]] for broadcast side on each RDD partition,
|
||||
* and merge all [[BitSet]]s together.
|
||||
*/
|
||||
private def getMatchedBroadcastRowsBitSet(
|
||||
streamRdd: RDD[InternalRow],
|
||||
relation: Broadcast[Array[InternalRow]]): BitSet = {
|
||||
val matchedBuildRows = streamRdd.mapPartitionsInternal { streamedIter =>
|
||||
val buildRows = relation.value
|
||||
val matched = new BitSet(buildRows.length)
|
||||
val joinedRow = new JoinedRow
|
||||
|
||||
streamedIter.foreach { streamedRow =>
|
||||
var i = 0
|
||||
while (i < buildRows.length) {
|
||||
if (boundCondition(joinedRow(streamedRow, buildRows(i)))) {
|
||||
matched.set(i)
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
Seq(matched).toIterator
|
||||
}
|
||||
|
||||
matchedBuildRows.fold(new BitSet(relation.value.length))(_ | _)
|
||||
}
|
||||
|
||||
protected override def doExecute(): RDD[InternalRow] = {
|
||||
val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
|
||||
|
||||
|
@ -366,20 +369,17 @@ case class BroadcastNestedLoopJoinExec(
|
|||
innerJoin(broadcastedRelation)
|
||||
case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) =>
|
||||
outerJoin(broadcastedRelation)
|
||||
case (LeftSemi, BuildRight) =>
|
||||
case (LeftSemi, _) =>
|
||||
leftExistenceJoin(broadcastedRelation, exists = true)
|
||||
case (LeftAnti, BuildRight) =>
|
||||
case (LeftAnti, _) =>
|
||||
leftExistenceJoin(broadcastedRelation, exists = false)
|
||||
case (_: ExistenceJoin, BuildRight) =>
|
||||
case (_: ExistenceJoin, _) =>
|
||||
existenceJoin(broadcastedRelation)
|
||||
case _ =>
|
||||
/**
|
||||
* LeftOuter with BuildLeft
|
||||
* RightOuter with BuildRight
|
||||
* FullOuter
|
||||
* LeftSemi with BuildLeft
|
||||
* LeftAnti with BuildLeft
|
||||
* ExistenceJoin with BuildLeft
|
||||
*/
|
||||
defaultJoin(broadcastedRelation)
|
||||
}
|
||||
|
|
|
@ -82,12 +82,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
|
|||
joinType: JoinType,
|
||||
leftRows: => DataFrame,
|
||||
rightRows: => DataFrame,
|
||||
condition: => Expression,
|
||||
condition: => Option[Expression],
|
||||
expectedAnswer: Seq[Row]): Unit = {
|
||||
|
||||
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
|
||||
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan,
|
||||
Inner, Some(condition), JoinHint.NONE)
|
||||
Inner, condition, JoinHint.NONE)
|
||||
ExtractEquiJoinKeys.unapply(join)
|
||||
}
|
||||
|
||||
|
@ -163,13 +163,13 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
|
|||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
EnsureRequirements.apply(
|
||||
BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, Some(condition))),
|
||||
BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, condition)),
|
||||
expectedAnswer,
|
||||
sortAnswers = true)
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
EnsureRequirements.apply(
|
||||
createLeftSemiPlusJoin(BroadcastNestedLoopJoinExec(
|
||||
left, right, BuildLeft, leftSemiPlus, Some(condition)))),
|
||||
left, right, BuildLeft, leftSemiPlus, condition))),
|
||||
expectedAnswer,
|
||||
sortAnswers = true)
|
||||
}
|
||||
|
@ -179,25 +179,42 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
|
|||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
EnsureRequirements.apply(
|
||||
BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, Some(condition))),
|
||||
BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, condition)),
|
||||
expectedAnswer,
|
||||
sortAnswers = true)
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
EnsureRequirements.apply(
|
||||
createLeftSemiPlusJoin(BroadcastNestedLoopJoinExec(
|
||||
left, right, BuildRight, leftSemiPlus, Some(condition)))),
|
||||
left, right, BuildRight, leftSemiPlus, condition))),
|
||||
expectedAnswer,
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
testExistenceJoin(
|
||||
"test no condition with non-empty right side for left semi join",
|
||||
LeftSemi,
|
||||
left,
|
||||
right,
|
||||
None,
|
||||
Seq(Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0), Row(3, 3.0), Row(null, null),
|
||||
Row(null, 5.0), Row(6, null)))
|
||||
|
||||
testExistenceJoin(
|
||||
"test no condition with empty right side for left semi join",
|
||||
LeftSemi,
|
||||
left,
|
||||
spark.emptyDataFrame,
|
||||
None,
|
||||
Seq.empty)
|
||||
|
||||
testExistenceJoin(
|
||||
"test single condition (equal) for left semi join",
|
||||
LeftSemi,
|
||||
left,
|
||||
right,
|
||||
singleConditionEQ,
|
||||
Some(singleConditionEQ),
|
||||
Seq(Row(2, 1.0), Row(2, 1.0), Row(3, 3.0), Row(6, null)))
|
||||
|
||||
testExistenceJoin(
|
||||
|
@ -205,7 +222,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
|
|||
LeftSemi,
|
||||
left,
|
||||
right.select(right.col("c")).distinct(), /* Trigger BHJs and SHJs unique key code path! */
|
||||
singleConditionEQ,
|
||||
Some(singleConditionEQ),
|
||||
Seq(Row(2, 1.0), Row(2, 1.0), Row(3, 3.0), Row(6, null)))
|
||||
|
||||
testExistenceJoin(
|
||||
|
@ -213,7 +230,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
|
|||
LeftSemi,
|
||||
left,
|
||||
right,
|
||||
composedConditionEQ,
|
||||
Some(composedConditionEQ),
|
||||
Seq(Row(2, 1.0), Row(2, 1.0)))
|
||||
|
||||
testExistenceJoin(
|
||||
|
@ -221,47 +238,65 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
|
|||
LeftSemi,
|
||||
left,
|
||||
right,
|
||||
composedConditionNEQ,
|
||||
Some(composedConditionNEQ),
|
||||
Seq(Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0)))
|
||||
|
||||
testExistenceJoin(
|
||||
"test single condition (equal) for left Anti join",
|
||||
"test no condition with non-empty right side for left anti join",
|
||||
LeftAnti,
|
||||
left,
|
||||
right,
|
||||
singleConditionEQ,
|
||||
None,
|
||||
Seq.empty)
|
||||
|
||||
testExistenceJoin(
|
||||
"test no condition with empty right side for left anti join",
|
||||
LeftAnti,
|
||||
left,
|
||||
spark.emptyDataFrame,
|
||||
None,
|
||||
Seq(Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0), Row(3, 3.0), Row(null, null),
|
||||
Row(null, 5.0), Row(6, null)))
|
||||
|
||||
testExistenceJoin(
|
||||
"test single condition (equal) for left anti join",
|
||||
LeftAnti,
|
||||
left,
|
||||
right,
|
||||
Some(singleConditionEQ),
|
||||
Seq(Row(1, 2.0), Row(1, 2.0), Row(null, null), Row(null, 5.0)))
|
||||
|
||||
|
||||
testExistenceJoin(
|
||||
"test single unique condition (equal) for left Anti join",
|
||||
"test single unique condition (equal) for left anti join",
|
||||
LeftAnti,
|
||||
left,
|
||||
right.select(right.col("c")).distinct(), /* Trigger BHJs and SHJs unique key code path! */
|
||||
singleConditionEQ,
|
||||
Some(singleConditionEQ),
|
||||
Seq(Row(1, 2.0), Row(1, 2.0), Row(null, null), Row(null, 5.0)))
|
||||
|
||||
testExistenceJoin(
|
||||
"test composed condition (equal & non-equal) test for anti join",
|
||||
"test composed condition (equal & non-equal) test for left anti join",
|
||||
LeftAnti,
|
||||
left,
|
||||
right,
|
||||
composedConditionEQ,
|
||||
Some(composedConditionEQ),
|
||||
Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null)))
|
||||
|
||||
testExistenceJoin(
|
||||
"test composed condition (both non-equal) for anti join",
|
||||
"test composed condition (both non-equal) for left anti join",
|
||||
LeftAnti,
|
||||
left,
|
||||
right,
|
||||
composedConditionNEQ,
|
||||
Some(composedConditionNEQ),
|
||||
Seq(Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null)))
|
||||
|
||||
testExistenceJoin(
|
||||
"test composed unique condition (both non-equal) for anti join",
|
||||
"test composed unique condition (both non-equal) for left anti join",
|
||||
LeftAnti,
|
||||
left,
|
||||
rightUniqueKey,
|
||||
(left.col("a") === rightUniqueKey.col("c") && left.col("b") < rightUniqueKey.col("d")).expr,
|
||||
Some((left.col("a") === rightUniqueKey.col("c") && left.col("b") < rightUniqueKey.col("d"))
|
||||
.expr),
|
||||
Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(null, null), Row(null, 5.0), Row(6, null)))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue