[SPARK-1495][SQL]add support for left semi join
Just submit another solution for #395 Author: Daoyuan <daoyuan.wang@intel.com> Author: Michael Armbrust <michael@databricks.com> Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #837 from adrian-wang/left-semi-join-support and squashes the following commits: d39cd12 [Daoyuan Wang] Merge pull request #1 from marmbrus/pr/837 6713c09 [Michael Armbrust] Better debugging for failed query tests. 035b73e [Michael Armbrust] Add test for left semi that can't be done with a hash join. 5ec6fa4 [Michael Armbrust] Add left semi to SQL Parser. 4c726e5 [Daoyuan] improvement according to Michael 8d4a121 [Daoyuan] add golden files for leftsemijoin 83a3c8a [Daoyuan] scala style fix 14cff80 [Daoyuan] add support for left semi join
This commit is contained in:
parent
35630c86ff
commit
0cf6002801
|
@ -131,6 +131,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
|
||||||
protected val OUTER = Keyword("OUTER")
|
protected val OUTER = Keyword("OUTER")
|
||||||
protected val RIGHT = Keyword("RIGHT")
|
protected val RIGHT = Keyword("RIGHT")
|
||||||
protected val SELECT = Keyword("SELECT")
|
protected val SELECT = Keyword("SELECT")
|
||||||
|
protected val SEMI = Keyword("SEMI")
|
||||||
protected val STRING = Keyword("STRING")
|
protected val STRING = Keyword("STRING")
|
||||||
protected val SUM = Keyword("SUM")
|
protected val SUM = Keyword("SUM")
|
||||||
protected val TRUE = Keyword("TRUE")
|
protected val TRUE = Keyword("TRUE")
|
||||||
|
@ -241,6 +242,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
|
||||||
|
|
||||||
protected lazy val joinType: Parser[JoinType] =
|
protected lazy val joinType: Parser[JoinType] =
|
||||||
INNER ^^^ Inner |
|
INNER ^^^ Inner |
|
||||||
|
LEFT ~ SEMI ^^^ LeftSemi |
|
||||||
LEFT ~ opt(OUTER) ^^^ LeftOuter |
|
LEFT ~ opt(OUTER) ^^^ LeftOuter |
|
||||||
RIGHT ~ opt(OUTER) ^^^ RightOuter |
|
RIGHT ~ opt(OUTER) ^^^ RightOuter |
|
||||||
FULL ~ opt(OUTER) ^^^ FullOuter
|
FULL ~ opt(OUTER) ^^^ FullOuter
|
||||||
|
|
|
@ -119,6 +119,11 @@ object HashFilteredJoin extends Logging with PredicateHelper {
|
||||||
case FilteredOperation(predicates, join @ Join(left, right, Inner, condition)) =>
|
case FilteredOperation(predicates, join @ Join(left, right, Inner, condition)) =>
|
||||||
logger.debug(s"Considering hash inner join on: ${predicates ++ condition}")
|
logger.debug(s"Considering hash inner join on: ${predicates ++ condition}")
|
||||||
splitPredicates(predicates ++ condition, join)
|
splitPredicates(predicates ++ condition, join)
|
||||||
|
// All predicates can be evaluated for left semi join (those that are in the WHERE
|
||||||
|
// clause can only from left table, so they can all be pushed down.)
|
||||||
|
case FilteredOperation(predicates, join @ Join(left, right, LeftSemi, condition)) =>
|
||||||
|
logger.debug(s"Considering hash left semi join on: ${predicates ++ condition}")
|
||||||
|
splitPredicates(predicates ++ condition, join)
|
||||||
case join @ Join(left, right, joinType, condition) =>
|
case join @ Join(left, right, joinType, condition) =>
|
||||||
logger.debug(s"Considering hash join on: $condition")
|
logger.debug(s"Considering hash join on: $condition")
|
||||||
splitPredicates(condition.toSeq, join)
|
splitPredicates(condition.toSeq, join)
|
||||||
|
|
|
@ -22,3 +22,4 @@ case object Inner extends JoinType
|
||||||
case object LeftOuter extends JoinType
|
case object LeftOuter extends JoinType
|
||||||
case object RightOuter extends JoinType
|
case object RightOuter extends JoinType
|
||||||
case object FullOuter extends JoinType
|
case object FullOuter extends JoinType
|
||||||
|
case object LeftSemi extends JoinType
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
package org.apache.spark.sql.catalyst.plans.logical
|
package org.apache.spark.sql.catalyst.plans.logical
|
||||||
|
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
import org.apache.spark.sql.catalyst.plans.JoinType
|
import org.apache.spark.sql.catalyst.plans.{LeftSemi, JoinType}
|
||||||
import org.apache.spark.sql.catalyst.types._
|
import org.apache.spark.sql.catalyst.types._
|
||||||
|
|
||||||
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
|
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
|
||||||
|
@ -81,7 +81,12 @@ case class Join(
|
||||||
condition: Option[Expression]) extends BinaryNode {
|
condition: Option[Expression]) extends BinaryNode {
|
||||||
|
|
||||||
def references = condition.map(_.references).getOrElse(Set.empty)
|
def references = condition.map(_.references).getOrElse(Set.empty)
|
||||||
def output = left.output ++ right.output
|
def output = joinType match {
|
||||||
|
case LeftSemi =>
|
||||||
|
left.output
|
||||||
|
case _ =>
|
||||||
|
left.output ++ right.output
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case class InsertIntoTable(
|
case class InsertIntoTable(
|
||||||
|
|
|
@ -193,6 +193,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
||||||
val strategies: Seq[Strategy] =
|
val strategies: Seq[Strategy] =
|
||||||
TakeOrdered ::
|
TakeOrdered ::
|
||||||
PartialAggregation ::
|
PartialAggregation ::
|
||||||
|
LeftSemiJoin ::
|
||||||
HashJoin ::
|
HashJoin ::
|
||||||
ParquetOperations ::
|
ParquetOperations ::
|
||||||
BasicOperators ::
|
BasicOperators ::
|
||||||
|
|
|
@ -28,6 +28,22 @@ import org.apache.spark.sql.parquet._
|
||||||
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
||||||
self: SQLContext#SparkPlanner =>
|
self: SQLContext#SparkPlanner =>
|
||||||
|
|
||||||
|
object LeftSemiJoin extends Strategy with PredicateHelper {
|
||||||
|
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
|
||||||
|
// Find left semi joins where at least some predicates can be evaluated by matching hash
|
||||||
|
// keys using the HashFilteredJoin pattern.
|
||||||
|
case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
|
||||||
|
val semiJoin = execution.LeftSemiJoinHash(
|
||||||
|
leftKeys, rightKeys, planLater(left), planLater(right))
|
||||||
|
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
|
||||||
|
// no predicate can be evaluated by matching hash keys
|
||||||
|
case logical.Join(left, right, LeftSemi, condition) =>
|
||||||
|
execution.LeftSemiJoinBNL(
|
||||||
|
planLater(left), planLater(right), condition)(sparkContext) :: Nil
|
||||||
|
case _ => Nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
object HashJoin extends Strategy with PredicateHelper {
|
object HashJoin extends Strategy with PredicateHelper {
|
||||||
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
|
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
|
||||||
// Find inner joins where at least some predicates can be evaluated by matching hash keys
|
// Find inner joins where at least some predicates can be evaluated by matching hash keys
|
||||||
|
|
|
@ -140,6 +140,137 @@ case class HashJoin(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* :: DeveloperApi ::
|
||||||
|
* Build the right table's join keys into a HashSet, and iteratively go through the left
|
||||||
|
* table, to find the if join keys are in the Hash set.
|
||||||
|
*/
|
||||||
|
@DeveloperApi
|
||||||
|
case class LeftSemiJoinHash(
|
||||||
|
leftKeys: Seq[Expression],
|
||||||
|
rightKeys: Seq[Expression],
|
||||||
|
left: SparkPlan,
|
||||||
|
right: SparkPlan) extends BinaryNode {
|
||||||
|
|
||||||
|
override def outputPartitioning: Partitioning = left.outputPartitioning
|
||||||
|
|
||||||
|
override def requiredChildDistribution =
|
||||||
|
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
|
||||||
|
|
||||||
|
val (buildPlan, streamedPlan) = (right, left)
|
||||||
|
val (buildKeys, streamedKeys) = (rightKeys, leftKeys)
|
||||||
|
|
||||||
|
def output = left.output
|
||||||
|
|
||||||
|
@transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
|
||||||
|
@transient lazy val streamSideKeyGenerator =
|
||||||
|
() => new MutableProjection(streamedKeys, streamedPlan.output)
|
||||||
|
|
||||||
|
def execute() = {
|
||||||
|
|
||||||
|
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
|
||||||
|
val hashTable = new java.util.HashSet[Row]()
|
||||||
|
var currentRow: Row = null
|
||||||
|
|
||||||
|
// Create a Hash set of buildKeys
|
||||||
|
while (buildIter.hasNext) {
|
||||||
|
currentRow = buildIter.next()
|
||||||
|
val rowKey = buildSideKeyGenerator(currentRow)
|
||||||
|
if(!rowKey.anyNull) {
|
||||||
|
val keyExists = hashTable.contains(rowKey)
|
||||||
|
if (!keyExists) {
|
||||||
|
hashTable.add(rowKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
new Iterator[Row] {
|
||||||
|
private[this] var currentStreamedRow: Row = _
|
||||||
|
private[this] var currentHashMatched: Boolean = false
|
||||||
|
|
||||||
|
private[this] val joinKeys = streamSideKeyGenerator()
|
||||||
|
|
||||||
|
override final def hasNext: Boolean =
|
||||||
|
streamIter.hasNext && fetchNext()
|
||||||
|
|
||||||
|
override final def next() = {
|
||||||
|
currentStreamedRow
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Searches the streamed iterator for the next row that has at least one match in hashtable.
|
||||||
|
*
|
||||||
|
* @return true if the search is successful, and false the streamed iterator runs out of
|
||||||
|
* tuples.
|
||||||
|
*/
|
||||||
|
private final def fetchNext(): Boolean = {
|
||||||
|
currentHashMatched = false
|
||||||
|
while (!currentHashMatched && streamIter.hasNext) {
|
||||||
|
currentStreamedRow = streamIter.next()
|
||||||
|
if (!joinKeys(currentStreamedRow).anyNull) {
|
||||||
|
currentHashMatched = hashTable.contains(joinKeys.currentValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
currentHashMatched
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* :: DeveloperApi ::
|
||||||
|
* Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
|
||||||
|
* for hash join.
|
||||||
|
*/
|
||||||
|
@DeveloperApi
|
||||||
|
case class LeftSemiJoinBNL(
|
||||||
|
streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
|
||||||
|
(@transient sc: SparkContext)
|
||||||
|
extends BinaryNode {
|
||||||
|
// TODO: Override requiredChildDistribution.
|
||||||
|
|
||||||
|
override def outputPartitioning: Partitioning = streamed.outputPartitioning
|
||||||
|
|
||||||
|
override def otherCopyArgs = sc :: Nil
|
||||||
|
|
||||||
|
def output = left.output
|
||||||
|
|
||||||
|
/** The Streamed Relation */
|
||||||
|
def left = streamed
|
||||||
|
/** The Broadcast relation */
|
||||||
|
def right = broadcast
|
||||||
|
|
||||||
|
@transient lazy val boundCondition =
|
||||||
|
InterpretedPredicate(
|
||||||
|
condition
|
||||||
|
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
|
||||||
|
.getOrElse(Literal(true)))
|
||||||
|
|
||||||
|
|
||||||
|
def execute() = {
|
||||||
|
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
|
||||||
|
|
||||||
|
streamed.execute().mapPartitions { streamedIter =>
|
||||||
|
val joinedRow = new JoinedRow
|
||||||
|
|
||||||
|
streamedIter.filter(streamedRow => {
|
||||||
|
var i = 0
|
||||||
|
var matched = false
|
||||||
|
|
||||||
|
while (i < broadcastedRelation.value.size && !matched) {
|
||||||
|
val broadcastedRow = broadcastedRelation.value(i)
|
||||||
|
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
|
||||||
|
matched = true
|
||||||
|
}
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
matched
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* :: DeveloperApi ::
|
* :: DeveloperApi ::
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -44,7 +44,7 @@ class QueryTest extends FunSuite {
|
||||||
fail(
|
fail(
|
||||||
s"""
|
s"""
|
||||||
|Exception thrown while executing query:
|
|Exception thrown while executing query:
|
||||||
|${rdd.logicalPlan}
|
|${rdd.queryExecution}
|
||||||
|== Exception ==
|
|== Exception ==
|
||||||
|$e
|
|$e
|
||||||
""".stripMargin)
|
""".stripMargin)
|
||||||
|
|
|
@ -40,6 +40,13 @@ class SQLQuerySuite extends QueryTest {
|
||||||
arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq)
|
arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("left semi greater than predicate") {
|
||||||
|
checkAnswer(
|
||||||
|
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
|
||||||
|
Seq((3,1), (3,2))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
test("index into array of arrays") {
|
test("index into array of arrays") {
|
||||||
checkAnswer(
|
checkAnswer(
|
||||||
sql(
|
sql(
|
||||||
|
|
|
@ -224,6 +224,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
|
||||||
DataSinks,
|
DataSinks,
|
||||||
Scripts,
|
Scripts,
|
||||||
PartialAggregation,
|
PartialAggregation,
|
||||||
|
LeftSemiJoin,
|
||||||
HashJoin,
|
HashJoin,
|
||||||
BasicOperators,
|
BasicOperators,
|
||||||
CartesianProduct,
|
CartesianProduct,
|
||||||
|
|
|
@ -685,6 +685,7 @@ private[hive] object HiveQl {
|
||||||
case "TOK_RIGHTOUTERJOIN" => RightOuter
|
case "TOK_RIGHTOUTERJOIN" => RightOuter
|
||||||
case "TOK_LEFTOUTERJOIN" => LeftOuter
|
case "TOK_LEFTOUTERJOIN" => LeftOuter
|
||||||
case "TOK_FULLOUTERJOIN" => FullOuter
|
case "TOK_FULLOUTERJOIN" => FullOuter
|
||||||
|
case "TOK_LEFTSEMIJOIN" => LeftSemi
|
||||||
}
|
}
|
||||||
assert(other.size <= 1, "Unhandled join clauses.")
|
assert(other.size <= 1, "Unhandled join clauses.")
|
||||||
Join(nodeToRelation(relation1),
|
Join(nodeToRelation(relation1),
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
Hank 2
|
||||||
|
Hank 2
|
||||||
|
Joe 2
|
||||||
|
Joe 2
|
|
@ -0,0 +1 @@
|
||||||
|
0
|
|
@ -0,0 +1,4 @@
|
||||||
|
Hank 2
|
||||||
|
Hank 2
|
||||||
|
Joe 2
|
||||||
|
Joe 2
|
|
@ -0,0 +1,2 @@
|
||||||
|
2 Tie
|
||||||
|
2 Tie
|
|
@ -0,0 +1 @@
|
||||||
|
0
|
|
@ -0,0 +1,2 @@
|
||||||
|
1
|
||||||
|
1
|
|
@ -0,0 +1,2 @@
|
||||||
|
1
|
||||||
|
1
|
|
@ -0,0 +1,20 @@
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
|
@ -0,0 +1 @@
|
||||||
|
0
|
|
@ -0,0 +1 @@
|
||||||
|
0
|
|
@ -0,0 +1 @@
|
||||||
|
0
|
|
@ -0,0 +1,2 @@
|
||||||
|
1
|
||||||
|
1
|
|
@ -480,6 +480,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
|
||||||
"lateral_view",
|
"lateral_view",
|
||||||
"lateral_view_cp",
|
"lateral_view_cp",
|
||||||
"lateral_view_ppd",
|
"lateral_view_ppd",
|
||||||
|
"leftsemijoin",
|
||||||
|
"leftsemijoin_mr",
|
||||||
"lineage1",
|
"lineage1",
|
||||||
"literal_double",
|
"literal_double",
|
||||||
"literal_ints",
|
"literal_ints",
|
||||||
|
|
Loading…
Reference in a new issue