[SPARK-22916][SQL] shouldn't bias towards build right if user does not specify

## What changes were proposed in this pull request?

When there are no broadcast hints, the current spark strategies will prefer to building the right side, without considering the sizes of the two tables. This patch added the logic to consider the sizes of the two tables for the build side. To make the logic clear, the build side is determined by two steps:

1. If there are broadcast hints, the build side is determined by `broadcastSideByHints`;
2. If there are no broadcast hints, the build side is determined by `broadcastSideBySizes`;
3. If the broadcast is disabled by the config, it falls back to the next cases.

## How was this patch tested?

(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Feng Liu <fengliu@databricks.com>

Closes #20099 from liufengdb/fix-spark-strategies.
This commit is contained in:
Feng Liu 2017-12-29 18:48:47 +08:00 committed by gatorsmile
parent 224375c55f
commit cc30ef8009
3 changed files with 116 additions and 49 deletions

View file

@ -158,45 +158,65 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def smallerSide =
if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft
val buildRight = canBuildRight && right.stats.hints.broadcast
val buildLeft = canBuildLeft && left.stats.hints.broadcast
if (buildRight && buildLeft) {
if (canBuildRight && canBuildLeft) {
// Broadcast smaller side base on its estimated physical size
// if both sides have broadcast hint
smallerSide
} else if (buildRight) {
} else if (canBuildRight) {
BuildRight
} else if (buildLeft) {
} else if (canBuildLeft) {
BuildLeft
} else if (canBuildRight && canBuildLeft) {
} else {
// for the last default broadcast nested loop join
smallerSide
} else {
throw new AnalysisException("Can not decide which side to broadcast for this join")
}
}
private def canBroadcastByHints(joinType: JoinType, left: LogicalPlan, right: LogicalPlan)
: Boolean = {
val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast
val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast
buildLeft || buildRight
}
private def broadcastSideByHints(joinType: JoinType, left: LogicalPlan, right: LogicalPlan)
: BuildSide = {
val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast
val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast
broadcastSide(buildLeft, buildRight, left, right)
}
private def canBroadcastBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan)
: Boolean = {
val buildLeft = canBuildLeft(joinType) && canBroadcast(left)
val buildRight = canBuildRight(joinType) && canBroadcast(right)
buildLeft || buildRight
}
private def broadcastSideBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan)
: BuildSide = {
val buildLeft = canBuildLeft(joinType) && canBroadcast(left)
val buildRight = canBuildRight(joinType) && canBroadcast(right)
broadcastSide(buildLeft, buildRight, left, right)
}
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// --- BroadcastHashJoin --------------------------------------------------------------------
// broadcast hints were specified
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
if (canBuildRight(joinType) && right.stats.hints.broadcast) ||
(canBuildLeft(joinType) && left.stats.hints.broadcast) =>
val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right)
if canBroadcastByHints(joinType, left, right) =>
val buildSide = broadcastSideByHints(joinType, left, right)
Seq(joins.BroadcastHashJoinExec(
leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right)))
// broadcast hints were not specified, so need to infer it from size and configuration.
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
if canBuildRight(joinType) && canBroadcast(right) =>
if canBroadcastBySizes(joinType, left, right) =>
val buildSide = broadcastSideBySizes(joinType, left, right)
Seq(joins.BroadcastHashJoinExec(
leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right)))
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
if canBuildLeft(joinType) && canBroadcast(left) =>
Seq(joins.BroadcastHashJoinExec(
leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right)))
leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right)))
// --- ShuffledHashJoin ---------------------------------------------------------------------
@ -225,27 +245,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// Pick BroadcastNestedLoopJoin if one side could be broadcasted
case j @ logical.Join(left, right, joinType, condition)
if (canBuildRight(joinType) && right.stats.hints.broadcast) ||
(canBuildLeft(joinType) && left.stats.hints.broadcast) =>
val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right)
if canBroadcastByHints(joinType, left, right) =>
val buildSide = broadcastSideByHints(joinType, left, right)
joins.BroadcastNestedLoopJoinExec(
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
case j @ logical.Join(left, right, joinType, condition)
if canBuildRight(joinType) && canBroadcast(right) =>
if canBroadcastBySizes(joinType, left, right) =>
val buildSide = broadcastSideBySizes(joinType, left, right)
joins.BroadcastNestedLoopJoinExec(
planLater(left), planLater(right), BuildRight, joinType, condition) :: Nil
case j @ logical.Join(left, right, joinType, condition)
if canBuildLeft(joinType) && canBroadcast(left) =>
joins.BroadcastNestedLoopJoinExec(
planLater(left), planLater(right), BuildLeft, joinType, condition) :: Nil
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
// Pick CartesianProduct for InnerJoin
case logical.Join(left, right, _: InnerLike, condition) =>
joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil
case logical.Join(left, right, joinType, condition) =>
val buildSide = broadcastSide(canBuildLeft = true, canBuildRight = true, left, right)
val buildSide = broadcastSide(
left.stats.hints.broadcast, right.stats.hints.broadcast, left, right)
// This join could be very slow or OOM
joins.BroadcastNestedLoopJoinExec(
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil

View file

@ -225,17 +225,6 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
}
test("Shouldn't change broadcast join buildSide if user clearly specified") {
def assertJoinBuildSide(sqlStr: String, joinMethod: String, buildSide: BuildSide): Any = {
val executedPlan = sql(sqlStr).queryExecution.executedPlan
executedPlan match {
case b: BroadcastNestedLoopJoinExec =>
assert(b.getClass.getSimpleName === joinMethod)
assert(b.buildSide === buildSide)
case w: WholeStageCodegenExec =>
assert(w.children.head.getClass.getSimpleName === joinMethod)
assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide)
}
}
withTempView("t1", "t2") {
spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1")
@ -246,9 +235,6 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
assert(t1Size < t2Size)
val bh = BroadcastHashJoinExec.toString
val bl = BroadcastNestedLoopJoinExec.toString
// INNER JOIN && t1Size < t2Size => BuildLeft
assertJoinBuildSide(
"SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft)
@ -266,8 +252,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
"SELECT /*+ MAPJOIN(t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildRight)
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0",
SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
// INNER JOIN && t1Size < t2Size => BuildLeft
assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2", bl, BuildLeft)
// FULL JOIN && t1Size < t2Size => BuildLeft
@ -290,4 +275,62 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
}
}
}
test("Shouldn't bias towards build right if user didn't specify") {
withTempView("t1", "t2") {
spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1")
spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value")
.createTempView("t2")
val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
assert(t1Size < t2Size)
assertJoinBuildSide("SELECT * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft)
assertJoinBuildSide("SELECT * FROM t2 JOIN t1 ON t1.key = t2.key", bh, BuildRight)
assertJoinBuildSide("SELECT * FROM t1 LEFT JOIN t2 ON t1.key = t2.key", bh, BuildRight)
assertJoinBuildSide("SELECT * FROM t2 LEFT JOIN t1 ON t1.key = t2.key", bh, BuildRight)
assertJoinBuildSide("SELECT * FROM t1 RIGHT JOIN t2 ON t1.key = t2.key", bh, BuildLeft)
assertJoinBuildSide("SELECT * FROM t2 RIGHT JOIN t1 ON t1.key = t2.key", bh, BuildLeft)
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
assertJoinBuildSide("SELECT * FROM t1 FULL OUTER JOIN t2", bl, BuildLeft)
assertJoinBuildSide("SELECT * FROM t2 FULL OUTER JOIN t1", bl, BuildRight)
assertJoinBuildSide("SELECT * FROM t1 LEFT JOIN t2", bl, BuildRight)
assertJoinBuildSide("SELECT * FROM t2 LEFT JOIN t1", bl, BuildRight)
assertJoinBuildSide("SELECT * FROM t1 RIGHT JOIN t2", bl, BuildLeft)
assertJoinBuildSide("SELECT * FROM t2 RIGHT JOIN t1", bl, BuildLeft)
}
}
}
private val bh = BroadcastHashJoinExec.toString
private val bl = BroadcastNestedLoopJoinExec.toString
private def assertJoinBuildSide(sqlStr: String, joinMethod: String, buildSide: BuildSide): Any = {
val executedPlan = sql(sqlStr).queryExecution.executedPlan
executedPlan match {
case b: BroadcastNestedLoopJoinExec =>
assert(b.getClass.getSimpleName === joinMethod)
assert(b.buildSide === buildSide)
case b: BroadcastNestedLoopJoinExec =>
assert(b.getClass.getSimpleName === joinMethod)
assert(b.buildSide === buildSide)
case w: WholeStageCodegenExec =>
assert(w.children.head.getClass.getSimpleName === joinMethod)
if (w.children.head.isInstanceOf[BroadcastNestedLoopJoinExec]) {
assert(
w.children.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide === buildSide)
} else if (w.children.head.isInstanceOf[BroadcastHashJoinExec]) {
assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide)
} else {
fail()
}
}
}
}

View file

@ -478,15 +478,22 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
spark.range(10).write.parquet(dir)
spark.read.parquet(dir).createOrReplaceTempView("pqS")
// The executed plan looks like:
// Exchange RoundRobinPartitioning(2)
// +- BroadcastNestedLoopJoin BuildLeft, Cross
// :- BroadcastExchange IdentityBroadcastMode
// : +- Exchange RoundRobinPartitioning(3)
// : +- *Range (0, 30, step=1, splits=2)
// +- *FileScan parquet [id#465L] Batched: true, Format: Parquet, Location: ...(ignored)
val res3 = InputOutputMetricsHelper.run(
spark.range(30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF()
)
// The query above is executed in the following stages:
// 1. sql("select * from pqS") => (10, 0, 10)
// 2. range(30) => (30, 0, 30)
// 3. crossJoin(...) of 1. and 2. => (0, 30, 300)
// 1. range(30) => (30, 0, 30)
// 2. sql("select * from pqS") => (0, 30, 0)
// 3. crossJoin(...) of 1. and 2. => (10, 0, 300)
// 4. shuffle & return results => (0, 300, 0)
assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil)
assert(res3 === (30L, 0L, 30L) :: (0L, 30L, 0L) :: (10L, 0L, 300L) :: (0L, 300L, 0L) :: Nil)
}
}