[SPARK-13422][SQL] Use HashedRelation instead of HashSet in Left Semi Joins
Use the HashedRelation which is a more optimized datastructure and reduce code complexity Author: Xiu Guo <xguo27@gmail.com> Closes #11291 from xguo27/SPARK-13422.
This commit is contained in:
parent
173aa949c3
commit
2063781840
|
@ -26,8 +26,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
|
|||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
|
||||
/**
|
||||
* Build the right table's join keys into a HashSet, and iteratively go through the left
|
||||
* table, to find the if join keys are in the Hash set.
|
||||
* Build the right table's join keys into a HashedRelation, and iteratively go through the left
|
||||
* table, to find if the join keys are in the HashedRelation.
|
||||
*/
|
||||
case class BroadcastLeftSemiJoinHash(
|
||||
leftKeys: Seq[Expression],
|
||||
|
@ -40,29 +40,18 @@ case class BroadcastLeftSemiJoinHash(
|
|||
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
|
||||
|
||||
override def requiredChildDistribution: Seq[Distribution] = {
|
||||
val mode = if (condition.isEmpty) {
|
||||
HashSetBroadcastMode(rightKeys, right.output)
|
||||
} else {
|
||||
HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output)
|
||||
}
|
||||
val mode = HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output)
|
||||
UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
|
||||
}
|
||||
|
||||
protected override def doExecute(): RDD[InternalRow] = {
|
||||
val numOutputRows = longMetric("numOutputRows")
|
||||
|
||||
if (condition.isEmpty) {
|
||||
val broadcastedRelation = right.executeBroadcast[java.util.Set[InternalRow]]()
|
||||
left.execute().mapPartitionsInternal { streamIter =>
|
||||
hashSemiJoin(streamIter, broadcastedRelation.value, numOutputRows)
|
||||
}
|
||||
} else {
|
||||
val broadcastedRelation = right.executeBroadcast[HashedRelation]()
|
||||
left.execute().mapPartitionsInternal { streamIter =>
|
||||
val hashedRelation = broadcastedRelation.value
|
||||
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
|
||||
hashSemiJoin(streamIter, hashedRelation, numOutputRows)
|
||||
}
|
||||
val broadcastedRelation = right.executeBroadcast[HashedRelation]()
|
||||
left.execute().mapPartitionsInternal { streamIter =>
|
||||
val hashedRelation = broadcastedRelation.value
|
||||
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
|
||||
hashSemiJoin(streamIter, hashedRelation, numOutputRows)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,24 +43,6 @@ trait HashSemiJoin {
|
|||
@transient private lazy val boundCondition =
|
||||
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
|
||||
|
||||
protected def buildKeyHashSet(
|
||||
buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = {
|
||||
HashSemiJoin.buildKeyHashSet(rightKeys, right.output, buildIter)
|
||||
}
|
||||
|
||||
protected def hashSemiJoin(
|
||||
streamIter: Iterator[InternalRow],
|
||||
hashSet: java.util.Set[InternalRow],
|
||||
numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
|
||||
val joinKeys = leftKeyGenerator
|
||||
streamIter.filter(current => {
|
||||
val key = joinKeys(current)
|
||||
val r = !key.anyNull && hashSet.contains(key)
|
||||
if (r) numOutputRows += 1
|
||||
r
|
||||
})
|
||||
}
|
||||
|
||||
protected def hashSemiJoin(
|
||||
streamIter: Iterator[InternalRow],
|
||||
hashedRelation: HashedRelation,
|
||||
|
@ -70,44 +52,11 @@ trait HashSemiJoin {
|
|||
streamIter.filter { current =>
|
||||
val key = joinKeys(current)
|
||||
lazy val rowBuffer = hashedRelation.get(key)
|
||||
val r = !key.anyNull && rowBuffer != null && rowBuffer.exists {
|
||||
val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists {
|
||||
(row: InternalRow) => boundCondition(joinedRow(current, row))
|
||||
}
|
||||
})
|
||||
if (r) numOutputRows += 1
|
||||
r
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[execution] object HashSemiJoin {
|
||||
def buildKeyHashSet(
|
||||
keys: Seq[Expression],
|
||||
attributes: Seq[Attribute],
|
||||
rows: Iterator[InternalRow]): java.util.HashSet[InternalRow] = {
|
||||
val hashSet = new java.util.HashSet[InternalRow]()
|
||||
|
||||
// Create a Hash set of buildKeys
|
||||
val key = UnsafeProjection.create(keys, attributes)
|
||||
while (rows.hasNext) {
|
||||
val currentRow = rows.next()
|
||||
val rowKey = key(currentRow)
|
||||
if (!rowKey.anyNull) {
|
||||
val keyExists = hashSet.contains(rowKey)
|
||||
if (!keyExists) {
|
||||
hashSet.add(rowKey.copy())
|
||||
}
|
||||
}
|
||||
}
|
||||
hashSet
|
||||
}
|
||||
}
|
||||
|
||||
/** HashSetBroadcastMode requires that the input rows are broadcasted as a set. */
|
||||
private[execution] case class HashSetBroadcastMode(
|
||||
keys: Seq[Expression],
|
||||
attributes: Seq[Attribute]) extends BroadcastMode {
|
||||
|
||||
override def transform(rows: Array[InternalRow]): java.util.HashSet[InternalRow] = {
|
||||
HashSemiJoin.buildKeyHashSet(keys, attributes, rows.iterator)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,8 +25,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
|
|||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
|
||||
/**
|
||||
* Build the right table's join keys into a HashSet, and iteratively go through the left
|
||||
* table, to find the if join keys are in the Hash set.
|
||||
* Build the right table's join keys into a HashedRelation, and iteratively go through the left
|
||||
* table, to find if the join keys are in the HashedRelation.
|
||||
*/
|
||||
case class LeftSemiJoinHash(
|
||||
leftKeys: Seq[Expression],
|
||||
|
@ -47,13 +47,8 @@ case class LeftSemiJoinHash(
|
|||
val numOutputRows = longMetric("numOutputRows")
|
||||
|
||||
right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
|
||||
if (condition.isEmpty) {
|
||||
val hashSet = buildKeyHashSet(buildIter)
|
||||
hashSemiJoin(streamIter, hashSet, numOutputRows)
|
||||
} else {
|
||||
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
|
||||
hashSemiJoin(streamIter, hashRelation, numOutputRows)
|
||||
}
|
||||
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
|
||||
hashSemiJoin(streamIter, hashRelation, numOutputRows)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue