[SQL] Rewrite join implementation to allow streaming of one relation.
Before we were materializing everything in memory. This also uses the projection interface so will be easier to plug in code gen (its ported from that branch). @rxin @liancheng Author: Michael Armbrust <michael@databricks.com> Closes #250 from marmbrus/hashJoin and squashes the following commits: 1ad873e [Michael Armbrust] Change hasNext logic back to the correct version. 8e6f2a2 [Michael Armbrust] Review comments. 1e9fb63 [Michael Armbrust] style bc0cb84 [Michael Armbrust] Rewrite join implementation to allow streaming of one relation.
This commit is contained in:
parent
841721e03c
commit
5731af5be6
|
@ -44,6 +44,16 @@ trait Row extends Seq[Any] with Serializable {
|
|||
s"[${this.mkString(",")}]"
|
||||
|
||||
def copy(): Row
|
||||
|
||||
/** Returns true if there are any NULL values in this row. */
|
||||
def anyNull: Boolean = {
|
||||
var i = 0
|
||||
while (i < length) {
|
||||
if (isNullAt(i)) { return true }
|
||||
i += 1
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -21,6 +21,12 @@ import org.apache.spark.sql.catalyst.trees
|
|||
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
|
||||
import org.apache.spark.sql.catalyst.types.{BooleanType, StringType}
|
||||
|
||||
object InterpretedPredicate {
|
||||
def apply(expression: Expression): (Row => Boolean) = {
|
||||
(r: Row) => expression.apply(r).asInstanceOf[Boolean]
|
||||
}
|
||||
}
|
||||
|
||||
trait Predicate extends Expression {
|
||||
self: Product =>
|
||||
|
||||
|
|
|
@ -117,7 +117,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
|||
val strategies: Seq[Strategy] =
|
||||
TopK ::
|
||||
PartialAggregation ::
|
||||
SparkEquiInnerJoin ::
|
||||
HashJoin ::
|
||||
ParquetOperations ::
|
||||
BasicOperators ::
|
||||
CartesianProduct ::
|
||||
|
|
|
@ -28,7 +28,7 @@ import org.apache.spark.sql.parquet._
|
|||
abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
||||
self: SQLContext#SparkPlanner =>
|
||||
|
||||
object SparkEquiInnerJoin extends Strategy {
|
||||
object HashJoin extends Strategy {
|
||||
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
|
||||
case FilteredOperation(predicates, logical.Join(left, right, Inner, condition)) =>
|
||||
logger.debug(s"Considering join: ${predicates ++ condition}")
|
||||
|
@ -51,8 +51,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
val leftKeys = joinKeys.map(_._1)
|
||||
val rightKeys = joinKeys.map(_._2)
|
||||
|
||||
val joinOp = execution.SparkEquiInnerJoin(
|
||||
leftKeys, rightKeys, planLater(left), planLater(right))
|
||||
val joinOp = execution.HashJoin(
|
||||
leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))
|
||||
|
||||
// Make sure other conditions are met if present.
|
||||
if (otherPredicates.nonEmpty) {
|
||||
|
|
|
@ -17,21 +17,22 @@
|
|||
|
||||
package org.apache.spark.sql.execution
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.{ArrayBuffer, BitSet}
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.SparkContext
|
||||
|
||||
import org.apache.spark.sql.catalyst.errors._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
|
||||
|
||||
import org.apache.spark.rdd.PartitionLocalRDDFunctions._
|
||||
sealed abstract class BuildSide
|
||||
case object BuildLeft extends BuildSide
|
||||
case object BuildRight extends BuildSide
|
||||
|
||||
case class SparkEquiInnerJoin(
|
||||
case class HashJoin(
|
||||
leftKeys: Seq[Expression],
|
||||
rightKeys: Seq[Expression],
|
||||
buildSide: BuildSide,
|
||||
left: SparkPlan,
|
||||
right: SparkPlan) extends BinaryNode {
|
||||
|
||||
|
@ -40,32 +41,92 @@ case class SparkEquiInnerJoin(
|
|||
override def requiredChildDistribution =
|
||||
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
|
||||
|
||||
val (buildPlan, streamedPlan) = buildSide match {
|
||||
case BuildLeft => (left, right)
|
||||
case BuildRight => (right, left)
|
||||
}
|
||||
|
||||
val (buildKeys, streamedKeys) = buildSide match {
|
||||
case BuildLeft => (leftKeys, rightKeys)
|
||||
case BuildRight => (rightKeys, leftKeys)
|
||||
}
|
||||
|
||||
def output = left.output ++ right.output
|
||||
|
||||
def execute() = attachTree(this, "execute") {
|
||||
val leftWithKeys = left.execute().mapPartitions { iter =>
|
||||
val generateLeftKeys = new Projection(leftKeys, left.output)
|
||||
iter.map(row => (generateLeftKeys(row), row.copy()))
|
||||
@transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
|
||||
@transient lazy val streamSideKeyGenerator =
|
||||
() => new MutableProjection(streamedKeys, streamedPlan.output)
|
||||
|
||||
def execute() = {
|
||||
|
||||
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
|
||||
// TODO: Use Spark's HashMap implementation.
|
||||
val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
|
||||
var currentRow: Row = null
|
||||
|
||||
// Create a mapping of buildKeys -> rows
|
||||
while (buildIter.hasNext) {
|
||||
currentRow = buildIter.next()
|
||||
val rowKey = buildSideKeyGenerator(currentRow)
|
||||
if(!rowKey.anyNull) {
|
||||
val existingMatchList = hashTable.get(rowKey)
|
||||
val matchList = if (existingMatchList == null) {
|
||||
val newMatchList = new ArrayBuffer[Row]()
|
||||
hashTable.put(rowKey, newMatchList)
|
||||
newMatchList
|
||||
} else {
|
||||
existingMatchList
|
||||
}
|
||||
matchList += currentRow.copy()
|
||||
}
|
||||
}
|
||||
|
||||
val rightWithKeys = right.execute().mapPartitions { iter =>
|
||||
val generateRightKeys = new Projection(rightKeys, right.output)
|
||||
iter.map(row => (generateRightKeys(row), row.copy()))
|
||||
}
|
||||
new Iterator[Row] {
|
||||
private[this] var currentStreamedRow: Row = _
|
||||
private[this] var currentHashMatches: ArrayBuffer[Row] = _
|
||||
private[this] var currentMatchPosition: Int = -1
|
||||
|
||||
// Do the join.
|
||||
val joined = filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys))
|
||||
// Drop join keys and merge input tuples.
|
||||
joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ rightTuple) }
|
||||
// Mutable per row objects.
|
||||
private[this] val joinRow = new JoinedRow
|
||||
|
||||
private[this] val joinKeys = streamSideKeyGenerator()
|
||||
|
||||
override final def hasNext: Boolean =
|
||||
(currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
|
||||
(streamIter.hasNext && fetchNext())
|
||||
|
||||
override final def next() = {
|
||||
val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
|
||||
currentMatchPosition += 1
|
||||
ret
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters any rows where the any of the join keys is null, ensuring three-valued
|
||||
* logic for the equi-join conditions.
|
||||
* Searches the streamed iterator for the next row that has at least one match in hashtable.
|
||||
*
|
||||
* @return true if the search is successful, and false the streamed iterator runs out of
|
||||
* tuples.
|
||||
*/
|
||||
protected def filterNulls(rdd: RDD[(Row, Row)]) =
|
||||
rdd.filter {
|
||||
case (key: Seq[_], _) => !key.exists(_ == null)
|
||||
private final def fetchNext(): Boolean = {
|
||||
currentHashMatches = null
|
||||
currentMatchPosition = -1
|
||||
|
||||
while (currentHashMatches == null && streamIter.hasNext) {
|
||||
currentStreamedRow = streamIter.next()
|
||||
if (!joinKeys(currentStreamedRow).anyNull) {
|
||||
currentHashMatches = hashTable.get(joinKeys.currentValue)
|
||||
}
|
||||
}
|
||||
|
||||
if (currentHashMatches == null) {
|
||||
false
|
||||
} else {
|
||||
currentMatchPosition = 0
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -95,17 +156,19 @@ case class BroadcastNestedLoopJoin(
|
|||
def right = broadcast
|
||||
|
||||
@transient lazy val boundCondition =
|
||||
InterpretedPredicate(
|
||||
condition
|
||||
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
|
||||
.getOrElse(Literal(true))
|
||||
.getOrElse(Literal(true)))
|
||||
|
||||
|
||||
def execute() = {
|
||||
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
|
||||
|
||||
val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
|
||||
val matchedRows = new mutable.ArrayBuffer[Row]
|
||||
val includedBroadcastTuples = new mutable.BitSet(broadcastedRelation.value.size)
|
||||
val matchedRows = new ArrayBuffer[Row]
|
||||
// TODO: Use Spark's BitSet.
|
||||
val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size)
|
||||
val joinedRow = new JoinedRow
|
||||
|
||||
streamedIter.foreach { streamedRow =>
|
||||
|
@ -115,7 +178,7 @@ case class BroadcastNestedLoopJoin(
|
|||
while (i < broadcastedRelation.value.size) {
|
||||
// TODO: One bitset per partition instead of per row.
|
||||
val broadcastedRow = broadcastedRelation.value(i)
|
||||
if (boundCondition(joinedRow(streamedRow, broadcastedRow)).asInstanceOf[Boolean]) {
|
||||
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
|
||||
matchedRows += buildRow(streamedRow ++ broadcastedRow)
|
||||
matched = true
|
||||
includedBroadcastTuples += i
|
||||
|
|
|
@ -194,7 +194,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
|
|||
DataSinks,
|
||||
Scripts,
|
||||
PartialAggregation,
|
||||
SparkEquiInnerJoin,
|
||||
HashJoin,
|
||||
BasicOperators,
|
||||
CartesianProduct,
|
||||
BroadcastNestedLoopJoin
|
||||
|
|
Loading…
Reference in a new issue