[SPARK-9730] [SQL] Add Full Outer Join support for SortMergeJoin

This PR is based on #8383 , thanks to viirya

JIRA: https://issues.apache.org/jira/browse/SPARK-9730

This patch adds the Full Outer Join support for SortMergeJoin. A new class SortMergeFullJoinScanner is added to scan rows from left and right iterators. FullOuterIterator is simply a wrapper of type RowIterator to consume joined rows from SortMergeFullJoinScanner.

Closes #8383

Author: Liang-Chi Hsieh <viirya@appier.com>
Author: Davies Liu <davies@databricks.com>

Closes #8579 from davies/smj_fullouter.
This commit is contained in:
Liang-Chi Hsieh 2015-09-09 16:02:27 -07:00 committed by Davies Liu
parent 71da1633c4
commit 45de518742
5 changed files with 259 additions and 34 deletions

View file

@ -32,6 +32,17 @@ class BitSet(numBits: Int) extends Serializable {
*/
def capacity: Int = numWords * 64
/**
* Clear all set bits.
*/
def clear(): Unit = {
var i = 0
while (i < numWords) {
words(i) = 0L
i += 1
}
}
/**
* Set all the bits up to a given index
*/

View file

@ -132,15 +132,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
joins.BroadcastHashOuterJoin(
leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right)
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
joins.SortMergeOuterJoin(
leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil
case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right)
if sqlContext.conf.sortMergeJoinEnabled && RowOrdering.isOrderable(leftKeys) =>
joins.SortMergeOuterJoin(
leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
joins.ShuffledHashOuterJoin(

View file

@ -17,20 +17,21 @@
package org.apache.spark.sql.execution.joins
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
import org.apache.spark.util.collection.BitSet
/**
* :: DeveloperApi ::
* Performs an sort merge outer join of two child relations.
*
* Note: this does not support full outer join yet; see SPARK-9730 for progress on this.
*/
@DeveloperApi
case class SortMergeOuterJoin(
@ -52,6 +53,8 @@ case class SortMergeOuterJoin(
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
(left.output ++ right.output).map(_.withNullability(true))
case x =>
throw new IllegalArgumentException(
s"${getClass.getSimpleName} should not take $x as the JoinType")
@ -62,6 +65,7 @@ case class SortMergeOuterJoin(
// For left and right outer joins, the output is partitioned by the streamed input's join keys.
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
case x =>
throw new IllegalArgumentException(
s"${getClass.getSimpleName} should not take $x as the JoinType")
@ -71,6 +75,8 @@ case class SortMergeOuterJoin(
// For left and right outer joins, the output is ordered by the streamed input's join keys.
case LeftOuter => requiredOrders(leftKeys)
case RightOuter => requiredOrders(rightKeys)
// there are null rows in both streams, so there is no order
case FullOuter => Nil
case x => throw new IllegalArgumentException(
s"SortMergeOuterJoin should not take $x as the JoinType")
}
@ -165,6 +171,26 @@ case class SortMergeOuterJoin(
new RightOuterIterator(
smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala
case FullOuter =>
val leftNullRow = new GenericInternalRow(left.output.length)
val rightNullRow = new GenericInternalRow(right.output.length)
val smjScanner = new SortMergeFullOuterJoinScanner(
leftKeyGenerator = createLeftKeyGenerator(),
rightKeyGenerator = createRightKeyGenerator(),
keyOrdering,
leftIter = RowIterator.fromScala(leftIter),
numLeftRows,
rightIter = RowIterator.fromScala(rightIter),
numRightRows,
boundCondition,
leftNullRow,
rightNullRow)
new FullOuterIterator(
smjScanner,
resultProj,
numOutputRows).toScala
case x =>
throw new IllegalArgumentException(
s"SortMergeOuterJoin should not take $x as the JoinType")
@ -271,3 +297,196 @@ private class RightOuterIterator(
override def getRow: InternalRow = resultProj(joinedRow)
}
private class SortMergeFullOuterJoinScanner(
leftKeyGenerator: Projection,
rightKeyGenerator: Projection,
keyOrdering: Ordering[InternalRow],
leftIter: RowIterator,
numLeftRows: LongSQLMetric,
rightIter: RowIterator,
numRightRows: LongSQLMetric,
boundCondition: InternalRow => Boolean,
leftNullRow: InternalRow,
rightNullRow: InternalRow) {
private[this] val joinedRow: JoinedRow = new JoinedRow()
private[this] var leftRow: InternalRow = _
private[this] var leftRowKey: InternalRow = _
private[this] var rightRow: InternalRow = _
private[this] var rightRowKey: InternalRow = _
private[this] var leftIndex: Int = 0
private[this] var rightIndex: Int = 0
private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
private[this] var leftMatched: BitSet = new BitSet(1)
private[this] var rightMatched: BitSet = new BitSet(1)
advancedLeft()
advancedRight()
// --- Private methods --------------------------------------------------------------------------
/**
* Advance the left iterator and compute the new row's join key.
* @return true if the left iterator returned a row and false otherwise.
*/
private def advancedLeft(): Boolean = {
if (leftIter.advanceNext()) {
leftRow = leftIter.getRow
leftRowKey = leftKeyGenerator(leftRow)
numLeftRows += 1
true
} else {
leftRow = null
leftRowKey = null
false
}
}
/**
* Advance the right iterator and compute the new row's join key.
* @return true if the right iterator returned a row and false otherwise.
*/
private def advancedRight(): Boolean = {
if (rightIter.advanceNext()) {
rightRow = rightIter.getRow
rightRowKey = rightKeyGenerator(rightRow)
numRightRows += 1
true
} else {
rightRow = null
rightRowKey = null
false
}
}
/**
* Populate the left and right buffers with rows matching the provided key.
* This consumes rows from both iterators until their keys are different from the matching key.
*/
private def findMatchingRows(matchingKey: InternalRow): Unit = {
leftMatches.clear()
rightMatches.clear()
leftIndex = 0
rightIndex = 0
while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) {
leftMatches += leftRow.copy()
advancedLeft()
}
while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) {
rightMatches += rightRow.copy()
advancedRight()
}
if (leftMatches.size <= leftMatched.capacity) {
leftMatched.clear()
} else {
leftMatched = new BitSet(leftMatches.size)
}
if (rightMatches.size <= rightMatched.capacity) {
rightMatched.clear()
} else {
rightMatched = new BitSet(rightMatches.size)
}
}
/**
* Scan the left and right buffers for the next valid match.
*
* Note: this method mutates `joinedRow` to point to the latest matching rows in the buffers.
* If a left row has no valid matches on the right, or a right row has no valid matches on the
* left, then the row is joined with the null row and the result is considered a valid match.
*
* @return true if a valid match is found, false otherwise.
*/
private def scanNextInBuffered(): Boolean = {
while (leftIndex < leftMatches.size) {
while (rightIndex < rightMatches.size) {
joinedRow(leftMatches(leftIndex), rightMatches(rightIndex))
if (boundCondition(joinedRow)) {
leftMatched.set(leftIndex)
rightMatched.set(rightIndex)
rightIndex += 1
return true
}
rightIndex += 1
}
rightIndex = 0
if (!leftMatched.get(leftIndex)) {
// the left row has never matched any right row, join it with null row
joinedRow(leftMatches(leftIndex), rightNullRow)
leftIndex += 1
return true
}
leftIndex += 1
}
while (rightIndex < rightMatches.size) {
if (!rightMatched.get(rightIndex)) {
// the right row has never matched any left row, join it with null row
joinedRow(leftNullRow, rightMatches(rightIndex))
rightIndex += 1
return true
}
rightIndex += 1
}
// There are no more valid matches in the left and right buffers
false
}
// --- Public methods --------------------------------------------------------------------------
def getJoinedRow(): JoinedRow = joinedRow
def advanceNext(): Boolean = {
// If we already buffered some matching rows, use them directly
if (leftIndex <= leftMatches.size || rightIndex <= rightMatches.size) {
if (scanNextInBuffered()) {
return true
}
}
if (leftRow != null && (leftRowKey.anyNull || rightRow == null)) {
joinedRow(leftRow.copy(), rightNullRow)
advancedLeft()
true
} else if (rightRow != null && (rightRowKey.anyNull || leftRow == null)) {
joinedRow(leftNullRow, rightRow.copy())
advancedRight()
true
} else if (leftRow != null && rightRow != null) {
// Both rows are present and neither have null values,
// so we populate the buffers with rows matching the next key
val comp = keyOrdering.compare(leftRowKey, rightRowKey)
if (comp <= 0) {
findMatchingRows(leftRowKey.copy())
} else {
findMatchingRows(rightRowKey.copy())
}
scanNextInBuffered()
true
} else {
// Both iterators have been consumed
false
}
}
}
private class FullOuterIterator(
smjScanner: SortMergeFullOuterJoinScanner,
resultProj: InternalRow => InternalRow,
numRows: LongSQLMetric
) extends RowIterator {
private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()
override def advanceNext(): Boolean = {
val r = smjScanner.advanceNext()
if (r) numRows += 1
r
}
override def getRow: InternalRow = resultProj(joinedRow)
}

View file

@ -83,7 +83,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
("SELECT * FROM testData right join testData2 ON key = a and key = 2",
classOf[SortMergeOuterJoin]),
("SELECT * FROM testData full outer join testData2 ON key = a",
classOf[ShuffledHashOuterJoin]),
classOf[SortMergeOuterJoin]),
("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)",

View file

@ -76,37 +76,37 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
test(s"$testName using ShuffledHashOuterJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(sqlContext).apply(
ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(sqlContext).apply(
ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
if (joinType != FullOuter) {
test(s"$testName using BroadcastHashOuterJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
}
test(s"$testName using SortMergeOuterJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(sqlContext).apply(
SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = false)
}
test(s"$testName using SortMergeOuterJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(sqlContext).apply(
SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}