diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6b3f301aad..0ed7c2f555 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -158,45 +158,65 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def smallerSide = if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft - val buildRight = canBuildRight && right.stats.hints.broadcast - val buildLeft = canBuildLeft && left.stats.hints.broadcast - - if (buildRight && buildLeft) { + if (canBuildRight && canBuildLeft) { // Broadcast smaller side base on its estimated physical size // if both sides have broadcast hint smallerSide - } else if (buildRight) { + } else if (canBuildRight) { BuildRight - } else if (buildLeft) { + } else if (canBuildLeft) { BuildLeft - } else if (canBuildRight && canBuildLeft) { + } else { // for the last default broadcast nested loop join smallerSide - } else { - throw new AnalysisException("Can not decide which side to broadcast for this join") } } + private def canBroadcastByHints(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) + : Boolean = { + val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast + val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast + buildLeft || buildRight + } + + private def broadcastSideByHints(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) + : BuildSide = { + val buildLeft = canBuildLeft(joinType) && left.stats.hints.broadcast + val buildRight = canBuildRight(joinType) && right.stats.hints.broadcast + broadcastSide(buildLeft, buildRight, left, right) + } + + private def canBroadcastBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) + : Boolean = { + val buildLeft = canBuildLeft(joinType) && canBroadcast(left) + val buildRight = canBuildRight(joinType) && canBroadcast(right) + buildLeft || buildRight + } + + private def broadcastSideBySizes(joinType: JoinType, left: LogicalPlan, right: LogicalPlan) + : BuildSide = { + val buildLeft = canBuildLeft(joinType) && canBroadcast(left) + val buildRight = canBuildRight(joinType) && canBroadcast(right) + broadcastSide(buildLeft, buildRight, left, right) + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // --- BroadcastHashJoin -------------------------------------------------------------------- + // broadcast hints were specified case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) - if (canBuildRight(joinType) && right.stats.hints.broadcast) || - (canBuildLeft(joinType) && left.stats.hints.broadcast) => - val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right) + if canBroadcastByHints(joinType, left, right) => + val buildSide = broadcastSideByHints(joinType, left, right) Seq(joins.BroadcastHashJoinExec( leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + // broadcast hints were not specified, so need to infer it from size and configuration. case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) - if canBuildRight(joinType) && canBroadcast(right) => + if canBroadcastBySizes(joinType, left, right) => + val buildSide = broadcastSideBySizes(joinType, left, right) Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) - - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) - if canBuildLeft(joinType) && canBroadcast(left) => - Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) // --- ShuffledHashJoin --------------------------------------------------------------------- @@ -225,27 +245,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Pick BroadcastNestedLoopJoin if one side could be broadcasted case j @ logical.Join(left, right, joinType, condition) - if (canBuildRight(joinType) && right.stats.hints.broadcast) || - (canBuildLeft(joinType) && left.stats.hints.broadcast) => - val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right) + if canBroadcastByHints(joinType, left, right) => + val buildSide = broadcastSideByHints(joinType, left, right) joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case j @ logical.Join(left, right, joinType, condition) - if canBuildRight(joinType) && canBroadcast(right) => + if canBroadcastBySizes(joinType, left, right) => + val buildSide = broadcastSideBySizes(joinType, left, right) joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), BuildRight, joinType, condition) :: Nil - case j @ logical.Join(left, right, joinType, condition) - if canBuildLeft(joinType) && canBroadcast(left) => - joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), BuildLeft, joinType, condition) :: Nil + planLater(left), planLater(right), buildSide, joinType, condition) :: Nil // Pick CartesianProduct for InnerJoin case logical.Join(left, right, _: InnerLike, condition) => joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil case logical.Join(left, right, joinType, condition) => - val buildSide = broadcastSide(canBuildLeft = true, canBuildRight = true, left, right) + val buildSide = broadcastSide( + left.stats.hints.broadcast, right.stats.hints.broadcast, left, right) // This join could be very slow or OOM joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 67e2cdc739..6da46ea348 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -225,17 +225,6 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } test("Shouldn't change broadcast join buildSide if user clearly specified") { - def assertJoinBuildSide(sqlStr: String, joinMethod: String, buildSide: BuildSide): Any = { - val executedPlan = sql(sqlStr).queryExecution.executedPlan - executedPlan match { - case b: BroadcastNestedLoopJoinExec => - assert(b.getClass.getSimpleName === joinMethod) - assert(b.buildSide === buildSide) - case w: WholeStageCodegenExec => - assert(w.children.head.getClass.getSimpleName === joinMethod) - assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide) - } - } withTempView("t1", "t2") { spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") @@ -246,9 +235,6 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes assert(t1Size < t2Size) - val bh = BroadcastHashJoinExec.toString - val bl = BroadcastNestedLoopJoinExec.toString - // INNER JOIN && t1Size < t2Size => BuildLeft assertJoinBuildSide( "SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft) @@ -266,8 +252,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { "SELECT /*+ MAPJOIN(t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildRight) - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0", - SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { // INNER JOIN && t1Size < t2Size => BuildLeft assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2", bl, BuildLeft) // FULL JOIN && t1Size < t2Size => BuildLeft @@ -290,4 +275,62 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } } + + test("Shouldn't bias towards build right if user didn't specify") { + + withTempView("t1", "t2") { + spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1") + spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value") + .createTempView("t2") + + val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes + val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes + assert(t1Size < t2Size) + + assertJoinBuildSide("SELECT * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft) + assertJoinBuildSide("SELECT * FROM t2 JOIN t1 ON t1.key = t2.key", bh, BuildRight) + + assertJoinBuildSide("SELECT * FROM t1 LEFT JOIN t2 ON t1.key = t2.key", bh, BuildRight) + assertJoinBuildSide("SELECT * FROM t2 LEFT JOIN t1 ON t1.key = t2.key", bh, BuildRight) + + assertJoinBuildSide("SELECT * FROM t1 RIGHT JOIN t2 ON t1.key = t2.key", bh, BuildLeft) + assertJoinBuildSide("SELECT * FROM t2 RIGHT JOIN t1 ON t1.key = t2.key", bh, BuildLeft) + + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + assertJoinBuildSide("SELECT * FROM t1 FULL OUTER JOIN t2", bl, BuildLeft) + assertJoinBuildSide("SELECT * FROM t2 FULL OUTER JOIN t1", bl, BuildRight) + + assertJoinBuildSide("SELECT * FROM t1 LEFT JOIN t2", bl, BuildRight) + assertJoinBuildSide("SELECT * FROM t2 LEFT JOIN t1", bl, BuildRight) + + assertJoinBuildSide("SELECT * FROM t1 RIGHT JOIN t2", bl, BuildLeft) + assertJoinBuildSide("SELECT * FROM t2 RIGHT JOIN t1", bl, BuildLeft) + } + } + } + + private val bh = BroadcastHashJoinExec.toString + private val bl = BroadcastNestedLoopJoinExec.toString + + private def assertJoinBuildSide(sqlStr: String, joinMethod: String, buildSide: BuildSide): Any = { + val executedPlan = sql(sqlStr).queryExecution.executedPlan + executedPlan match { + case b: BroadcastNestedLoopJoinExec => + assert(b.getClass.getSimpleName === joinMethod) + assert(b.buildSide === buildSide) + case b: BroadcastNestedLoopJoinExec => + assert(b.getClass.getSimpleName === joinMethod) + assert(b.buildSide === buildSide) + case w: WholeStageCodegenExec => + assert(w.children.head.getClass.getSimpleName === joinMethod) + if (w.children.head.isInstanceOf[BroadcastNestedLoopJoinExec]) { + assert( + w.children.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide === buildSide) + } else if (w.children.head.isInstanceOf[BroadcastHashJoinExec]) { + assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide) + } else { + fail() + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index fc3483379c..a3a3f3851e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -478,15 +478,22 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared spark.range(10).write.parquet(dir) spark.read.parquet(dir).createOrReplaceTempView("pqS") + // The executed plan looks like: + // Exchange RoundRobinPartitioning(2) + // +- BroadcastNestedLoopJoin BuildLeft, Cross + // :- BroadcastExchange IdentityBroadcastMode + // : +- Exchange RoundRobinPartitioning(3) + // : +- *Range (0, 30, step=1, splits=2) + // +- *FileScan parquet [id#465L] Batched: true, Format: Parquet, Location: ...(ignored) val res3 = InputOutputMetricsHelper.run( spark.range(30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF() ) // The query above is executed in the following stages: - // 1. sql("select * from pqS") => (10, 0, 10) - // 2. range(30) => (30, 0, 30) - // 3. crossJoin(...) of 1. and 2. => (0, 30, 300) + // 1. range(30) => (30, 0, 30) + // 2. sql("select * from pqS") => (0, 30, 0) + // 3. crossJoin(...) of 1. and 2. => (10, 0, 300) // 4. shuffle & return results => (0, 300, 0) - assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil) + assert(res3 === (30L, 0L, 30L) :: (0L, 30L, 0L) :: (10L, 0L, 300L) :: (0L, 300L, 0L) :: Nil) } }