diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 07f4d2946c..8b4cf5bac0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -19,17 +19,13 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi} -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ class JoinSuite extends QueryTest with BeforeAndAfterEach { - // Ensures tables are loaded. TestData @@ -41,54 +37,65 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { assert(planned.size === 1) } - test("join operator selection") { - def assertJoin(sqlString: String, c: Class[_]): Any = { - val rdd = sql(sqlString) - val physical = rdd.queryExecution.sparkPlan - val operators = physical.collect { - case j: ShuffledHashJoin => j - case j: HashOuterJoin => j - case j: LeftSemiJoinHash => j - case j: BroadcastHashJoin => j - case j: LeftSemiJoinBNL => j - case j: CartesianProduct => j - case j: BroadcastNestedLoopJoin => j - } - - assert(operators.size === 1) - if (operators(0).getClass() != c) { - fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical") - } + def assertJoin(sqlString: String, c: Class[_]): Any = { + val rdd = sql(sqlString) + val physical = rdd.queryExecution.sparkPlan + val operators = physical.collect { + case j: ShuffledHashJoin => j + case j: HashOuterJoin => j + case j: LeftSemiJoinHash => j + case j: BroadcastHashJoin => j + case j: LeftSemiJoinBNL => j + case j: CartesianProduct => j + case j: BroadcastNestedLoopJoin => j } - val cases1 = Seq( - ("SELECT * FROM testData left semi join testData2 ON key = a", classOf[LeftSemiJoinHash]), - ("SELECT * FROM testData left semi join testData2", classOf[LeftSemiJoinBNL]), - ("SELECT * FROM testData join testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData join testData2 where key=2", classOf[CartesianProduct]), - ("SELECT * FROM testData left join testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData right join testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData full outer join testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData left join testData2 where key=2", classOf[CartesianProduct]), - ("SELECT * FROM testData right join testData2 where key=2", classOf[CartesianProduct]), - ("SELECT * FROM testData full outer join testData2 where key=2", classOf[CartesianProduct]), - ("SELECT * FROM testData join testData2 where key>a", classOf[CartesianProduct]), - ("SELECT * FROM testData full outer join testData2 where key>a", classOf[CartesianProduct]), - ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData left join testData2 ON key = a", classOf[HashOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a where key=2", + assert(operators.size === 1) + if (operators(0).getClass() != c) { + fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical") + } + } + + test("join operator selection") { + clearCache() + + Seq( + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), + ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), + ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]), + ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), + ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[HashOuterJoin]), - ("SELECT * FROM testData right join testData2 ON key = a and key=2", + ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[HashOuterJoin]), - ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]), - ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]) - // TODO add BroadcastNestedLoopJoin - ) - cases1.foreach { c => assertJoin(c._1, c._2) } + ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]) + // TODO add BroadcastNestedLoopJoin + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } + + test("broadcasted hash join operator selection") { + clearCache() + sql("CACHE TABLE testData") + + Seq( + ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + + sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { @@ -171,7 +178,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { (4, "D", 4, "d") :: (5, "E", null, null) :: (6, "F", null, null) :: Nil) - + checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)), (1, "A", null, null) :: @@ -180,7 +187,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { (4, "D", 4, "d") :: (5, "E", null, null) :: (6, "F", null, null) :: Nil) - + checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)), (1, "A", null, null) :: @@ -189,7 +196,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { (4, "D", 4, "d") :: (5, "E", null, null) :: (6, "F", null, null) :: Nil) - + checkAnswer( upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)), (1, "A", 1, "a") :: @@ -300,7 +307,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { (4, "D", 4, "D") :: (null, null, 5, "E") :: (null, null, 6, "F") :: Nil) - + checkAnswer( left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), (1, "A", null, null) :: @@ -310,7 +317,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { (4, "D", 4, "D") :: (null, null, 5, "E") :: (null, null, 6, "F") :: Nil) - + checkAnswer( left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), (1, "A", null, null) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 6c38575b13..c4dd3e860f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -80,7 +80,7 @@ object TestData { UpperCaseData(3, "C") :: UpperCaseData(4, "D") :: UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil) + UpperCaseData(6, "F") :: Nil).toSchemaRDD upperCaseData.registerTempTable("upperCaseData") case class LowerCaseData(n: Int, l: String) @@ -89,7 +89,7 @@ object TestData { LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil) + LowerCaseData(4, "d") :: Nil).toSchemaRDD lowerCaseData.registerTempTable("lowerCaseData") case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])