[SPARK-9730] [SQL] Add Full Outer Join support for SortMergeJoin
This PR is based on #8383 , thanks to viirya JIRA: https://issues.apache.org/jira/browse/SPARK-9730 This patch adds the Full Outer Join support for SortMergeJoin. A new class SortMergeFullJoinScanner is added to scan rows from left and right iterators. FullOuterIterator is simply a wrapper of type RowIterator to consume joined rows from SortMergeFullJoinScanner. Closes #8383 Author: Liang-Chi Hsieh <viirya@appier.com> Author: Davies Liu <davies@databricks.com> Closes #8579 from davies/smj_fullouter.
This commit is contained in:
parent
71da1633c4
commit
45de518742
|
@ -32,6 +32,17 @@ class BitSet(numBits: Int) extends Serializable {
|
|||
*/
|
||||
def capacity: Int = numWords * 64
|
||||
|
||||
/**
|
||||
* Clear all set bits.
|
||||
*/
|
||||
def clear(): Unit = {
|
||||
var i = 0
|
||||
while (i < numWords) {
|
||||
words(i) = 0L
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Set all the bits up to a given index
|
||||
*/
|
||||
|
|
|
@ -132,15 +132,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
joins.BroadcastHashOuterJoin(
|
||||
leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
|
||||
|
||||
case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right)
|
||||
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
|
||||
if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
|
||||
joins.SortMergeOuterJoin(
|
||||
leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil
|
||||
|
||||
case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right)
|
||||
if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
|
||||
joins.SortMergeOuterJoin(
|
||||
leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
|
||||
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
|
||||
|
||||
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
|
||||
joins.ShuffledHashOuterJoin(
|
||||
|
|
|
@ -17,20 +17,21 @@
|
|||
|
||||
package org.apache.spark.sql.execution.joins
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.annotation.DeveloperApi
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
|
||||
import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
|
||||
import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
|
||||
import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
|
||||
import org.apache.spark.util.collection.BitSet
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
* Performs an sort merge outer join of two child relations.
|
||||
*
|
||||
* Note: this does not support full outer join yet; see SPARK-9730 for progress on this.
|
||||
*/
|
||||
@DeveloperApi
|
||||
case class SortMergeOuterJoin(
|
||||
|
@ -52,6 +53,8 @@ case class SortMergeOuterJoin(
|
|||
left.output ++ right.output.map(_.withNullability(true))
|
||||
case RightOuter =>
|
||||
left.output.map(_.withNullability(true)) ++ right.output
|
||||
case FullOuter =>
|
||||
(left.output ++ right.output).map(_.withNullability(true))
|
||||
case x =>
|
||||
throw new IllegalArgumentException(
|
||||
s"${getClass.getSimpleName} should not take $x as the JoinType")
|
||||
|
@ -62,6 +65,7 @@ case class SortMergeOuterJoin(
|
|||
// For left and right outer joins, the output is partitioned by the streamed input's join keys.
|
||||
case LeftOuter => left.outputPartitioning
|
||||
case RightOuter => right.outputPartitioning
|
||||
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
|
||||
case x =>
|
||||
throw new IllegalArgumentException(
|
||||
s"${getClass.getSimpleName} should not take $x as the JoinType")
|
||||
|
@ -71,6 +75,8 @@ case class SortMergeOuterJoin(
|
|||
// For left and right outer joins, the output is ordered by the streamed input's join keys.
|
||||
case LeftOuter => requiredOrders(leftKeys)
|
||||
case RightOuter => requiredOrders(rightKeys)
|
||||
// there are null rows in both streams, so there is no order
|
||||
case FullOuter => Nil
|
||||
case x => throw new IllegalArgumentException(
|
||||
s"SortMergeOuterJoin should not take $x as the JoinType")
|
||||
}
|
||||
|
@ -165,6 +171,26 @@ case class SortMergeOuterJoin(
|
|||
new RightOuterIterator(
|
||||
smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala
|
||||
|
||||
case FullOuter =>
|
||||
val leftNullRow = new GenericInternalRow(left.output.length)
|
||||
val rightNullRow = new GenericInternalRow(right.output.length)
|
||||
val smjScanner = new SortMergeFullOuterJoinScanner(
|
||||
leftKeyGenerator = createLeftKeyGenerator(),
|
||||
rightKeyGenerator = createRightKeyGenerator(),
|
||||
keyOrdering,
|
||||
leftIter = RowIterator.fromScala(leftIter),
|
||||
numLeftRows,
|
||||
rightIter = RowIterator.fromScala(rightIter),
|
||||
numRightRows,
|
||||
boundCondition,
|
||||
leftNullRow,
|
||||
rightNullRow)
|
||||
|
||||
new FullOuterIterator(
|
||||
smjScanner,
|
||||
resultProj,
|
||||
numOutputRows).toScala
|
||||
|
||||
case x =>
|
||||
throw new IllegalArgumentException(
|
||||
s"SortMergeOuterJoin should not take $x as the JoinType")
|
||||
|
@ -271,3 +297,196 @@ private class RightOuterIterator(
|
|||
|
||||
override def getRow: InternalRow = resultProj(joinedRow)
|
||||
}
|
||||
|
||||
private class SortMergeFullOuterJoinScanner(
|
||||
leftKeyGenerator: Projection,
|
||||
rightKeyGenerator: Projection,
|
||||
keyOrdering: Ordering[InternalRow],
|
||||
leftIter: RowIterator,
|
||||
numLeftRows: LongSQLMetric,
|
||||
rightIter: RowIterator,
|
||||
numRightRows: LongSQLMetric,
|
||||
boundCondition: InternalRow => Boolean,
|
||||
leftNullRow: InternalRow,
|
||||
rightNullRow: InternalRow) {
|
||||
private[this] val joinedRow: JoinedRow = new JoinedRow()
|
||||
private[this] var leftRow: InternalRow = _
|
||||
private[this] var leftRowKey: InternalRow = _
|
||||
private[this] var rightRow: InternalRow = _
|
||||
private[this] var rightRowKey: InternalRow = _
|
||||
|
||||
private[this] var leftIndex: Int = 0
|
||||
private[this] var rightIndex: Int = 0
|
||||
private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
|
||||
private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
|
||||
private[this] var leftMatched: BitSet = new BitSet(1)
|
||||
private[this] var rightMatched: BitSet = new BitSet(1)
|
||||
|
||||
advancedLeft()
|
||||
advancedRight()
|
||||
|
||||
// --- Private methods --------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Advance the left iterator and compute the new row's join key.
|
||||
* @return true if the left iterator returned a row and false otherwise.
|
||||
*/
|
||||
private def advancedLeft(): Boolean = {
|
||||
if (leftIter.advanceNext()) {
|
||||
leftRow = leftIter.getRow
|
||||
leftRowKey = leftKeyGenerator(leftRow)
|
||||
numLeftRows += 1
|
||||
true
|
||||
} else {
|
||||
leftRow = null
|
||||
leftRowKey = null
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Advance the right iterator and compute the new row's join key.
|
||||
* @return true if the right iterator returned a row and false otherwise.
|
||||
*/
|
||||
private def advancedRight(): Boolean = {
|
||||
if (rightIter.advanceNext()) {
|
||||
rightRow = rightIter.getRow
|
||||
rightRowKey = rightKeyGenerator(rightRow)
|
||||
numRightRows += 1
|
||||
true
|
||||
} else {
|
||||
rightRow = null
|
||||
rightRowKey = null
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Populate the left and right buffers with rows matching the provided key.
|
||||
* This consumes rows from both iterators until their keys are different from the matching key.
|
||||
*/
|
||||
private def findMatchingRows(matchingKey: InternalRow): Unit = {
|
||||
leftMatches.clear()
|
||||
rightMatches.clear()
|
||||
leftIndex = 0
|
||||
rightIndex = 0
|
||||
|
||||
while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) {
|
||||
leftMatches += leftRow.copy()
|
||||
advancedLeft()
|
||||
}
|
||||
while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) {
|
||||
rightMatches += rightRow.copy()
|
||||
advancedRight()
|
||||
}
|
||||
|
||||
if (leftMatches.size <= leftMatched.capacity) {
|
||||
leftMatched.clear()
|
||||
} else {
|
||||
leftMatched = new BitSet(leftMatches.size)
|
||||
}
|
||||
if (rightMatches.size <= rightMatched.capacity) {
|
||||
rightMatched.clear()
|
||||
} else {
|
||||
rightMatched = new BitSet(rightMatches.size)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Scan the left and right buffers for the next valid match.
|
||||
*
|
||||
* Note: this method mutates `joinedRow` to point to the latest matching rows in the buffers.
|
||||
* If a left row has no valid matches on the right, or a right row has no valid matches on the
|
||||
* left, then the row is joined with the null row and the result is considered a valid match.
|
||||
*
|
||||
* @return true if a valid match is found, false otherwise.
|
||||
*/
|
||||
private def scanNextInBuffered(): Boolean = {
|
||||
while (leftIndex < leftMatches.size) {
|
||||
while (rightIndex < rightMatches.size) {
|
||||
joinedRow(leftMatches(leftIndex), rightMatches(rightIndex))
|
||||
if (boundCondition(joinedRow)) {
|
||||
leftMatched.set(leftIndex)
|
||||
rightMatched.set(rightIndex)
|
||||
rightIndex += 1
|
||||
return true
|
||||
}
|
||||
rightIndex += 1
|
||||
}
|
||||
rightIndex = 0
|
||||
if (!leftMatched.get(leftIndex)) {
|
||||
// the left row has never matched any right row, join it with null row
|
||||
joinedRow(leftMatches(leftIndex), rightNullRow)
|
||||
leftIndex += 1
|
||||
return true
|
||||
}
|
||||
leftIndex += 1
|
||||
}
|
||||
|
||||
while (rightIndex < rightMatches.size) {
|
||||
if (!rightMatched.get(rightIndex)) {
|
||||
// the right row has never matched any left row, join it with null row
|
||||
joinedRow(leftNullRow, rightMatches(rightIndex))
|
||||
rightIndex += 1
|
||||
return true
|
||||
}
|
||||
rightIndex += 1
|
||||
}
|
||||
|
||||
// There are no more valid matches in the left and right buffers
|
||||
false
|
||||
}
|
||||
|
||||
// --- Public methods --------------------------------------------------------------------------
|
||||
|
||||
def getJoinedRow(): JoinedRow = joinedRow
|
||||
|
||||
def advanceNext(): Boolean = {
|
||||
// If we already buffered some matching rows, use them directly
|
||||
if (leftIndex <= leftMatches.size || rightIndex <= rightMatches.size) {
|
||||
if (scanNextInBuffered()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if (leftRow != null && (leftRowKey.anyNull || rightRow == null)) {
|
||||
joinedRow(leftRow.copy(), rightNullRow)
|
||||
advancedLeft()
|
||||
true
|
||||
} else if (rightRow != null && (rightRowKey.anyNull || leftRow == null)) {
|
||||
joinedRow(leftNullRow, rightRow.copy())
|
||||
advancedRight()
|
||||
true
|
||||
} else if (leftRow != null && rightRow != null) {
|
||||
// Both rows are present and neither have null values,
|
||||
// so we populate the buffers with rows matching the next key
|
||||
val comp = keyOrdering.compare(leftRowKey, rightRowKey)
|
||||
if (comp <= 0) {
|
||||
findMatchingRows(leftRowKey.copy())
|
||||
} else {
|
||||
findMatchingRows(rightRowKey.copy())
|
||||
}
|
||||
scanNextInBuffered()
|
||||
true
|
||||
} else {
|
||||
// Both iterators have been consumed
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class FullOuterIterator(
|
||||
smjScanner: SortMergeFullOuterJoinScanner,
|
||||
resultProj: InternalRow => InternalRow,
|
||||
numRows: LongSQLMetric
|
||||
) extends RowIterator {
|
||||
private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()
|
||||
|
||||
override def advanceNext(): Boolean = {
|
||||
val r = smjScanner.advanceNext()
|
||||
if (r) numRows += 1
|
||||
r
|
||||
}
|
||||
|
||||
override def getRow: InternalRow = resultProj(joinedRow)
|
||||
}
|
||||
|
|
|
@ -83,7 +83,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
|
|||
("SELECT * FROM testData right join testData2 ON key = a and key = 2",
|
||||
classOf[SortMergeOuterJoin]),
|
||||
("SELECT * FROM testData full outer join testData2 ON key = a",
|
||||
classOf[ShuffledHashOuterJoin]),
|
||||
classOf[SortMergeOuterJoin]),
|
||||
("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)",
|
||||
classOf[BroadcastNestedLoopJoin]),
|
||||
("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)",
|
||||
|
|
|
@ -76,37 +76,37 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
|
|||
|
||||
test(s"$testName using ShuffledHashOuterJoin") {
|
||||
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
EnsureRequirements(sqlContext).apply(
|
||||
ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
EnsureRequirements(sqlContext).apply(
|
||||
ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (joinType != FullOuter) {
|
||||
test(s"$testName using BroadcastHashOuterJoin") {
|
||||
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$testName using SortMergeOuterJoin") {
|
||||
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
EnsureRequirements(sqlContext).apply(
|
||||
SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = false)
|
||||
}
|
||||
test(s"$testName using SortMergeOuterJoin") {
|
||||
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
EnsureRequirements(sqlContext).apply(
|
||||
SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue