[SPARK-13422][SQL] Use HashedRelation instead of HashSet in Left Semi Joins

Use the HashedRelation which is a more optimized datastructure and reduce code complexity

Author: Xiu Guo <xguo27@gmail.com>

Closes #11291 from xguo27/SPARK-13422.
This commit is contained in:
Xiu Guo 2016-02-22 16:34:02 -08:00 committed by Reynold Xin
parent 173aa949c3
commit 2063781840
3 changed files with 14 additions and 81 deletions

View file

@ -26,8 +26,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* Build the right table's join keys into a HashSet, and iteratively go through the left
* table, to find the if join keys are in the Hash set.
* Build the right table's join keys into a HashedRelation, and iteratively go through the left
* table, to find if the join keys are in the HashedRelation.
*/
case class BroadcastLeftSemiJoinHash(
leftKeys: Seq[Expression],
@ -40,29 +40,18 @@ case class BroadcastLeftSemiJoinHash(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
override def requiredChildDistribution: Seq[Distribution] = {
val mode = if (condition.isEmpty) {
HashSetBroadcastMode(rightKeys, right.output)
} else {
HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output)
}
val mode = HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output)
UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
}
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
if (condition.isEmpty) {
val broadcastedRelation = right.executeBroadcast[java.util.Set[InternalRow]]()
left.execute().mapPartitionsInternal { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value, numOutputRows)
}
} else {
val broadcastedRelation = right.executeBroadcast[HashedRelation]()
left.execute().mapPartitionsInternal { streamIter =>
val hashedRelation = broadcastedRelation.value
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
hashSemiJoin(streamIter, hashedRelation, numOutputRows)
}
val broadcastedRelation = right.executeBroadcast[HashedRelation]()
left.execute().mapPartitionsInternal { streamIter =>
val hashedRelation = broadcastedRelation.value
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
hashSemiJoin(streamIter, hashedRelation, numOutputRows)
}
}
}

View file

@ -43,24 +43,6 @@ trait HashSemiJoin {
@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
protected def buildKeyHashSet(
buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = {
HashSemiJoin.buildKeyHashSet(rightKeys, right.output, buildIter)
}
protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashSet: java.util.Set[InternalRow],
numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
val joinKeys = leftKeyGenerator
streamIter.filter(current => {
val key = joinKeys(current)
val r = !key.anyNull && hashSet.contains(key)
if (r) numOutputRows += 1
r
})
}
protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashedRelation: HashedRelation,
@ -70,44 +52,11 @@ trait HashSemiJoin {
streamIter.filter { current =>
val key = joinKeys(current)
lazy val rowBuffer = hashedRelation.get(key)
val r = !key.anyNull && rowBuffer != null && rowBuffer.exists {
val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists {
(row: InternalRow) => boundCondition(joinedRow(current, row))
}
})
if (r) numOutputRows += 1
r
}
}
}
private[execution] object HashSemiJoin {
def buildKeyHashSet(
keys: Seq[Expression],
attributes: Seq[Attribute],
rows: Iterator[InternalRow]): java.util.HashSet[InternalRow] = {
val hashSet = new java.util.HashSet[InternalRow]()
// Create a Hash set of buildKeys
val key = UnsafeProjection.create(keys, attributes)
while (rows.hasNext) {
val currentRow = rows.next()
val rowKey = key(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
hashSet.add(rowKey.copy())
}
}
}
hashSet
}
}
/** HashSetBroadcastMode requires that the input rows are broadcasted as a set. */
private[execution] case class HashSetBroadcastMode(
keys: Seq[Expression],
attributes: Seq[Attribute]) extends BroadcastMode {
override def transform(rows: Array[InternalRow]): java.util.HashSet[InternalRow] = {
HashSemiJoin.buildKeyHashSet(keys, attributes, rows.iterator)
}
}

View file

@ -25,8 +25,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* Build the right table's join keys into a HashSet, and iteratively go through the left
* table, to find the if join keys are in the Hash set.
* Build the right table's join keys into a HashedRelation, and iteratively go through the left
* table, to find if the join keys are in the HashedRelation.
*/
case class LeftSemiJoinHash(
leftKeys: Seq[Expression],
@ -47,13 +47,8 @@ case class LeftSemiJoinHash(
val numOutputRows = longMetric("numOutputRows")
right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
if (condition.isEmpty) {
val hashSet = buildKeyHashSet(buildIter)
hashSemiJoin(streamIter, hashSet, numOutputRows)
} else {
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
hashSemiJoin(streamIter, hashRelation, numOutputRows)
}
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
hashSemiJoin(streamIter, hashRelation, numOutputRows)
}
}
}