From 65805ab6eaa6b0ed1dc7f6cbd3c3368eb750b375 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 24 Feb 2016 12:03:45 -0800 Subject: [PATCH] Revert "Revert "[SPARK-13383][SQL] Keep broadcast hint after column pruning"" This reverts commit 382b27babf7771b724f7abff78195a858631d138. --- .../plans/logical/basicOperators.scala | 4 +++ ...uite.scala => JoinOptimizationSuite.scala} | 35 ++++++++++++++++--- .../spark/sql/execution/SparkStrategies.scala | 12 ++++--- 3 files changed, 42 insertions(+), 9 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{JoinOrderSuite.scala => JoinOptimizationSuite.scala} (77%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index af43cb3786..5d2a65b716 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -332,6 +332,10 @@ case class Join( */ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + + // We manually set statistics of BroadcastHint to smallest value to make sure + // the plan wrapped by BroadcastHint will be considered to broadcast later. + override def statistics: Statistics = Statistics(sizeInBytes = 1) } case class InsertIntoTable( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala similarity index 77% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index a5b487bcc8..d482519827 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -23,18 +23,18 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -class JoinOrderSuite extends PlanTest { +class JoinOptimizationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: - Batch("Filter Pushdown", Once, + Batch("Filter Pushdown", FixedPoint(100), CombineFilters, PushPredicateThroughProject, BooleanSimplification, @@ -92,4 +92,31 @@ class JoinOrderSuite extends PlanTest { comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer)) } + + test("broadcasthint sets relation statistics to smallest value") { + val input = LocalRelation('key.int, 'value.string) + + val query = + Project(Seq($"x.key", $"y.key"), + Join( + SubqueryAlias("x", input), + BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze + + val optimized = Optimize.execute(query) + + val expected = + Project(Seq($"x.key", $"y.key"), + Join( + Project(Seq($"x.key"), SubqueryAlias("x", input)), + BroadcastHint( + Project(Seq($"y.key"), SubqueryAlias("y", input))), + Inner, None)).analyze + + comparePlans(optimized, expected) + + val broadcastChildren = optimized.collect { + case Join(_, r, _, _) if r.statistics.sizeInBytes == 1 => r + } + assert(broadcastChildren.size == 1) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7347156398..247eb054a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -81,11 +81,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ object CanBroadcast { - def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { - case BroadcastHint(p) => Some(p) - case p if sqlContext.conf.autoBroadcastJoinThreshold > 0 && - p.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => Some(p) - case _ => None + def unapply(plan: LogicalPlan): Option[LogicalPlan] = { + if (sqlContext.conf.autoBroadcastJoinThreshold > 0 && + plan.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold) { + Some(plan) + } else { + None + } } }