[SPARK-32649][SQL] Optimize BHJ/SHJ inner/semi join with empty hashed relation
### What changes were proposed in this pull request? For broadcast hash join and shuffled hash join, whenever the build side hashed relation turns out to be empty. We don't need to execute stream side plan at all, and can return an empty iterator (for inner join and left semi join), because we know for sure that none of stream side rows can be outputted as there's no match. ### Why are the changes needed? A very minor optimization for rare use case, but in case build side turns out to be empty, we can leverage it to short-cut stream side to save CPU and IO. Example broadcast hash join query similar to `JoinBenchmark` with empty hashed relation: ``` def broadcastHashJoinLongKey(): Unit = { val N = 20 << 20 val M = 1 << 16 val dim = broadcast(spark.range(0).selectExpr("id as k", "cast(id as string) as v")) codegenBenchmark("Join w long", N) { val df = spark.range(N).join(dim, (col("id") % M) === col("k")) assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) df.noop() } } ``` Comparing wall clock time for enabling and disabling this PR (for non-codegen code path). Seeing like 8x improvement. ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.15.4 Intel(R) Core(TM) i9-9980HK CPU 2.40GHz Join w long: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ Join PR disabled 637 646 12 32.9 30.4 1.0X Join PR enabled 77 78 2 271.8 3.7 8.3X ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added unit test in `JoinSuite`. Closes #29484 from c21/empty-relation. Authored-by: Cheng Su <chengsu@fb.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
11c6a23c13
commit
08b951b1cb
|
@ -80,7 +80,7 @@ case class AdaptiveSparkPlanExec(
|
|||
// TODO add more optimization rules
|
||||
override protected def batches: Seq[Batch] = Seq(
|
||||
Batch("Demote BroadcastHashJoin", Once, DemoteBroadcastHashJoin(conf)),
|
||||
Batch("Eliminate Null Aware Anti Join", Once, EliminateNullAwareAntiJoin)
|
||||
Batch("Eliminate Join to Empty Relation", Once, EliminateJoinToEmptyRelation)
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.adaptive
|
||||
|
||||
import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin
|
||||
import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Join, LocalRelation, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
import org.apache.spark.sql.execution.joins.{EmptyHashedRelation, HashedRelation, HashedRelationWithAllNullKeys}
|
||||
|
||||
/**
|
||||
* This optimization rule detects and converts a Join to an empty [[LocalRelation]]:
|
||||
* 1. Join is single column NULL-aware anti join (NAAJ), and broadcasted [[HashedRelation]]
|
||||
* is [[HashedRelationWithAllNullKeys]].
|
||||
*
|
||||
* 2. Join is inner or left semi join, and broadcasted [[HashedRelation]]
|
||||
* is [[EmptyHashedRelation]].
|
||||
* This applies to all Joins (sort merge join, shuffled hash join, and broadcast hash join),
|
||||
* because sort merge join and shuffled hash join will be changed to broadcast hash join with AQE
|
||||
* at the first place.
|
||||
*/
|
||||
object EliminateJoinToEmptyRelation extends Rule[LogicalPlan] {
|
||||
|
||||
private def canEliminate(plan: LogicalPlan, relation: HashedRelation): Boolean = plan match {
|
||||
case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.resultOption.get().isDefined
|
||||
&& stage.broadcast.relationFuture.get().value == relation => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown {
|
||||
case j @ ExtractSingleColumnNullAwareAntiJoin(_, _)
|
||||
if canEliminate(j.right, HashedRelationWithAllNullKeys) =>
|
||||
LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming)
|
||||
|
||||
case j @ Join(_, _, Inner, _, _) if canEliminate(j.left, EmptyHashedRelation) ||
|
||||
canEliminate(j.right, EmptyHashedRelation) =>
|
||||
LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming)
|
||||
|
||||
case j @ Join(_, _, LeftSemi, _, _) if canEliminate(j.right, EmptyHashedRelation) =>
|
||||
LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming)
|
||||
}
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.adaptive
|
||||
|
||||
import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys
|
||||
|
||||
/**
|
||||
* This optimization rule detects and convert a NAAJ to an Empty LocalRelation
|
||||
* when buildSide is HashedRelationWithAllNullKeys.
|
||||
*/
|
||||
object EliminateNullAwareAntiJoin extends Rule[LogicalPlan] {
|
||||
|
||||
private def canEliminate(plan: LogicalPlan): Boolean = plan match {
|
||||
case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.resultOption.get().isDefined
|
||||
&& stage.broadcast.relationFuture.get().value == HashedRelationWithAllNullKeys => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown {
|
||||
case j @ ExtractSingleColumnNullAwareAntiJoin(_, _) if canEliminate(j.right) =>
|
||||
LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming)
|
||||
}
|
||||
}
|
|
@ -155,7 +155,9 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
|
|||
val joinRow = new JoinedRow
|
||||
val joinKeys = streamSideKeyGenerator()
|
||||
|
||||
if (hashedRelation.keyIsUnique) {
|
||||
if (hashedRelation == EmptyHashedRelation) {
|
||||
Iterator.empty
|
||||
} else if (hashedRelation.keyIsUnique) {
|
||||
streamIter.flatMap { srow =>
|
||||
joinRow.withLeft(srow)
|
||||
val matched = hashedRelation.getValue(joinKeys(srow))
|
||||
|
@ -230,7 +232,9 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
|
|||
val joinKeys = streamSideKeyGenerator()
|
||||
val joinedRow = new JoinedRow
|
||||
|
||||
if (hashedRelation.keyIsUnique) {
|
||||
if (hashedRelation == EmptyHashedRelation) {
|
||||
Iterator.empty
|
||||
} else if (hashedRelation.keyIsUnique) {
|
||||
streamIter.filter { current =>
|
||||
val key = joinKeys(current)
|
||||
lazy val matched = hashedRelation.getValue(key)
|
||||
|
@ -432,7 +436,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
|
|||
* Generates the code for Inner join.
|
||||
*/
|
||||
protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
|
||||
val HashedRelationInfo(relationTerm, keyIsUnique, _) = prepareRelation(ctx)
|
||||
val HashedRelationInfo(relationTerm, keyIsUnique, isEmptyHashedRelation) = prepareRelation(ctx)
|
||||
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
|
||||
val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
|
||||
val numOutput = metricTerm(ctx, "numOutputRows")
|
||||
|
@ -442,7 +446,11 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
|
|||
case BuildRight => input ++ buildVars
|
||||
}
|
||||
|
||||
if (keyIsUnique) {
|
||||
if (isEmptyHashedRelation) {
|
||||
"""
|
||||
|// If HashedRelation is empty, hash inner join simply returns nothing.
|
||||
""".stripMargin
|
||||
} else if (keyIsUnique) {
|
||||
s"""
|
||||
|// generate join key for stream side
|
||||
|${keyEv.code}
|
||||
|
@ -559,12 +567,16 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
|
|||
* Generates the code for left semi join.
|
||||
*/
|
||||
protected def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = {
|
||||
val HashedRelationInfo(relationTerm, keyIsUnique, _) = prepareRelation(ctx)
|
||||
val HashedRelationInfo(relationTerm, keyIsUnique, isEmptyHashedRelation) = prepareRelation(ctx)
|
||||
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
|
||||
val (matched, checkCondition, _) = getJoinCondition(ctx, input)
|
||||
val numOutput = metricTerm(ctx, "numOutputRows")
|
||||
|
||||
if (keyIsUnique) {
|
||||
if (isEmptyHashedRelation) {
|
||||
"""
|
||||
|// If HashedRelation is empty, hash semi join simply returns nothing.
|
||||
""".stripMargin
|
||||
} else if (keyIsUnique) {
|
||||
s"""
|
||||
|// generate join key for stream side
|
||||
|${keyEv.code}
|
||||
|
@ -612,10 +624,10 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
|
|||
val numOutput = metricTerm(ctx, "numOutputRows")
|
||||
if (isEmptyHashedRelation) {
|
||||
return s"""
|
||||
|// If the right side is empty, Anti Join simply returns the left side.
|
||||
|// If HashedRelation is empty, hash anti join simply returns the stream side.
|
||||
|$numOutput.add(1);
|
||||
|${consume(ctx, input)}
|
||||
|""".stripMargin
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
|
||||
|
|
|
@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
|
|||
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
|
||||
import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.Filter
|
||||
import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, SortExec, SparkPlan}
|
||||
import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
|
||||
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
|
||||
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
|
||||
import org.apache.spark.sql.execution.joins._
|
||||
|
@ -1254,4 +1254,56 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-32649: Optimize BHJ/SHJ inner/semi join with empty hashed relation") {
|
||||
val inputDFs = Seq(
|
||||
// Test empty build side for inner join
|
||||
(spark.range(30).selectExpr("id as k1"),
|
||||
spark.range(10).selectExpr("id as k2").filter("k2 < -1"),
|
||||
"inner"),
|
||||
// Test empty build side for semi join
|
||||
(spark.range(30).selectExpr("id as k1"),
|
||||
spark.range(10).selectExpr("id as k2").filter("k2 < -1"),
|
||||
"semi")
|
||||
)
|
||||
inputDFs.foreach { case (df1, df2, joinType) =>
|
||||
// Test broadcast hash join
|
||||
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "200") {
|
||||
val bhjCodegenDF = df1.join(df2, $"k1" === $"k2", joinType)
|
||||
assert(bhjCodegenDF.queryExecution.executedPlan.collect {
|
||||
case WholeStageCodegenExec(_ : BroadcastHashJoinExec) => true
|
||||
case WholeStageCodegenExec(ProjectExec(_, _ : BroadcastHashJoinExec)) => true
|
||||
}.size === 1)
|
||||
checkAnswer(bhjCodegenDF, Seq.empty)
|
||||
|
||||
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
|
||||
val bhjNonCodegenDF = df1.join(df2, $"k1" === $"k2", joinType)
|
||||
assert(bhjNonCodegenDF.queryExecution.executedPlan.collect {
|
||||
case _: BroadcastHashJoinExec => true }.size === 1)
|
||||
checkAnswer(bhjNonCodegenDF, Seq.empty)
|
||||
}
|
||||
}
|
||||
|
||||
// Test shuffled hash join
|
||||
withSQLConf(SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
|
||||
// Set broadcast join threshold and number of shuffle partitions,
|
||||
// as shuffled hash join depends on these two configs.
|
||||
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "50",
|
||||
SQLConf.SHUFFLE_PARTITIONS.key -> "2") {
|
||||
val shjCodegenDF = df1.join(df2, $"k1" === $"k2", joinType)
|
||||
assert(shjCodegenDF.queryExecution.executedPlan.collect {
|
||||
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
|
||||
case WholeStageCodegenExec(ProjectExec(_, _ : ShuffledHashJoinExec)) => true
|
||||
}.size === 1)
|
||||
checkAnswer(shjCodegenDF, Seq.empty)
|
||||
|
||||
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
|
||||
val shjNonCodegenDF = df1.join(df2, $"k1" === $"k2", joinType)
|
||||
assert(shjNonCodegenDF.queryExecution.executedPlan.collect {
|
||||
case _: ShuffledHashJoinExec => true }.size === 1)
|
||||
checkAnswer(shjNonCodegenDF, Seq.empty)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -226,7 +226,8 @@ class AdaptiveQueryExecSuite
|
|||
val df1 = spark.range(10).withColumn("a", 'id)
|
||||
val df2 = spark.range(10).withColumn("b", 'id)
|
||||
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
|
||||
val testDf = df1.where('a > 10).join(df2.where('b > 10), "id").groupBy('a).count()
|
||||
val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer")
|
||||
.groupBy('a).count()
|
||||
checkAnswer(testDf, Seq())
|
||||
val plan = testDf.queryExecution.executedPlan
|
||||
assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined)
|
||||
|
@ -238,7 +239,8 @@ class AdaptiveQueryExecSuite
|
|||
}
|
||||
|
||||
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") {
|
||||
val testDf = df1.where('a > 10).join(df2.where('b > 10), "id").groupBy('a).count()
|
||||
val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer")
|
||||
.groupBy('a).count()
|
||||
checkAnswer(testDf, Seq())
|
||||
val plan = testDf.queryExecution.executedPlan
|
||||
assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
|
||||
|
@ -1181,4 +1183,26 @@ class AdaptiveQueryExecSuite
|
|||
checkNumLocalShuffleReaders(adaptivePlan)
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-32649: Eliminate inner and semi join to empty relation") {
|
||||
withSQLConf(
|
||||
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
|
||||
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
|
||||
Seq(
|
||||
// inner join (small table at right side)
|
||||
"SELECT * FROM testData t1 join testData3 t2 ON t1.key = t2.a WHERE t2.b = 1",
|
||||
// inner join (small table at left side)
|
||||
"SELECT * FROM testData3 t1 join testData t2 ON t1.a = t2.key WHERE t1.b = 1",
|
||||
// left semi join
|
||||
"SELECT * FROM testData t1 left semi join testData3 t2 ON t1.key = t2.a AND t2.b = 1"
|
||||
).foreach(query => {
|
||||
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query)
|
||||
val smj = findTopLevelSortMergeJoin(plan)
|
||||
assert(smj.size == 1)
|
||||
val join = findTopLevelBaseJoin(adaptivePlan)
|
||||
assert(join.isEmpty)
|
||||
checkNumLocalShuffleReaders(adaptivePlan)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue