[SPARK-1495][SQL]add support for left semi join

Just submit another solution for #395

Author: Daoyuan <daoyuan.wang@intel.com>
Author: Michael Armbrust <michael@databricks.com>
Author: Daoyuan Wang <daoyuan.wang@intel.com>

Closes #837 from adrian-wang/left-semi-join-support and squashes the following commits:

d39cd12 [Daoyuan Wang] Merge pull request #1 from marmbrus/pr/837
6713c09 [Michael Armbrust] Better debugging for failed query tests.
035b73e [Michael Armbrust] Add test for left semi that can't be done with a hash join.
5ec6fa4 [Michael Armbrust] Add left semi to SQL Parser.
4c726e5 [Daoyuan] improvement according to Michael
8d4a121 [Daoyuan] add golden files for leftsemijoin
83a3c8a [Daoyuan] scala style fix
14cff80 [Daoyuan] add support for left semi join
This commit is contained in:
Daoyuan 2014-06-09 11:31:36 -07:00 committed by Michael Armbrust
parent 35630c86ff
commit 0cf6002801
37 changed files with 216 additions and 3 deletions

View file

@ -131,6 +131,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val OUTER = Keyword("OUTER")
protected val RIGHT = Keyword("RIGHT")
protected val SELECT = Keyword("SELECT")
protected val SEMI = Keyword("SEMI")
protected val STRING = Keyword("STRING")
protected val SUM = Keyword("SUM")
protected val TRUE = Keyword("TRUE")
@ -241,6 +242,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected lazy val joinType: Parser[JoinType] =
INNER ^^^ Inner |
LEFT ~ SEMI ^^^ LeftSemi |
LEFT ~ opt(OUTER) ^^^ LeftOuter |
RIGHT ~ opt(OUTER) ^^^ RightOuter |
FULL ~ opt(OUTER) ^^^ FullOuter

View file

@ -119,6 +119,11 @@ object HashFilteredJoin extends Logging with PredicateHelper {
case FilteredOperation(predicates, join @ Join(left, right, Inner, condition)) =>
logger.debug(s"Considering hash inner join on: ${predicates ++ condition}")
splitPredicates(predicates ++ condition, join)
// All predicates can be evaluated for left semi join (those that are in the WHERE
// clause can only from left table, so they can all be pushed down.)
case FilteredOperation(predicates, join @ Join(left, right, LeftSemi, condition)) =>
logger.debug(s"Considering hash left semi join on: ${predicates ++ condition}")
splitPredicates(predicates ++ condition, join)
case join @ Join(left, right, joinType, condition) =>
logger.debug(s"Considering hash join on: $condition")
splitPredicates(condition.toSeq, join)

View file

@ -22,3 +22,4 @@ case object Inner extends JoinType
case object LeftOuter extends JoinType
case object RightOuter extends JoinType
case object FullOuter extends JoinType
case object LeftSemi extends JoinType

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.{LeftSemi, JoinType}
import org.apache.spark.sql.catalyst.types._
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
@ -81,7 +81,12 @@ case class Join(
condition: Option[Expression]) extends BinaryNode {
def references = condition.map(_.references).getOrElse(Set.empty)
def output = left.output ++ right.output
def output = joinType match {
case LeftSemi =>
left.output
case _ =>
left.output ++ right.output
}
}
case class InsertIntoTable(

View file

@ -193,6 +193,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
val strategies: Seq[Strategy] =
TakeOrdered ::
PartialAggregation ::
LeftSemiJoin ::
HashJoin ::
ParquetOperations ::
BasicOperators ::

View file

@ -28,6 +28,22 @@ import org.apache.spark.sql.parquet._
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SQLContext#SparkPlanner =>
object LeftSemiJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Find left semi joins where at least some predicates can be evaluated by matching hash
// keys using the HashFilteredJoin pattern.
case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
val semiJoin = execution.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
execution.LeftSemiJoinBNL(
planLater(left), planLater(right), condition)(sparkContext) :: Nil
case _ => Nil
}
}
object HashJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Find inner joins where at least some predicates can be evaluated by matching hash keys

View file

@ -140,6 +140,137 @@ case class HashJoin(
}
}
/**
* :: DeveloperApi ::
* Build the right table's join keys into a HashSet, and iteratively go through the left
* table, to find the if join keys are in the Hash set.
*/
@DeveloperApi
case class LeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {
override def outputPartitioning: Partitioning = left.outputPartitioning
override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
val (buildPlan, streamedPlan) = (right, left)
val (buildKeys, streamedKeys) = (rightKeys, leftKeys)
def output = left.output
@transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
@transient lazy val streamSideKeyGenerator =
() => new MutableProjection(streamedKeys, streamedPlan.output)
def execute() = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashTable = new java.util.HashSet[Row]()
var currentRow: Row = null
// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
val keyExists = hashTable.contains(rowKey)
if (!keyExists) {
hashTable.add(rowKey)
}
}
}
new Iterator[Row] {
private[this] var currentStreamedRow: Row = _
private[this] var currentHashMatched: Boolean = false
private[this] val joinKeys = streamSideKeyGenerator()
override final def hasNext: Boolean =
streamIter.hasNext && fetchNext()
override final def next() = {
currentStreamedRow
}
/**
* Searches the streamed iterator for the next row that has at least one match in hashtable.
*
* @return true if the search is successful, and false the streamed iterator runs out of
* tuples.
*/
private final def fetchNext(): Boolean = {
currentHashMatched = false
while (!currentHashMatched && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
if (!joinKeys(currentStreamedRow).anyNull) {
currentHashMatched = hashTable.contains(joinKeys.currentValue)
}
}
currentHashMatched
}
}
}
}
}
/**
* :: DeveloperApi ::
* Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
* for hash join.
*/
@DeveloperApi
case class LeftSemiJoinBNL(
streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
(@transient sc: SparkContext)
extends BinaryNode {
// TODO: Override requiredChildDistribution.
override def outputPartitioning: Partitioning = streamed.outputPartitioning
override def otherCopyArgs = sc :: Nil
def output = left.output
/** The Streamed Relation */
def left = streamed
/** The Broadcast relation */
def right = broadcast
@transient lazy val boundCondition =
InterpretedPredicate(
condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true)))
def execute() = {
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
streamed.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow
streamedIter.filter(streamedRow => {
var i = 0
var matched = false
while (i < broadcastedRelation.value.size && !matched) {
val broadcastedRow = broadcastedRelation.value(i)
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
matched = true
}
i += 1
}
matched
})
}
}
}
/**
* :: DeveloperApi ::
*/

View file

@ -44,7 +44,7 @@ class QueryTest extends FunSuite {
fail(
s"""
|Exception thrown while executing query:
|${rdd.logicalPlan}
|${rdd.queryExecution}
|== Exception ==
|$e
""".stripMargin)

View file

@ -40,6 +40,13 @@ class SQLQuerySuite extends QueryTest {
arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq)
}
test("left semi greater than predicate") {
checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
Seq((3,1), (3,2))
)
}
test("index into array of arrays") {
checkAnswer(
sql(

View file

@ -224,6 +224,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
DataSinks,
Scripts,
PartialAggregation,
LeftSemiJoin,
HashJoin,
BasicOperators,
CartesianProduct,

View file

@ -685,6 +685,7 @@ private[hive] object HiveQl {
case "TOK_RIGHTOUTERJOIN" => RightOuter
case "TOK_LEFTOUTERJOIN" => LeftOuter
case "TOK_FULLOUTERJOIN" => FullOuter
case "TOK_LEFTSEMIJOIN" => LeftSemi
}
assert(other.size <= 1, "Unhandled join clauses.")
Join(nodeToRelation(relation1),

View file

@ -0,0 +1,4 @@
Hank 2
Hank 2
Joe 2
Joe 2

View file

@ -0,0 +1,4 @@
Hank 2
Hank 2
Joe 2
Joe 2

View file

@ -0,0 +1,20 @@
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1

View file

@ -480,6 +480,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"lateral_view",
"lateral_view_cp",
"lateral_view_ppd",
"leftsemijoin",
"leftsemijoin_mr",
"lineage1",
"literal_double",
"literal_ints",