[SPARK-22916][SQL] shouldn't bias towards build right if user does not specify
## What changes were proposed in this pull request? When there are no broadcast hints, the current spark strategies will prefer to building the right side, without considering the sizes of the two tables. This patch added the logic to consider the sizes of the two tables for the build side. To make the logic clear, the build side is determined by two steps: 1. If there are broadcast hints, the build side is determined by `broadcastSideByHints`; 2. If there are no broadcast hints, the build side is determined by `broadcastSideBySizes`; 3. If the broadcast is disabled by the config, it falls back to the next cases. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu <fengliu@databricks.com> Closes #20099 from liufengdb/fix-spark-strategies.
This commit is contained in:
parent
224375c55f
commit
cc30ef8009
|
@ -158,45 +158,65 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
def smallerSide =
|
||||
if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft
|
||||
|
||||
val buildRight = canBuildRight && right.stats.hints.broadcast
|
||||
val buildLeft = canBuildLeft && left.stats.hints.broadcast
|
||||
|
||||
if (buildRight && buildLeft) {
|
||||
if (canBuildRight && canBuildLeft) {
|
||||
// Broadcast smaller side base on its estimated physical size
|
||||
// if both sides have broadcast hint
|
||||
smallerSide
|
||||
} else if (buildRight) {
|
||||
} else if (canBuildRight) {
|
||||
BuildRight
|
||||
} else if (buildLeft) {
|
||||
} else if (canBuildLeft) {
|
||||
BuildLeft
|
||||
} else if (canBuildRight && canBuildLeft) {
|
||||
} else {
|
||||
// for the last default broadcast nested loop join
|
||||
smallerSide
|
||||
} else {
|
||||
throw new AnalysisException("Can not decide which side to broadcast for this join")
|
||||
}
|
||||
}
|
||||
|
||||
private def canBroadcastByHints(joinType: JoinType, left: LogicalPlan, right: LogicalPlan)
|
||||
: Boolean = {
|
||||
val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast
|
||||
val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast
|
||||
buildLeft || buildRight
|
||||
}
|
||||
|
||||
private def broadcastSideByHints(joinType: JoinType, left: LogicalPlan, right: LogicalPlan)
|
||||
: BuildSide = {
|
||||
val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast
|
||||
val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast
|
||||
broadcastSide(buildLeft, buildRight, left, right)
|
||||
}
|
||||
|
||||
private def canBroadcastBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan)
|
||||
: Boolean = {
|
||||
val buildLeft = canBuildLeft(joinType) && canBroadcast(left)
|
||||
val buildRight = canBuildRight(joinType) && canBroadcast(right)
|
||||
buildLeft || buildRight
|
||||
}
|
||||
|
||||
private def broadcastSideBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan)
|
||||
: BuildSide = {
|
||||
val buildLeft = canBuildLeft(joinType) && canBroadcast(left)
|
||||
val buildRight = canBuildRight(joinType) && canBroadcast(right)
|
||||
broadcastSide(buildLeft, buildRight, left, right)
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
|
||||
|
||||
// --- BroadcastHashJoin --------------------------------------------------------------------
|
||||
|
||||
// broadcast hints were specified
|
||||
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
|
||||
if (canBuildRight(joinType) && right.stats.hints.broadcast) ||
|
||||
(canBuildLeft(joinType) && left.stats.hints.broadcast) =>
|
||||
val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right)
|
||||
if canBroadcastByHints(joinType, left, right) =>
|
||||
val buildSide = broadcastSideByHints(joinType, left, right)
|
||||
Seq(joins.BroadcastHashJoinExec(
|
||||
leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right)))
|
||||
|
||||
// broadcast hints were not specified, so need to infer it from size and configuration.
|
||||
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
|
||||
if canBuildRight(joinType) && canBroadcast(right) =>
|
||||
if canBroadcastBySizes(joinType, left, right) =>
|
||||
val buildSide = broadcastSideBySizes(joinType, left, right)
|
||||
Seq(joins.BroadcastHashJoinExec(
|
||||
leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right)))
|
||||
|
||||
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
|
||||
if canBuildLeft(joinType) && canBroadcast(left) =>
|
||||
Seq(joins.BroadcastHashJoinExec(
|
||||
leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right)))
|
||||
leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right)))
|
||||
|
||||
// --- ShuffledHashJoin ---------------------------------------------------------------------
|
||||
|
||||
|
@ -225,27 +245,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
|
||||
// Pick BroadcastNestedLoopJoin if one side could be broadcasted
|
||||
case j @ logical.Join(left, right, joinType, condition)
|
||||
if (canBuildRight(joinType) && right.stats.hints.broadcast) ||
|
||||
(canBuildLeft(joinType) && left.stats.hints.broadcast) =>
|
||||
val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right)
|
||||
if canBroadcastByHints(joinType, left, right) =>
|
||||
val buildSide = broadcastSideByHints(joinType, left, right)
|
||||
joins.BroadcastNestedLoopJoinExec(
|
||||
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
|
||||
|
||||
case j @ logical.Join(left, right, joinType, condition)
|
||||
if canBuildRight(joinType) && canBroadcast(right) =>
|
||||
if canBroadcastBySizes(joinType, left, right) =>
|
||||
val buildSide = broadcastSideBySizes(joinType, left, right)
|
||||
joins.BroadcastNestedLoopJoinExec(
|
||||
planLater(left), planLater(right), BuildRight, joinType, condition) :: Nil
|
||||
case j @ logical.Join(left, right, joinType, condition)
|
||||
if canBuildLeft(joinType) && canBroadcast(left) =>
|
||||
joins.BroadcastNestedLoopJoinExec(
|
||||
planLater(left), planLater(right), BuildLeft, joinType, condition) :: Nil
|
||||
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
|
||||
|
||||
// Pick CartesianProduct for InnerJoin
|
||||
case logical.Join(left, right, _: InnerLike, condition) =>
|
||||
joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil
|
||||
|
||||
case logical.Join(left, right, joinType, condition) =>
|
||||
val buildSide = broadcastSide(canBuildLeft = true, canBuildRight = true, left, right)
|
||||
val buildSide = broadcastSide(
|
||||
left.stats.hints.broadcast, right.stats.hints.broadcast, left, right)
|
||||
// This join could be very slow or OOM
|
||||
joins.BroadcastNestedLoopJoinExec(
|
||||
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
|
||||
|
|
|
@ -225,17 +225,6 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
|
|||
}
|
||||
|
||||
test("Shouldn't change broadcast join buildSide if user clearly specified") {
|
||||
def assertJoinBuildSide(sqlStr: String, joinMethod: String, buildSide: BuildSide): Any = {
|
||||
val executedPlan = sql(sqlStr).queryExecution.executedPlan
|
||||
executedPlan match {
|
||||
case b: BroadcastNestedLoopJoinExec =>
|
||||
assert(b.getClass.getSimpleName === joinMethod)
|
||||
assert(b.buildSide === buildSide)
|
||||
case w: WholeStageCodegenExec =>
|
||||
assert(w.children.head.getClass.getSimpleName === joinMethod)
|
||||
assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide)
|
||||
}
|
||||
}
|
||||
|
||||
withTempView("t1", "t2") {
|
||||
spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1")
|
||||
|
@ -246,9 +235,6 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
|
|||
val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
|
||||
assert(t1Size < t2Size)
|
||||
|
||||
val bh = BroadcastHashJoinExec.toString
|
||||
val bl = BroadcastNestedLoopJoinExec.toString
|
||||
|
||||
// INNER JOIN && t1Size < t2Size => BuildLeft
|
||||
assertJoinBuildSide(
|
||||
"SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft)
|
||||
|
@ -266,8 +252,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
|
|||
"SELECT /*+ MAPJOIN(t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildRight)
|
||||
|
||||
|
||||
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0",
|
||||
SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
|
||||
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
|
||||
// INNER JOIN && t1Size < t2Size => BuildLeft
|
||||
assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2", bl, BuildLeft)
|
||||
// FULL JOIN && t1Size < t2Size => BuildLeft
|
||||
|
@ -290,4 +275,62 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("Shouldn't bias towards build right if user didn't specify") {
|
||||
|
||||
withTempView("t1", "t2") {
|
||||
spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1")
|
||||
spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value")
|
||||
.createTempView("t2")
|
||||
|
||||
val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
|
||||
val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
|
||||
assert(t1Size < t2Size)
|
||||
|
||||
assertJoinBuildSide("SELECT * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft)
|
||||
assertJoinBuildSide("SELECT * FROM t2 JOIN t1 ON t1.key = t2.key", bh, BuildRight)
|
||||
|
||||
assertJoinBuildSide("SELECT * FROM t1 LEFT JOIN t2 ON t1.key = t2.key", bh, BuildRight)
|
||||
assertJoinBuildSide("SELECT * FROM t2 LEFT JOIN t1 ON t1.key = t2.key", bh, BuildRight)
|
||||
|
||||
assertJoinBuildSide("SELECT * FROM t1 RIGHT JOIN t2 ON t1.key = t2.key", bh, BuildLeft)
|
||||
assertJoinBuildSide("SELECT * FROM t2 RIGHT JOIN t1 ON t1.key = t2.key", bh, BuildLeft)
|
||||
|
||||
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
|
||||
assertJoinBuildSide("SELECT * FROM t1 FULL OUTER JOIN t2", bl, BuildLeft)
|
||||
assertJoinBuildSide("SELECT * FROM t2 FULL OUTER JOIN t1", bl, BuildRight)
|
||||
|
||||
assertJoinBuildSide("SELECT * FROM t1 LEFT JOIN t2", bl, BuildRight)
|
||||
assertJoinBuildSide("SELECT * FROM t2 LEFT JOIN t1", bl, BuildRight)
|
||||
|
||||
assertJoinBuildSide("SELECT * FROM t1 RIGHT JOIN t2", bl, BuildLeft)
|
||||
assertJoinBuildSide("SELECT * FROM t2 RIGHT JOIN t1", bl, BuildLeft)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val bh = BroadcastHashJoinExec.toString
|
||||
private val bl = BroadcastNestedLoopJoinExec.toString
|
||||
|
||||
private def assertJoinBuildSide(sqlStr: String, joinMethod: String, buildSide: BuildSide): Any = {
|
||||
val executedPlan = sql(sqlStr).queryExecution.executedPlan
|
||||
executedPlan match {
|
||||
case b: BroadcastNestedLoopJoinExec =>
|
||||
assert(b.getClass.getSimpleName === joinMethod)
|
||||
assert(b.buildSide === buildSide)
|
||||
case b: BroadcastNestedLoopJoinExec =>
|
||||
assert(b.getClass.getSimpleName === joinMethod)
|
||||
assert(b.buildSide === buildSide)
|
||||
case w: WholeStageCodegenExec =>
|
||||
assert(w.children.head.getClass.getSimpleName === joinMethod)
|
||||
if (w.children.head.isInstanceOf[BroadcastNestedLoopJoinExec]) {
|
||||
assert(
|
||||
w.children.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide === buildSide)
|
||||
} else if (w.children.head.isInstanceOf[BroadcastHashJoinExec]) {
|
||||
assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide)
|
||||
} else {
|
||||
fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -478,15 +478,22 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
|
|||
spark.range(10).write.parquet(dir)
|
||||
spark.read.parquet(dir).createOrReplaceTempView("pqS")
|
||||
|
||||
// The executed plan looks like:
|
||||
// Exchange RoundRobinPartitioning(2)
|
||||
// +- BroadcastNestedLoopJoin BuildLeft, Cross
|
||||
// :- BroadcastExchange IdentityBroadcastMode
|
||||
// : +- Exchange RoundRobinPartitioning(3)
|
||||
// : +- *Range (0, 30, step=1, splits=2)
|
||||
// +- *FileScan parquet [id#465L] Batched: true, Format: Parquet, Location: ...(ignored)
|
||||
val res3 = InputOutputMetricsHelper.run(
|
||||
spark.range(30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF()
|
||||
)
|
||||
// The query above is executed in the following stages:
|
||||
// 1. sql("select * from pqS") => (10, 0, 10)
|
||||
// 2. range(30) => (30, 0, 30)
|
||||
// 3. crossJoin(...) of 1. and 2. => (0, 30, 300)
|
||||
// 1. range(30) => (30, 0, 30)
|
||||
// 2. sql("select * from pqS") => (0, 30, 0)
|
||||
// 3. crossJoin(...) of 1. and 2. => (10, 0, 300)
|
||||
// 4. shuffle & return results => (0, 300, 0)
|
||||
assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil)
|
||||
assert(res3 === (30L, 0L, 30L) :: (0L, 30L, 0L) :: (10L, 0L, 300L) :: (0L, 300L, 0L) :: Nil)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue