[SPARK-13977] [SQL] Brings back Shuffled hash join

## What changes were proposed in this pull request?

ShuffledHashJoin (also outer join) is removed in 1.6, in favor of SortMergeJoin, which is more robust and also fast.

ShuffledHashJoin is still useful in this case: 1) one table is much smaller than the other one, then cost to build a hash table on smaller table is smaller than sorting the larger table 2) any partition of the small table could fit in memory.

This PR brings back ShuffledHashJoin, basically revert #9645, and fix the conflict. Also merging outer join and left-semi join into the same class. This PR does not implement full outer join, because it's not implemented efficiently (requiring build hash table on both side).

A simple benchmark (one table is 5x smaller than other one) show that ShuffledHashJoin could be 2X faster than SortMergeJoin.

## How was this patch tested?

Added new unit tests for ShuffledHashJoin.

Author: Davies Liu <davies@databricks.com>

Closes #11788 from davies/shuffle_join.
This commit is contained in:
Davies Liu 2016-03-18 10:32:53 -07:00 committed by Davies Liu
parent 14c7236dc6
commit 9c23c818ca
13 changed files with 277 additions and 118 deletions

View file

@ -17,7 +17,6 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@ -29,7 +28,8 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.execution.command.{DescribeCommand => RunnableDescribeCommand, _}
import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _}
import org.apache.spark.sql.execution.datasources.{DescribeCommand => LogicalDescribeCommand, _}
import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
import org.apache.spark.sql.internal.SQLConf
@ -69,8 +69,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right)))
// Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
joins.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
Seq(joins.ShuffledHashJoin(
leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right)))
case _ => Nil
}
}
@ -80,8 +80,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
object CanBroadcast {
def unapply(plan: LogicalPlan): Option[LogicalPlan] = {
if (conf.autoBroadcastJoinThreshold > 0 &&
plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) {
if (plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) {
Some(plan)
} else {
None
@ -101,10 +100,41 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side
* of the join will be broadcasted and the other side will be streamed, with no shuffling
* performed. If both sides of the join are eligible to be broadcasted then the
* - Shuffle hash join: if single partition is small enough to build a hash table.
* - Sort merge: if the matching join keys are sortable.
*/
object EquiJoinSelection extends Strategy with PredicateHelper {
/**
* Matches a plan whose single partition should be small enough to build a hash table.
*/
def canBuildHashMap(plan: LogicalPlan): Boolean = {
plan.statistics.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions
}
/**
* Returns whether plan a is much smaller (3X) than plan b.
*
* The cost to build hash map is higher than sorting, we should only build hash map on a table
* that is much smaller than other one. Since we does not have the statistic for number of rows,
* use the size of bytes here as estimation.
*/
private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = {
a.statistics.sizeInBytes * 3 <= b.statistics.sizeInBytes
}
/**
* Returns whether we should use shuffle hash join or not.
*
* We should only use shuffle hash join when:
* 1) any single partition of a small table could fit in memory.
* 2) the smaller table is much smaller (3X) than the other one.
*/
private def shouldShuffleHashJoin(left: LogicalPlan, right: LogicalPlan): Boolean = {
canBuildHashMap(left) && muchSmaller(left, right) ||
canBuildHashMap(right) && muchSmaller(right, left)
}
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// --- Inner joins --------------------------------------------------------------------------
@ -117,6 +147,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
Seq(joins.BroadcastHashJoin(
leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right)))
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if !conf.preferSortMergeJoin && shouldShuffleHashJoin(left, right) ||
!RowOrdering.isOrderable(leftKeys) =>
val buildSide =
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
BuildRight
} else {
BuildLeft
}
Seq(joins.ShuffledHashJoin(
leftKeys, rightKeys, Inner, buildSide, condition, planLater(left), planLater(right)))
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
joins.SortMergeJoin(
@ -134,6 +176,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
Seq(joins.BroadcastHashJoin(
leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right)))
case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right)
if !conf.preferSortMergeJoin && canBuildHashMap(right) && muchSmaller(right, left) ||
!RowOrdering.isOrderable(leftKeys) =>
Seq(joins.ShuffledHashJoin(
leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right)))
case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right)
if !conf.preferSortMergeJoin && canBuildHashMap(left) && muchSmaller(left, right) ||
!RowOrdering.isOrderable(leftKeys) =>
Seq(joins.ShuffledHashJoin(
leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right)))
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
joins.SortMergeJoin(

View file

@ -1,58 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.joins
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.LeftSemi
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* Build the right table's join keys into a HashedRelation, and iteratively go through the left
* table, to find if the join keys are in the HashedRelation.
*/
case class LeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan,
condition: Option[Expression]) extends BinaryNode with HashJoin {
override val joinType = LeftSemi
override val buildSide = BuildRight
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
override def outputPartitioning: Partitioning = left.outputPartitioning
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
val hashRelation = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator)
hashSemiJoin(streamIter, hashRelation, numOutputRows)
}
}
}

View file

@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.joins
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* Performs an inner hash join of two child relations by first shuffling the data using the join
* keys.
*/
case class ShuffledHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
buildSide: BuildSide,
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan)
extends BinaryNode with HashJoin {
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
override def outputPartitioning: Partitioning = joinType match {
case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
case LeftSemi => left.outputPartitioning
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
case x =>
throw new IllegalArgumentException(s"ShuffledHashJoin should not take $x as the JoinType")
}
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
val hashed = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator)
val joinedRow = new JoinedRow
joinType match {
case Inner =>
hashJoin(streamIter, hashed, numOutputRows)
case LeftSemi =>
hashSemiJoin(streamIter, hashed, numOutputRows)
case LeftOuter =>
val keyGenerator = streamSideKeyGenerator
val resultProj = createResultProjection
streamIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows)
})
case RightOuter =>
val keyGenerator = streamSideKeyGenerator
val resultProj = createResultProjection
streamIter.flatMap(currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)
rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows)
})
case x =>
throw new IllegalArgumentException(
s"ShuffledHashJoin should not take $x as the JoinType")
}
}
}
}

View file

@ -665,11 +665,11 @@ private[joins] class SortMergeJoinScanner(
* An iterator for outputting rows in left outer join.
*/
private class LeftOuterIterator(
smjScanner: SortMergeJoinScanner,
rightNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
numOutputRows: LongSQLMetric)
smjScanner: SortMergeJoinScanner,
rightNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
numOutputRows: LongSQLMetric)
extends OneSideOuterIterator(
smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) {
@ -681,13 +681,12 @@ private class LeftOuterIterator(
* An iterator for outputting rows in right outer join.
*/
private class RightOuterIterator(
smjScanner: SortMergeJoinScanner,
leftNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
numOutputRows: LongSQLMetric)
extends OneSideOuterIterator(
smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
smjScanner: SortMergeJoinScanner,
leftNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
numOutputRows: LongSQLMetric)
extends OneSideOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
@ -710,11 +709,11 @@ private class RightOuterIterator(
* @param numOutputRows an accumulator metric for the number of rows output
*/
private abstract class OneSideOuterIterator(
smjScanner: SortMergeJoinScanner,
bufferedSideNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
numOutputRows: LongSQLMetric) extends RowIterator {
smjScanner: SortMergeJoinScanner,
bufferedSideNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
numOutputRows: LongSQLMetric) extends RowIterator {
// A row to store the joined result, reused many times
protected[this] val joinedRow: JoinedRow = new JoinedRow()
@ -777,14 +776,14 @@ private abstract class OneSideOuterIterator(
}
private class SortMergeFullOuterJoinScanner(
leftKeyGenerator: Projection,
rightKeyGenerator: Projection,
keyOrdering: Ordering[InternalRow],
leftIter: RowIterator,
rightIter: RowIterator,
boundCondition: InternalRow => Boolean,
leftNullRow: InternalRow,
rightNullRow: InternalRow) {
leftKeyGenerator: Projection,
rightKeyGenerator: Projection,
keyOrdering: Ordering[InternalRow],
leftIter: RowIterator,
rightIter: RowIterator,
boundCondition: InternalRow => Boolean,
leftNullRow: InternalRow,
rightNullRow: InternalRow) {
private[this] val joinedRow: JoinedRow = new JoinedRow()
private[this] var leftRow: InternalRow = _
private[this] var leftRowKey: InternalRow = _
@ -950,10 +949,9 @@ private class SortMergeFullOuterJoinScanner(
}
private class FullOuterIterator(
smjScanner: SortMergeFullOuterJoinScanner,
resultProj: InternalRow => InternalRow,
numRows: LongSQLMetric
) extends RowIterator {
smjScanner: SortMergeFullOuterJoinScanner,
resultProj: InternalRow => InternalRow,
numRows: LongSQLMetric) extends RowIterator {
private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()
override def advanceNext(): Boolean = {

View file

@ -236,6 +236,11 @@ object SQLConf {
doc = "When true, enable partition pruning for in-memory columnar tables.",
isPublic = false)
val PREFER_SORTMERGEJOIN = booleanConf("spark.sql.join.preferSortMergeJoin",
defaultValue = Some(true),
doc = "When true, prefer sort merge join over shuffle hash join",
isPublic = false)
val AUTO_BROADCASTJOIN_THRESHOLD = intConf("spark.sql.autoBroadcastJoinThreshold",
defaultValue = Some(10 * 1024 * 1024),
doc = "Configures the maximum size in bytes for a table that will be broadcast to all worker " +
@ -586,6 +591,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin
def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN)
def defaultSizeInBytes: Long =
getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L)

View file

@ -45,8 +45,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
val df = sql(sqlString)
val physical = df.queryExecution.sparkPlan
val operators = physical.collect {
case j: LeftSemiJoinHash => j
case j: BroadcastHashJoin => j
case j: ShuffledHashJoin => j
case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j
case j: SortMergeJoin => j
@ -63,7 +63,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]),
("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
@ -434,7 +434,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash])
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin])
).foreach {
case (query, joinClass) => assertJoin(query, joinClass)
}
@ -460,7 +460,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
classOf[LeftSemiJoinHash]),
classOf[ShuffledHashJoin]),
("SELECT * FROM testData LEFT SEMI JOIN testData2",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData JOIN testData2",

View file

@ -247,7 +247,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
*/
}
ignore("rube") {
ignore("shuffle hash join") {
val N = 4 << 20
sqlContext.setConf("spark.sql.shuffle.partitions", "2")
sqlContext.setConf("spark.sql.autoBroadcastJoinThreshold", "10000000")
sqlContext.setConf("spark.sql.join.preferSortMergeJoin", "false")
runBenchmark("shuffle hash join", N) {
val df1 = sqlContext.range(N).selectExpr(s"id as k1")
val df2 = sqlContext.range(N / 5).selectExpr(s"id * 3 as k2")
df1.join(df2, col("k1") === col("k2")).count()
}
/**
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
shuffle hash join codegen=false 1168 / 1902 3.6 278.6 1.0X
shuffle hash join codegen=true 850 / 1196 4.9 202.8 1.4X
*/
}
ignore("cube") {
val N = 5 << 20
runBenchmark("cube", N) {

View file

@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchange, ReuseExchange, ShuffleExchange}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin, SortMergeJoin}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
@ -143,7 +143,7 @@ class PlannerSuite extends SharedSQLContext {
val sortMergeJoins = planned.collect { case join: SortMergeJoin => join }
assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
assert(sortMergeJoins.isEmpty, "Should not use sort merge join")
assert(sortMergeJoins.isEmpty, "Should not use shuffled hash join")
sqlContext.clearCache()
}

View file

@ -101,6 +101,20 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
EnsureRequirements(sqlContext.sessionState.conf).apply(broadcastJoin)
}
def makeShuffledHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
boundCondition: Option[Expression],
leftPlan: SparkPlan,
rightPlan: SparkPlan,
side: BuildSide) = {
val shuffledHashJoin =
joins.ShuffledHashJoin(leftKeys, rightKeys, Inner, side, None, leftPlan, rightPlan)
val filteredJoin =
boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
EnsureRequirements(sqlContext.sessionState.conf).apply(filteredJoin)
}
def makeSortMergeJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
@ -136,6 +150,30 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
}
}
test(s"$testName using ShuffledHashJoin (build=left)") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeShuffledHashJoin(
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
test(s"$testName using ShuffledHashJoin (build=right)") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeShuffledHashJoin(
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
test(s"$testName using SortMergeJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {

View file

@ -76,6 +76,22 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
ExtractEquiJoinKeys.unapply(join)
}
if (joinType != FullOuter) {
test(s"$testName using ShuffledHashJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(sqlContext.sessionState.conf).apply(
ShuffledHashJoin(
leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
}
if (joinType != FullOuter) {
test(s"$testName using BroadcastHashJoin") {
val buildSide = joinType match {

View file

@ -72,12 +72,13 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
ExtractEquiJoinKeys.unapply(join)
}
test(s"$testName using LeftSemiJoinHash") {
test(s"$testName using ShuffledHashJoin") {
extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(left.sqlContext.sessionState.conf).apply(
LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
ShuffledHashJoin(
leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}

View file

@ -263,32 +263,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
)
}
test("LeftSemiJoinHash metrics") {
test("ShuffledHashJoin metrics") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value")
// Assume the execution plan is
// ... -> LeftSemiJoinHash(nodeId = 0)
// ... -> ShuffledHashJoin(nodeId = 0)
val df = df1.join(df2, $"key" === $"key2", "leftsemi")
testSparkPlanMetrics(df, 1, Map(
0L -> ("LeftSemiJoinHash", Map(
0L -> ("ShuffledHashJoin", Map(
"number of output rows" -> 2L)))
)
}
}
test("LeftSemiJoinBNL metrics") {
val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value")
// Assume the execution plan is
// ... -> LeftSemiJoinBNL(nodeId = 0)
val df = df1.join(df2, $"key" < $"key2", "leftsemi")
testSparkPlanMetrics(df, 2, Map(
0L -> ("LeftSemiJoinBNL", Map(
"number of output rows" -> 2L)))
)
}
test("CartesianProduct metrics") {
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")

View file

@ -230,7 +230,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton {
assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off")
val shj = df.queryExecution.sparkPlan.collect {
case j: LeftSemiJoinHash => j
case j: ShuffledHashJoin => j
}
assert(shj.size === 1,
"LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off")