[SPARK-1495][SQL]add support for left semi join

Just submit another solution for #395

Author: Daoyuan <daoyuan.wang@intel.com>
Author: Michael Armbrust <michael@databricks.com>
Author: Daoyuan Wang <daoyuan.wang@intel.com>

Closes #837 from adrian-wang/left-semi-join-support and squashes the following commits:

d39cd12 [Daoyuan Wang] Merge pull request #1 from marmbrus/pr/837
6713c09 [Michael Armbrust] Better debugging for failed query tests.
035b73e [Michael Armbrust] Add test for left semi that can't be done with a hash join.
5ec6fa4 [Michael Armbrust] Add left semi to SQL Parser.
4c726e5 [Daoyuan] improvement according to Michael
8d4a121 [Daoyuan] add golden files for leftsemijoin
83a3c8a [Daoyuan] scala style fix
14cff80 [Daoyuan] add support for left semi join
This commit is contained in:
Daoyuan 2014-06-09 11:31:36 -07:00 committed by Michael Armbrust
parent 35630c86ff
commit 0cf6002801
37 changed files with 216 additions and 3 deletions

View file

@ -131,6 +131,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val OUTER = Keyword("OUTER") protected val OUTER = Keyword("OUTER")
protected val RIGHT = Keyword("RIGHT") protected val RIGHT = Keyword("RIGHT")
protected val SELECT = Keyword("SELECT") protected val SELECT = Keyword("SELECT")
protected val SEMI = Keyword("SEMI")
protected val STRING = Keyword("STRING") protected val STRING = Keyword("STRING")
protected val SUM = Keyword("SUM") protected val SUM = Keyword("SUM")
protected val TRUE = Keyword("TRUE") protected val TRUE = Keyword("TRUE")
@ -241,6 +242,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected lazy val joinType: Parser[JoinType] = protected lazy val joinType: Parser[JoinType] =
INNER ^^^ Inner | INNER ^^^ Inner |
LEFT ~ SEMI ^^^ LeftSemi |
LEFT ~ opt(OUTER) ^^^ LeftOuter | LEFT ~ opt(OUTER) ^^^ LeftOuter |
RIGHT ~ opt(OUTER) ^^^ RightOuter | RIGHT ~ opt(OUTER) ^^^ RightOuter |
FULL ~ opt(OUTER) ^^^ FullOuter FULL ~ opt(OUTER) ^^^ FullOuter

View file

@ -119,6 +119,11 @@ object HashFilteredJoin extends Logging with PredicateHelper {
case FilteredOperation(predicates, join @ Join(left, right, Inner, condition)) => case FilteredOperation(predicates, join @ Join(left, right, Inner, condition)) =>
logger.debug(s"Considering hash inner join on: ${predicates ++ condition}") logger.debug(s"Considering hash inner join on: ${predicates ++ condition}")
splitPredicates(predicates ++ condition, join) splitPredicates(predicates ++ condition, join)
// All predicates can be evaluated for left semi join (those that are in the WHERE
// clause can only from left table, so they can all be pushed down.)
case FilteredOperation(predicates, join @ Join(left, right, LeftSemi, condition)) =>
logger.debug(s"Considering hash left semi join on: ${predicates ++ condition}")
splitPredicates(predicates ++ condition, join)
case join @ Join(left, right, joinType, condition) => case join @ Join(left, right, joinType, condition) =>
logger.debug(s"Considering hash join on: $condition") logger.debug(s"Considering hash join on: $condition")
splitPredicates(condition.toSeq, join) splitPredicates(condition.toSeq, join)

View file

@ -22,3 +22,4 @@ case object Inner extends JoinType
case object LeftOuter extends JoinType case object LeftOuter extends JoinType
case object RightOuter extends JoinType case object RightOuter extends JoinType
case object FullOuter extends JoinType case object FullOuter extends JoinType
case object LeftSemi extends JoinType

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.{LeftSemi, JoinType}
import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.types._
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
@ -81,7 +81,12 @@ case class Join(
condition: Option[Expression]) extends BinaryNode { condition: Option[Expression]) extends BinaryNode {
def references = condition.map(_.references).getOrElse(Set.empty) def references = condition.map(_.references).getOrElse(Set.empty)
def output = left.output ++ right.output def output = joinType match {
case LeftSemi =>
left.output
case _ =>
left.output ++ right.output
}
} }
case class InsertIntoTable( case class InsertIntoTable(

View file

@ -193,6 +193,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
val strategies: Seq[Strategy] = val strategies: Seq[Strategy] =
TakeOrdered :: TakeOrdered ::
PartialAggregation :: PartialAggregation ::
LeftSemiJoin ::
HashJoin :: HashJoin ::
ParquetOperations :: ParquetOperations ::
BasicOperators :: BasicOperators ::

View file

@ -28,6 +28,22 @@ import org.apache.spark.sql.parquet._
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SQLContext#SparkPlanner => self: SQLContext#SparkPlanner =>
object LeftSemiJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Find left semi joins where at least some predicates can be evaluated by matching hash
// keys using the HashFilteredJoin pattern.
case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
val semiJoin = execution.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
execution.LeftSemiJoinBNL(
planLater(left), planLater(right), condition)(sparkContext) :: Nil
case _ => Nil
}
}
object HashJoin extends Strategy with PredicateHelper { object HashJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Find inner joins where at least some predicates can be evaluated by matching hash keys // Find inner joins where at least some predicates can be evaluated by matching hash keys

View file

@ -140,6 +140,137 @@ case class HashJoin(
} }
} }
/**
* :: DeveloperApi ::
* Build the right table's join keys into a HashSet, and iteratively go through the left
* table, to find the if join keys are in the Hash set.
*/
@DeveloperApi
case class LeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {
override def outputPartitioning: Partitioning = left.outputPartitioning
override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
val (buildPlan, streamedPlan) = (right, left)
val (buildKeys, streamedKeys) = (rightKeys, leftKeys)
def output = left.output
@transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
@transient lazy val streamSideKeyGenerator =
() => new MutableProjection(streamedKeys, streamedPlan.output)
def execute() = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashTable = new java.util.HashSet[Row]()
var currentRow: Row = null
// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
val keyExists = hashTable.contains(rowKey)
if (!keyExists) {
hashTable.add(rowKey)
}
}
}
new Iterator[Row] {
private[this] var currentStreamedRow: Row = _
private[this] var currentHashMatched: Boolean = false
private[this] val joinKeys = streamSideKeyGenerator()
override final def hasNext: Boolean =
streamIter.hasNext && fetchNext()
override final def next() = {
currentStreamedRow
}
/**
* Searches the streamed iterator for the next row that has at least one match in hashtable.
*
* @return true if the search is successful, and false the streamed iterator runs out of
* tuples.
*/
private final def fetchNext(): Boolean = {
currentHashMatched = false
while (!currentHashMatched && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
if (!joinKeys(currentStreamedRow).anyNull) {
currentHashMatched = hashTable.contains(joinKeys.currentValue)
}
}
currentHashMatched
}
}
}
}
}
/**
* :: DeveloperApi ::
* Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
* for hash join.
*/
@DeveloperApi
case class LeftSemiJoinBNL(
streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
(@transient sc: SparkContext)
extends BinaryNode {
// TODO: Override requiredChildDistribution.
override def outputPartitioning: Partitioning = streamed.outputPartitioning
override def otherCopyArgs = sc :: Nil
def output = left.output
/** The Streamed Relation */
def left = streamed
/** The Broadcast relation */
def right = broadcast
@transient lazy val boundCondition =
InterpretedPredicate(
condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true)))
def execute() = {
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
streamed.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow
streamedIter.filter(streamedRow => {
var i = 0
var matched = false
while (i < broadcastedRelation.value.size && !matched) {
val broadcastedRow = broadcastedRelation.value(i)
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
matched = true
}
i += 1
}
matched
})
}
}
}
/** /**
* :: DeveloperApi :: * :: DeveloperApi ::
*/ */

View file

@ -44,7 +44,7 @@ class QueryTest extends FunSuite {
fail( fail(
s""" s"""
|Exception thrown while executing query: |Exception thrown while executing query:
|${rdd.logicalPlan} |${rdd.queryExecution}
|== Exception == |== Exception ==
|$e |$e
""".stripMargin) """.stripMargin)

View file

@ -40,6 +40,13 @@ class SQLQuerySuite extends QueryTest {
arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq) arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq)
} }
test("left semi greater than predicate") {
checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
Seq((3,1), (3,2))
)
}
test("index into array of arrays") { test("index into array of arrays") {
checkAnswer( checkAnswer(
sql( sql(

View file

@ -224,6 +224,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
DataSinks, DataSinks,
Scripts, Scripts,
PartialAggregation, PartialAggregation,
LeftSemiJoin,
HashJoin, HashJoin,
BasicOperators, BasicOperators,
CartesianProduct, CartesianProduct,

View file

@ -685,6 +685,7 @@ private[hive] object HiveQl {
case "TOK_RIGHTOUTERJOIN" => RightOuter case "TOK_RIGHTOUTERJOIN" => RightOuter
case "TOK_LEFTOUTERJOIN" => LeftOuter case "TOK_LEFTOUTERJOIN" => LeftOuter
case "TOK_FULLOUTERJOIN" => FullOuter case "TOK_FULLOUTERJOIN" => FullOuter
case "TOK_LEFTSEMIJOIN" => LeftSemi
} }
assert(other.size <= 1, "Unhandled join clauses.") assert(other.size <= 1, "Unhandled join clauses.")
Join(nodeToRelation(relation1), Join(nodeToRelation(relation1),

View file

@ -0,0 +1,4 @@
Hank 2
Hank 2
Joe 2
Joe 2

View file

@ -0,0 +1,4 @@
Hank 2
Hank 2
Joe 2
Joe 2

View file

@ -0,0 +1,20 @@
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1

View file

@ -480,6 +480,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"lateral_view", "lateral_view",
"lateral_view_cp", "lateral_view_cp",
"lateral_view_ppd", "lateral_view_ppd",
"leftsemijoin",
"leftsemijoin_mr",
"lineage1", "lineage1",
"literal_double", "literal_double",
"literal_ints", "literal_ints",