[SPARK-1495][SQL]add support for left semi join
Just submit another solution for #395 Author: Daoyuan <daoyuan.wang@intel.com> Author: Michael Armbrust <michael@databricks.com> Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #837 from adrian-wang/left-semi-join-support and squashes the following commits: d39cd12 [Daoyuan Wang] Merge pull request #1 from marmbrus/pr/837 6713c09 [Michael Armbrust] Better debugging for failed query tests. 035b73e [Michael Armbrust] Add test for left semi that can't be done with a hash join. 5ec6fa4 [Michael Armbrust] Add left semi to SQL Parser. 4c726e5 [Daoyuan] improvement according to Michael 8d4a121 [Daoyuan] add golden files for leftsemijoin 83a3c8a [Daoyuan] scala style fix 14cff80 [Daoyuan] add support for left semi join
This commit is contained in:
parent
35630c86ff
commit
0cf6002801
|
@ -131,6 +131,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
|
|||
protected val OUTER = Keyword("OUTER")
|
||||
protected val RIGHT = Keyword("RIGHT")
|
||||
protected val SELECT = Keyword("SELECT")
|
||||
protected val SEMI = Keyword("SEMI")
|
||||
protected val STRING = Keyword("STRING")
|
||||
protected val SUM = Keyword("SUM")
|
||||
protected val TRUE = Keyword("TRUE")
|
||||
|
@ -241,6 +242,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
|
|||
|
||||
protected lazy val joinType: Parser[JoinType] =
|
||||
INNER ^^^ Inner |
|
||||
LEFT ~ SEMI ^^^ LeftSemi |
|
||||
LEFT ~ opt(OUTER) ^^^ LeftOuter |
|
||||
RIGHT ~ opt(OUTER) ^^^ RightOuter |
|
||||
FULL ~ opt(OUTER) ^^^ FullOuter
|
||||
|
|
|
@ -119,6 +119,11 @@ object HashFilteredJoin extends Logging with PredicateHelper {
|
|||
case FilteredOperation(predicates, join @ Join(left, right, Inner, condition)) =>
|
||||
logger.debug(s"Considering hash inner join on: ${predicates ++ condition}")
|
||||
splitPredicates(predicates ++ condition, join)
|
||||
// All predicates can be evaluated for left semi join (those that are in the WHERE
|
||||
// clause can only from left table, so they can all be pushed down.)
|
||||
case FilteredOperation(predicates, join @ Join(left, right, LeftSemi, condition)) =>
|
||||
logger.debug(s"Considering hash left semi join on: ${predicates ++ condition}")
|
||||
splitPredicates(predicates ++ condition, join)
|
||||
case join @ Join(left, right, joinType, condition) =>
|
||||
logger.debug(s"Considering hash join on: $condition")
|
||||
splitPredicates(condition.toSeq, join)
|
||||
|
|
|
@ -22,3 +22,4 @@ case object Inner extends JoinType
|
|||
case object LeftOuter extends JoinType
|
||||
case object RightOuter extends JoinType
|
||||
case object FullOuter extends JoinType
|
||||
case object LeftSemi extends JoinType
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.plans.logical
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.JoinType
|
||||
import org.apache.spark.sql.catalyst.plans.{LeftSemi, JoinType}
|
||||
import org.apache.spark.sql.catalyst.types._
|
||||
|
||||
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
|
||||
|
@ -81,7 +81,12 @@ case class Join(
|
|||
condition: Option[Expression]) extends BinaryNode {
|
||||
|
||||
def references = condition.map(_.references).getOrElse(Set.empty)
|
||||
def output = left.output ++ right.output
|
||||
def output = joinType match {
|
||||
case LeftSemi =>
|
||||
left.output
|
||||
case _ =>
|
||||
left.output ++ right.output
|
||||
}
|
||||
}
|
||||
|
||||
case class InsertIntoTable(
|
||||
|
|
|
@ -193,6 +193,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
|||
val strategies: Seq[Strategy] =
|
||||
TakeOrdered ::
|
||||
PartialAggregation ::
|
||||
LeftSemiJoin ::
|
||||
HashJoin ::
|
||||
ParquetOperations ::
|
||||
BasicOperators ::
|
||||
|
|
|
@ -28,6 +28,22 @@ import org.apache.spark.sql.parquet._
|
|||
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
||||
self: SQLContext#SparkPlanner =>
|
||||
|
||||
object LeftSemiJoin extends Strategy with PredicateHelper {
|
||||
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
|
||||
// Find left semi joins where at least some predicates can be evaluated by matching hash
|
||||
// keys using the HashFilteredJoin pattern.
|
||||
case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
|
||||
val semiJoin = execution.LeftSemiJoinHash(
|
||||
leftKeys, rightKeys, planLater(left), planLater(right))
|
||||
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
|
||||
// no predicate can be evaluated by matching hash keys
|
||||
case logical.Join(left, right, LeftSemi, condition) =>
|
||||
execution.LeftSemiJoinBNL(
|
||||
planLater(left), planLater(right), condition)(sparkContext) :: Nil
|
||||
case _ => Nil
|
||||
}
|
||||
}
|
||||
|
||||
object HashJoin extends Strategy with PredicateHelper {
|
||||
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
|
||||
// Find inner joins where at least some predicates can be evaluated by matching hash keys
|
||||
|
|
|
@ -140,6 +140,137 @@ case class HashJoin(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
* Build the right table's join keys into a HashSet, and iteratively go through the left
|
||||
* table, to find the if join keys are in the Hash set.
|
||||
*/
|
||||
@DeveloperApi
|
||||
case class LeftSemiJoinHash(
|
||||
leftKeys: Seq[Expression],
|
||||
rightKeys: Seq[Expression],
|
||||
left: SparkPlan,
|
||||
right: SparkPlan) extends BinaryNode {
|
||||
|
||||
override def outputPartitioning: Partitioning = left.outputPartitioning
|
||||
|
||||
override def requiredChildDistribution =
|
||||
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
|
||||
|
||||
val (buildPlan, streamedPlan) = (right, left)
|
||||
val (buildKeys, streamedKeys) = (rightKeys, leftKeys)
|
||||
|
||||
def output = left.output
|
||||
|
||||
@transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
|
||||
@transient lazy val streamSideKeyGenerator =
|
||||
() => new MutableProjection(streamedKeys, streamedPlan.output)
|
||||
|
||||
def execute() = {
|
||||
|
||||
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
|
||||
val hashTable = new java.util.HashSet[Row]()
|
||||
var currentRow: Row = null
|
||||
|
||||
// Create a Hash set of buildKeys
|
||||
while (buildIter.hasNext) {
|
||||
currentRow = buildIter.next()
|
||||
val rowKey = buildSideKeyGenerator(currentRow)
|
||||
if(!rowKey.anyNull) {
|
||||
val keyExists = hashTable.contains(rowKey)
|
||||
if (!keyExists) {
|
||||
hashTable.add(rowKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
new Iterator[Row] {
|
||||
private[this] var currentStreamedRow: Row = _
|
||||
private[this] var currentHashMatched: Boolean = false
|
||||
|
||||
private[this] val joinKeys = streamSideKeyGenerator()
|
||||
|
||||
override final def hasNext: Boolean =
|
||||
streamIter.hasNext && fetchNext()
|
||||
|
||||
override final def next() = {
|
||||
currentStreamedRow
|
||||
}
|
||||
|
||||
/**
|
||||
* Searches the streamed iterator for the next row that has at least one match in hashtable.
|
||||
*
|
||||
* @return true if the search is successful, and false the streamed iterator runs out of
|
||||
* tuples.
|
||||
*/
|
||||
private final def fetchNext(): Boolean = {
|
||||
currentHashMatched = false
|
||||
while (!currentHashMatched && streamIter.hasNext) {
|
||||
currentStreamedRow = streamIter.next()
|
||||
if (!joinKeys(currentStreamedRow).anyNull) {
|
||||
currentHashMatched = hashTable.contains(joinKeys.currentValue)
|
||||
}
|
||||
}
|
||||
currentHashMatched
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
* Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
|
||||
* for hash join.
|
||||
*/
|
||||
@DeveloperApi
|
||||
case class LeftSemiJoinBNL(
|
||||
streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
|
||||
(@transient sc: SparkContext)
|
||||
extends BinaryNode {
|
||||
// TODO: Override requiredChildDistribution.
|
||||
|
||||
override def outputPartitioning: Partitioning = streamed.outputPartitioning
|
||||
|
||||
override def otherCopyArgs = sc :: Nil
|
||||
|
||||
def output = left.output
|
||||
|
||||
/** The Streamed Relation */
|
||||
def left = streamed
|
||||
/** The Broadcast relation */
|
||||
def right = broadcast
|
||||
|
||||
@transient lazy val boundCondition =
|
||||
InterpretedPredicate(
|
||||
condition
|
||||
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
|
||||
.getOrElse(Literal(true)))
|
||||
|
||||
|
||||
def execute() = {
|
||||
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
|
||||
|
||||
streamed.execute().mapPartitions { streamedIter =>
|
||||
val joinedRow = new JoinedRow
|
||||
|
||||
streamedIter.filter(streamedRow => {
|
||||
var i = 0
|
||||
var matched = false
|
||||
|
||||
while (i < broadcastedRelation.value.size && !matched) {
|
||||
val broadcastedRow = broadcastedRelation.value(i)
|
||||
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
|
||||
matched = true
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
matched
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
*/
|
||||
|
|
|
@ -44,7 +44,7 @@ class QueryTest extends FunSuite {
|
|||
fail(
|
||||
s"""
|
||||
|Exception thrown while executing query:
|
||||
|${rdd.logicalPlan}
|
||||
|${rdd.queryExecution}
|
||||
|== Exception ==
|
||||
|$e
|
||||
""".stripMargin)
|
||||
|
|
|
@ -40,6 +40,13 @@ class SQLQuerySuite extends QueryTest {
|
|||
arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq)
|
||||
}
|
||||
|
||||
test("left semi greater than predicate") {
|
||||
checkAnswer(
|
||||
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
|
||||
Seq((3,1), (3,2))
|
||||
)
|
||||
}
|
||||
|
||||
test("index into array of arrays") {
|
||||
checkAnswer(
|
||||
sql(
|
||||
|
|
|
@ -224,6 +224,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
|
|||
DataSinks,
|
||||
Scripts,
|
||||
PartialAggregation,
|
||||
LeftSemiJoin,
|
||||
HashJoin,
|
||||
BasicOperators,
|
||||
CartesianProduct,
|
||||
|
|
|
@ -685,6 +685,7 @@ private[hive] object HiveQl {
|
|||
case "TOK_RIGHTOUTERJOIN" => RightOuter
|
||||
case "TOK_LEFTOUTERJOIN" => LeftOuter
|
||||
case "TOK_FULLOUTERJOIN" => FullOuter
|
||||
case "TOK_LEFTSEMIJOIN" => LeftSemi
|
||||
}
|
||||
assert(other.size <= 1, "Unhandled join clauses.")
|
||||
Join(nodeToRelation(relation1),
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
Hank 2
|
||||
Hank 2
|
||||
Joe 2
|
||||
Joe 2
|
|
@ -0,0 +1 @@
|
|||
0
|
|
@ -0,0 +1,4 @@
|
|||
Hank 2
|
||||
Hank 2
|
||||
Joe 2
|
||||
Joe 2
|
|
@ -0,0 +1,2 @@
|
|||
2 Tie
|
||||
2 Tie
|
|
@ -0,0 +1 @@
|
|||
0
|
|
@ -0,0 +1,2 @@
|
|||
1
|
||||
1
|
|
@ -0,0 +1,2 @@
|
|||
1
|
||||
1
|
|
@ -0,0 +1,20 @@
|
|||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
|
@ -0,0 +1 @@
|
|||
0
|
|
@ -0,0 +1 @@
|
|||
0
|
|
@ -0,0 +1 @@
|
|||
0
|
|
@ -0,0 +1,2 @@
|
|||
1
|
||||
1
|
|
@ -480,6 +480,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
|
|||
"lateral_view",
|
||||
"lateral_view_cp",
|
||||
"lateral_view_ppd",
|
||||
"leftsemijoin",
|
||||
"leftsemijoin_mr",
|
||||
"lineage1",
|
||||
"literal_double",
|
||||
"literal_ints",
|
||||
|
|
Loading…
Reference in a new issue