From b6a873d6d4682796f55dbafadd0b5cad881f96ea Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 21 Feb 2016 12:32:31 -0800 Subject: [PATCH] [SPARK-13136][SQL] Create a dedicated Broadcast exchange operator Quite a few Spark SQL join operators broadcast one side of the join to all nodes. The are a few problems with this: - This conflates broadcasting (a data exchange) with joining. Data exchanges should be managed by a different operator. - All these nodes implement their own (duplicate) broadcasting logic. - Re-use of indices is quite hard. This PR defines both a ```BroadcastDistribution``` and ```BroadcastPartitioning```, these contain a `BroadcastMode`. The `BroadcastMode` defines the way in which we transform the Array of `InternalRow`'s into an index. We currently support the following `BroadcastMode`'s: - IdentityBroadcastMode: This broadcasts the rows in their original form. - HashSetBroadcastMode: This applies a projection to the input rows, deduplicates these rows and broadcasts the resulting `Set`. - HashedRelationBroadcastMode: This transforms the input rows into a `HashedRelation`, and broadcasts this index. To match this distribution we implement a ```BroadcastExchange``` operator which will perform the broadcast for us, and have ```EnsureRequirements``` plan this operator. The old Exchange operator has been renamed into ShuffleExchange in order to clearly separate between Shuffled and Broadcasted exchanges. Finally the classes in Exchange.scala have been moved to a dedicated package. cc rxin davies Author: Herman van Hovell Closes #11083 from hvanhovell/SPARK-13136. --- .../plans/physical/broadcastMode.scala | 35 +++ .../plans/physical/partitioning.scala | 30 +- .../org/apache/spark/sql/SQLContext.scala | 12 +- .../spark/sql/execution/SparkPlan.scala | 34 ++- .../spark/sql/execution/SparkStrategies.scala | 8 +- .../sql/execution/WholeStageCodegen.scala | 5 + .../exchange/BroadcastExchange.scala | 89 ++++++ .../exchange/EnsureRequirements.scala | 261 ++++++++++++++++++ .../{ => exchange}/ExchangeCoordinator.scala | 46 +-- .../ShuffleExchange.scala} | 255 +---------------- .../execution/joins/BroadcastHashJoin.scala | 75 +---- .../joins/BroadcastLeftSemiJoinHash.scala | 25 +- .../joins/BroadcastNestedLoopJoin.scala | 25 +- .../sql/execution/joins/HashSemiJoin.scala | 51 ++-- .../sql/execution/joins/HashedRelation.scala | 22 +- .../sql/execution/joins/LeftSemiJoinBNL.scala | 23 +- .../apache/spark/sql/execution/limit.scala | 7 +- .../apache/spark/sql/CachedTableSuite.scala | 4 +- .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- .../execution/ExchangeCoordinatorSuite.scala | 21 +- .../spark/sql/execution/ExchangeSuite.scala | 3 +- .../spark/sql/execution/PlannerSuite.scala | 21 +- .../execution/joins/BroadcastJoinSuite.scala | 5 +- .../sql/execution/joins/InnerJoinSuite.scala | 11 +- .../sql/execution/joins/OuterJoinSuite.scala | 3 +- .../sql/execution/joins/SemiJoinSuite.scala | 3 +- .../spark/sql/sources/BucketedReadSuite.scala | 13 +- 27 files changed, 658 insertions(+), 433 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala rename sql/core/src/main/scala/org/apache/spark/sql/execution/{ => exchange}/ExchangeCoordinator.scala (85%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{Exchange.scala => exchange/ShuffleExchange.scala} (50%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala new file mode 100644 index 0000000000..c646dcfa11 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.physical + +import org.apache.spark.sql.catalyst.InternalRow + +/** + * Marker trait to identify the shape in which tuples are broadcasted. Typical examples of this are + * identity (tuples remain unchanged) or hashed (tuples are converted into some hash index). + */ +trait BroadcastMode { + def transform(rows: Array[InternalRow]): Any +} + +/** + * IdentityBroadcastMode requires that rows are broadcasted in their original form. + */ +case object IdentityBroadcastMode extends BroadcastMode { + override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index d6e10c412c..45e2841ec9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, IntegerType} @@ -75,6 +76,12 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { def clustering: Set[Expression] = ordering.map(_.child).toSet } +/** + * Represents data where tuples are broadcasted to every node. It is quite common that the + * entire set of tuples is transformed into different data structure. + */ +case class BroadcastDistribution(mode: BroadcastMode) extends Distribution + /** * Describes how an operator's output is split across partitions. The `compatibleWith`, * `guarantees`, and `satisfies` methods describe relationships between child partitionings, @@ -213,7 +220,10 @@ case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning { case object SinglePartition extends Partitioning { val numPartitions = 1 - override def satisfies(required: Distribution): Boolean = true + override def satisfies(required: Distribution): Boolean = required match { + case _: BroadcastDistribution => false + case _ => true + } override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1 @@ -351,3 +361,21 @@ case class PartitioningCollection(partitionings: Seq[Partitioning]) partitionings.map(_.toString).mkString("(", " or ", ")") } } + +/** + * Represents a partitioning where rows are collected, transformed and broadcasted to each + * node in the cluster. + */ +case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { + override val numPartitions: Int = 1 + + override def satisfies(required: Distribution): Boolean = required match { + case BroadcastDistribution(m) if m == mode => true + case _ => false + } + + override def compatibleWith(other: Partitioning): Boolean = other match { + case BroadcastPartitioning(m) if m == mode => true + case _ => false + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 932df36b85..a2f386850c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ @@ -59,7 +60,6 @@ import org.apache.spark.util.Utils * @groupname config Configuration * @groupname dataframes Custom DataFrame Creation * @groupname Ungrouped Support functions for language integrated queries - * * @since 1.0.0 */ class SQLContext private[sql]( @@ -313,10 +313,10 @@ class SQLContext private[sql]( } /** - * Returns true if the [[Queryable]] is currently cached in-memory. - * @group cachemgmt - * @since 1.3.0 - */ + * Returns true if the [[Queryable]] is currently cached in-memory. + * @group cachemgmt + * @since 1.3.0 + */ private[sql] def isCached(qName: Queryable): Boolean = { cacheManager.lookupCachedData(qName).nonEmpty } @@ -364,6 +364,7 @@ class SQLContext private[sql]( /** * Converts $"col name" into an [[Column]]. + * * @since 1.3.0 */ // This must live here to preserve binary compatibility with Spark < 1.5. @@ -728,7 +729,6 @@ class SQLContext private[sql]( * cached/persisted before, it's also unpersisted. * * @param tableName the name of the table to be unregistered. - * * @group basic * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 477a9460d7..3be4cce045 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -24,6 +24,7 @@ import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ import org.apache.spark.Logging +import org.apache.spark.broadcast import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} @@ -108,15 +109,30 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) /** - * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute - * after adding query plan information to created RDDs for visualization. - * Concrete implementations of SparkPlan should override doExecute instead. + * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute after + * preparations. Concrete implementations of SparkPlan should override doExecute. */ - final def execute(): RDD[InternalRow] = { + final def execute(): RDD[InternalRow] = executeQuery { + doExecute() + } + + /** + * Returns the result of this query as a broadcast variable by delegating to doBroadcast after + * preparations. Concrete implementations of SparkPlan should override doBroadcast. + */ + final def executeBroadcast[T](): broadcast.Broadcast[T] = executeQuery { + doExecuteBroadcast() + } + + /** + * Execute a query after preparing the query and adding query plan information to created RDDs + * for visualization. + */ + private final def executeQuery[T](query: => T): T = { RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() waitForSubqueries() - doExecute() + query } } @@ -192,6 +208,14 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ protected def doExecute(): RDD[InternalRow] + /** + * Overridden by concrete implementations of SparkPlan. + * Produces the result of the query as a broadcast variable. + */ + protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + throw new UnsupportedOperationException(s"$nodeName does not implement doExecuteBroadcast") + } + /** * Runs this query returning the result as an array. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 382654afac..7347156398 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{execution, Strategy} +import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -25,6 +26,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} @@ -328,7 +330,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { - execution.Exchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil + ShuffleExchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil } else { execution.Coalesce(numPartitions, planLater(child)) :: Nil } @@ -367,7 +369,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r @ logical.Range(start, end, step, numSlices, output) => execution.Range(start, step, numSlices, r.numElements, output) :: Nil case logical.RepartitionByExpression(expressions, child, nPartitions) => - execution.Exchange(HashPartitioning( + exchange.ShuffleExchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil case e @ python.EvaluatePython(udf, child, _) => python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 990eeb22b6..d79b547137 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import scala.collection.mutable.ArrayBuffer +import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow @@ -172,6 +173,10 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { child.execute() } + override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + child.doExecuteBroadcast() + } + override def supportCodegen: Boolean = false override def upstream(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala new file mode 100644 index 0000000000..40cad4b1a7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ + +import org.apache.spark.broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution, UnaryNode} +import org.apache.spark.util.ThreadUtils + +/** + * A [[BroadcastExchange]] collects, transforms and finally broadcasts the result of a transformed + * SparkPlan. + */ +case class BroadcastExchange( + mode: BroadcastMode, + child: SparkPlan) extends UnaryNode { + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) + + @transient + private val timeout: Duration = { + val timeoutValue = sqlContext.conf.broadcastTimeout + if (timeoutValue < 0) { + Duration.Inf + } else { + timeoutValue.seconds + } + } + + @transient + private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + Future { + // This will run in another thread. Set the execution id so that we can connect these jobs + // with the correct execution. + SQLExecution.withExecutionId(sparkContext, executionId) { + // Note that we use .executeCollect() because we don't want to convert data to Scala types + val input: Array[InternalRow] = child.executeCollect() + + // Construct and broadcast the relation. + sparkContext.broadcast(mode.transform(input)) + } + }(BroadcastExchange.executionContext) + } + + override protected def doPrepare(): Unit = { + // Materialize the future. + relationFuture + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException( + "BroadcastExchange does not support the execute() code path.") + } + + override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + val result = Await.result(relationFuture, timeout) + result.asInstanceOf[broadcast.Broadcast[T]] + } +} + +object BroadcastExchange { + private[execution] val executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("broadcast-exchange", 128)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala new file mode 100644 index 0000000000..709a424636 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ + +/** + * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] + * of input data meets the + * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for + * each operator by inserting [[ShuffleExchange]] Operators where required. Also ensure that the + * input partition ordering requirements are met. + */ +private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { + private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions + + private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize + + private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled + + private def minNumPostShufflePartitions: Option[Int] = { + val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions + if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None + } + + /** + * Given a required distribution, returns a partitioning that satisfies that distribution. + */ + private def createPartitioning( + requiredDistribution: Distribution, + numPartitions: Int): Partitioning = { + requiredDistribution match { + case AllTuples => SinglePartition + case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) + case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) + case dist => sys.error(s"Do not know how to satisfy distribution $dist") + } + } + + /** + * Adds [[ExchangeCoordinator]] to [[ShuffleExchange]]s if adaptive query execution is enabled + * and partitioning schemes of these [[ShuffleExchange]]s support [[ExchangeCoordinator]]. + */ + private def withExchangeCoordinator( + children: Seq[SparkPlan], + requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = { + val supportsCoordinator = + if (children.exists(_.isInstanceOf[ShuffleExchange])) { + // Right now, ExchangeCoordinator only support HashPartitionings. + children.forall { + case e @ ShuffleExchange(hash: HashPartitioning, _, _) => true + case child => + child.outputPartitioning match { + case hash: HashPartitioning => true + case collection: PartitioningCollection => + collection.partitionings.forall(_.isInstanceOf[HashPartitioning]) + case _ => false + } + } + } else { + // In this case, although we do not have Exchange operators, we may still need to + // shuffle data when we have more than one children because data generated by + // these children may not be partitioned in the same way. + // Please see the comment in withCoordinator for more details. + val supportsDistribution = + requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution]) + children.length > 1 && supportsDistribution + } + + val withCoordinator = + if (adaptiveExecutionEnabled && supportsCoordinator) { + val coordinator = + new ExchangeCoordinator( + children.length, + targetPostShuffleInputSize, + minNumPostShufflePartitions) + children.zip(requiredChildDistributions).map { + case (e: ShuffleExchange, _) => + // This child is an Exchange, we need to add the coordinator. + e.copy(coordinator = Some(coordinator)) + case (child, distribution) => + // If this child is not an Exchange, we need to add an Exchange for now. + // Ideally, we can try to avoid this Exchange. However, when we reach here, + // there are at least two children operators (because if there is a single child + // and we can avoid Exchange, supportsCoordinator will be false and we + // will not reach here.). Although we can make two children have the same number of + // post-shuffle partitions. Their numbers of pre-shuffle partitions may be different. + // For example, let's say we have the following plan + // Join + // / \ + // Agg Exchange + // / \ + // Exchange t2 + // / + // t1 + // In this case, because a post-shuffle partition can include multiple pre-shuffle + // partitions, a HashPartitioning will not be strictly partitioned by the hashcodes + // after shuffle. So, even we can use the child Exchange operator of the Join to + // have a number of post-shuffle partitions that matches the number of partitions of + // Agg, we cannot say these two children are partitioned in the same way. + // Here is another case + // Join + // / \ + // Agg1 Agg2 + // / \ + // Exchange1 Exchange2 + // / \ + // t1 t2 + // In this case, two Aggs shuffle data with the same column of the join condition. + // After we use ExchangeCoordinator, these two Aggs may not be partitioned in the same + // way. Let's say that Agg1 and Agg2 both have 5 pre-shuffle partitions and 2 + // post-shuffle partitions. It is possible that Agg1 fetches those pre-shuffle + // partitions by using a partitionStartIndices [0, 3]. However, Agg2 may fetch its + // pre-shuffle partitions by using another partitionStartIndices [0, 4]. + // So, Agg1 and Agg2 are actually not co-partitioned. + // + // It will be great to introduce a new Partitioning to represent the post-shuffle + // partitions when one post-shuffle partition includes multiple pre-shuffle partitions. + val targetPartitioning = + createPartitioning(distribution, defaultNumPreShufflePartitions) + assert(targetPartitioning.isInstanceOf[HashPartitioning]) + ShuffleExchange(targetPartitioning, child, Some(coordinator)) + } + } else { + // If we do not need ExchangeCoordinator, the original children are returned. + children + } + + withCoordinator + } + + private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { + val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution + val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering + var children: Seq[SparkPlan] = operator.children + assert(requiredChildDistributions.length == children.length) + assert(requiredChildOrderings.length == children.length) + + // Ensure that the operator's children satisfy their output distribution requirements: + children = children.zip(requiredChildDistributions).map { + case (child, distribution) if child.outputPartitioning.satisfies(distribution) => + child + case (child, BroadcastDistribution(mode)) => + BroadcastExchange(mode, child) + case (child, distribution) => + ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + } + + // If the operator has multiple children and specifies child output distributions (e.g. join), + // then the children's output partitionings must be compatible: + def requireCompatiblePartitioning(distribution: Distribution): Boolean = distribution match { + case UnspecifiedDistribution => false + case BroadcastDistribution(_) => false + case _ => true + } + if (children.length > 1 + && requiredChildDistributions.exists(requireCompatiblePartitioning) + && !Partitioning.allCompatible(children.map(_.outputPartitioning))) { + + // First check if the existing partitions of the children all match. This means they are + // partitioned by the same partitioning into the same number of partitions. In that case, + // don't try to make them match `defaultPartitions`, just use the existing partitioning. + val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max + val useExistingPartitioning = children.zip(requiredChildDistributions).forall { + case (child, distribution) => + child.outputPartitioning.guarantees( + createPartitioning(distribution, maxChildrenNumPartitions)) + } + + children = if (useExistingPartitioning) { + // We do not need to shuffle any child's output. + children + } else { + // We need to shuffle at least one child's output. + // Now, we will determine the number of partitions that will be used by created + // partitioning schemes. + val numPartitions = { + // Let's see if we need to shuffle all child's outputs when we use + // maxChildrenNumPartitions. + val shufflesAllChildren = children.zip(requiredChildDistributions).forall { + case (child, distribution) => + !child.outputPartitioning.guarantees( + createPartitioning(distribution, maxChildrenNumPartitions)) + } + // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the + // number of partitions. Otherwise, we use maxChildrenNumPartitions. + if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions + } + + children.zip(requiredChildDistributions).map { + case (child, distribution) => + val targetPartitioning = createPartitioning(distribution, numPartitions) + if (child.outputPartitioning.guarantees(targetPartitioning)) { + child + } else { + child match { + // If child is an exchange, we replace it with + // a new one having targetPartitioning. + case ShuffleExchange(_, c, _) => ShuffleExchange(targetPartitioning, c) + case _ => ShuffleExchange(targetPartitioning, child) + } + } + } + } + } + + // Now, we need to add ExchangeCoordinator if necessary. + // Actually, it is not a good idea to add ExchangeCoordinators while we are adding Exchanges. + // However, with the way that we plan the query, we do not have a place where we have a + // global picture of all shuffle dependencies of a post-shuffle stage. So, we add coordinator + // at here for now. + // Once we finish https://issues.apache.org/jira/browse/SPARK-10665, + // we can first add Exchanges and then add coordinator once we have a DAG of query fragments. + children = withExchangeCoordinator(children, requiredChildDistributions) + + // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: + children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => + if (requiredOrdering.nonEmpty) { + // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. + if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) { + Sort(requiredOrdering, global = false, child = child) + } else { + child + } + } else { + child + } + } + + operator.withNewChildren(children) + } + + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case operator @ ShuffleExchange(partitioning, child, _) => + child.children match { + case ShuffleExchange(childPartitioning, baseChild, _)::Nil => + if (childPartitioning.guarantees(partitioning)) child else operator + case _ => operator + } + case operator: SparkPlan => ensureDistributionAndOrdering(operator) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala similarity index 85% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index 07015e5a5a..6f3bb0ad2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.sql.execution.exchange import java.util.{HashMap => JHashMap, Map => JMap} import javax.annotation.concurrent.GuardedBy @@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Logging, MapOutputStatistics, ShuffleDependency, SimpleFutureAction} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} /** * A coordinator used to determines how we shuffle data between stages generated by Spark SQL. @@ -33,9 +34,9 @@ import org.apache.spark.sql.catalyst.InternalRow * * A coordinator is constructed with three parameters, `numExchanges`, * `targetPostShuffleInputSize`, and `minNumPostShufflePartitions`. - * - `numExchanges` is used to indicated that how many [[Exchange]]s that will be registered to - * this coordinator. So, when we start to do any actual work, we have a way to make sure that - * we have got expected number of [[Exchange]]s. + * - `numExchanges` is used to indicated that how many [[ShuffleExchange]]s that will be registered + * to this coordinator. So, when we start to do any actual work, we have a way to make sure that + * we have got expected number of [[ShuffleExchange]]s. * - `targetPostShuffleInputSize` is the targeted size of a post-shuffle partition's * input data size. With this parameter, we can estimate the number of post-shuffle partitions. * This parameter is configured through @@ -45,26 +46,27 @@ import org.apache.spark.sql.catalyst.InternalRow * partitions. * * The workflow of this coordinator is described as follows: - * - Before the execution of a [[SparkPlan]], for an [[Exchange]] operator, + * - Before the execution of a [[SparkPlan]], for an [[ShuffleExchange]] operator, * if an [[ExchangeCoordinator]] is assigned to it, it registers itself to this coordinator. * This happens in the `doPrepare` method. - * - Once we start to execute a physical plan, an [[Exchange]] registered to this coordinator will - * call `postShuffleRDD` to get its corresponding post-shuffle [[ShuffledRowRDD]]. - * If this coordinator has made the decision on how to shuffle data, this [[Exchange]] will - * immediately get its corresponding post-shuffle [[ShuffledRowRDD]]. + * - Once we start to execute a physical plan, an [[ShuffleExchange]] registered to this + * coordinator will call `postShuffleRDD` to get its corresponding post-shuffle + * [[ShuffledRowRDD]]. + * If this coordinator has made the decision on how to shuffle data, this [[ShuffleExchange]] + * will immediately get its corresponding post-shuffle [[ShuffledRowRDD]]. * - If this coordinator has not made the decision on how to shuffle data, it will ask those - * registered [[Exchange]]s to submit their pre-shuffle stages. Then, based on the the size - * statistics of pre-shuffle partitions, this coordinator will determine the number of + * registered [[ShuffleExchange]]s to submit their pre-shuffle stages. Then, based on the the + * size statistics of pre-shuffle partitions, this coordinator will determine the number of * post-shuffle partitions and pack multiple pre-shuffle partitions with continuous indices * to a single post-shuffle partition whenever necessary. * - Finally, this coordinator will create post-shuffle [[ShuffledRowRDD]]s for all registered - * [[Exchange]]s. So, when an [[Exchange]] calls `postShuffleRDD`, this coordinator can - * lookup the corresponding [[RDD]]. + * [[ShuffleExchange]]s. So, when an [[ShuffleExchange]] calls `postShuffleRDD`, this coordinator + * can lookup the corresponding [[RDD]]. * * The strategy used to determine the number of post-shuffle partitions is described as follows. * To determine the number of post-shuffle partitions, we have a target input size for a * post-shuffle partition. Once we have size statistics of pre-shuffle partitions from stages - * corresponding to the registered [[Exchange]]s, we will do a pass of those statistics and + * corresponding to the registered [[ShuffleExchange]]s, we will do a pass of those statistics and * pack pre-shuffle partitions with continuous indices to a single post-shuffle partition until * the size of a post-shuffle partition is equal or greater than the target size. * For example, we have two stages with the following pre-shuffle partition size statistics: @@ -83,11 +85,11 @@ private[sql] class ExchangeCoordinator( extends Logging { // The registered Exchange operators. - private[this] val exchanges = ArrayBuffer[Exchange]() + private[this] val exchanges = ArrayBuffer[ShuffleExchange]() // This map is used to lookup the post-shuffle ShuffledRowRDD for an Exchange operator. - private[this] val postShuffleRDDs: JMap[Exchange, ShuffledRowRDD] = - new JHashMap[Exchange, ShuffledRowRDD](numExchanges) + private[this] val postShuffleRDDs: JMap[ShuffleExchange, ShuffledRowRDD] = + new JHashMap[ShuffleExchange, ShuffledRowRDD](numExchanges) // A boolean that indicates if this coordinator has made decision on how to shuffle data. // This variable will only be updated by doEstimationIfNecessary, which is protected by @@ -95,11 +97,11 @@ private[sql] class ExchangeCoordinator( @volatile private[this] var estimated: Boolean = false /** - * Registers an [[Exchange]] operator to this coordinator. This method is only allowed to be - * called in the `doPrepare` method of an [[Exchange]] operator. + * Registers an [[ShuffleExchange]] operator to this coordinator. This method is only allowed to + * be called in the `doPrepare` method of an [[ShuffleExchange]] operator. */ @GuardedBy("this") - def registerExchange(exchange: Exchange): Unit = synchronized { + def registerExchange(exchange: ShuffleExchange): Unit = synchronized { exchanges += exchange } @@ -199,7 +201,7 @@ private[sql] class ExchangeCoordinator( // Make sure we have the expected number of registered Exchange operators. assert(exchanges.length == numExchanges) - val newPostShuffleRDDs = new JHashMap[Exchange, ShuffledRowRDD](numExchanges) + val newPostShuffleRDDs = new JHashMap[ShuffleExchange, ShuffledRowRDD](numExchanges) // Submit all map stages val shuffleDependencies = ArrayBuffer[ShuffleDependency[Int, InternalRow, InternalRow]]() @@ -254,7 +256,7 @@ private[sql] class ExchangeCoordinator( } } - def postShuffleRDD(exchange: Exchange): ShuffledRowRDD = { + def postShuffleRDD(exchange: ShuffleExchange): ShuffledRowRDD = { doEstimationIfNecessary() if (!postShuffleRDDs.containsKey(exchange)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala similarity index 50% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index e30adefc69..de21d7705e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.sql.execution.exchange import java.util.Random @@ -24,19 +24,18 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ import org.apache.spark.util.MutablePair /** * Performs a shuffle that will result in the desired `newPartitioning`. */ -case class Exchange( +case class ShuffleExchange( var newPartitioning: Partitioning, child: SparkPlan, @transient coordinator: Option[ExchangeCoordinator]) extends UnaryNode { @@ -81,7 +80,8 @@ case class Exchange( * the returned ShuffleDependency will be the input of shuffle. */ private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = { - Exchange.prepareShuffleDependency(child.execute(), child.output, newPartitioning, serializer) + ShuffleExchange.prepareShuffleDependency( + child.execute(), child.output, newPartitioning, serializer) } /** @@ -116,9 +116,9 @@ case class Exchange( } } -object Exchange { - def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = { - Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator]) +object ShuffleExchange { + def apply(newPartitioning: Partitioning, child: SparkPlan): ShuffleExchange = { + ShuffleExchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator]) } /** @@ -259,238 +259,3 @@ object Exchange { dependency } } - -/** - * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] - * of input data meets the - * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for - * each operator by inserting [[Exchange]] Operators where required. Also ensure that the - * input partition ordering requirements are met. - */ -private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { - private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions - - private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize - - private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled - - private def minNumPostShufflePartitions: Option[Int] = { - val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions - if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None - } - - /** - * Given a required distribution, returns a partitioning that satisfies that distribution. - */ - private def createPartitioning( - requiredDistribution: Distribution, - numPartitions: Int): Partitioning = { - requiredDistribution match { - case AllTuples => SinglePartition - case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) - case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) - case dist => sys.error(s"Do not know how to satisfy distribution $dist") - } - } - - /** - * Adds [[ExchangeCoordinator]] to [[Exchange]]s if adaptive query execution is enabled - * and partitioning schemes of these [[Exchange]]s support [[ExchangeCoordinator]]. - */ - private def withExchangeCoordinator( - children: Seq[SparkPlan], - requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = { - val supportsCoordinator = - if (children.exists(_.isInstanceOf[Exchange])) { - // Right now, ExchangeCoordinator only support HashPartitionings. - children.forall { - case e @ Exchange(hash: HashPartitioning, _, _) => true - case child => - child.outputPartitioning match { - case hash: HashPartitioning => true - case collection: PartitioningCollection => - collection.partitionings.forall(_.isInstanceOf[HashPartitioning]) - case _ => false - } - } - } else { - // In this case, although we do not have Exchange operators, we may still need to - // shuffle data when we have more than one children because data generated by - // these children may not be partitioned in the same way. - // Please see the comment in withCoordinator for more details. - val supportsDistribution = - requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution]) - children.length > 1 && supportsDistribution - } - - val withCoordinator = - if (adaptiveExecutionEnabled && supportsCoordinator) { - val coordinator = - new ExchangeCoordinator( - children.length, - targetPostShuffleInputSize, - minNumPostShufflePartitions) - children.zip(requiredChildDistributions).map { - case (e: Exchange, _) => - // This child is an Exchange, we need to add the coordinator. - e.copy(coordinator = Some(coordinator)) - case (child, distribution) => - // If this child is not an Exchange, we need to add an Exchange for now. - // Ideally, we can try to avoid this Exchange. However, when we reach here, - // there are at least two children operators (because if there is a single child - // and we can avoid Exchange, supportsCoordinator will be false and we - // will not reach here.). Although we can make two children have the same number of - // post-shuffle partitions. Their numbers of pre-shuffle partitions may be different. - // For example, let's say we have the following plan - // Join - // / \ - // Agg Exchange - // / \ - // Exchange t2 - // / - // t1 - // In this case, because a post-shuffle partition can include multiple pre-shuffle - // partitions, a HashPartitioning will not be strictly partitioned by the hashcodes - // after shuffle. So, even we can use the child Exchange operator of the Join to - // have a number of post-shuffle partitions that matches the number of partitions of - // Agg, we cannot say these two children are partitioned in the same way. - // Here is another case - // Join - // / \ - // Agg1 Agg2 - // / \ - // Exchange1 Exchange2 - // / \ - // t1 t2 - // In this case, two Aggs shuffle data with the same column of the join condition. - // After we use ExchangeCoordinator, these two Aggs may not be partitioned in the same - // way. Let's say that Agg1 and Agg2 both have 5 pre-shuffle partitions and 2 - // post-shuffle partitions. It is possible that Agg1 fetches those pre-shuffle - // partitions by using a partitionStartIndices [0, 3]. However, Agg2 may fetch its - // pre-shuffle partitions by using another partitionStartIndices [0, 4]. - // So, Agg1 and Agg2 are actually not co-partitioned. - // - // It will be great to introduce a new Partitioning to represent the post-shuffle - // partitions when one post-shuffle partition includes multiple pre-shuffle partitions. - val targetPartitioning = - createPartitioning(distribution, defaultNumPreShufflePartitions) - assert(targetPartitioning.isInstanceOf[HashPartitioning]) - Exchange(targetPartitioning, child, Some(coordinator)) - } - } else { - // If we do not need ExchangeCoordinator, the original children are returned. - children - } - - withCoordinator - } - - private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { - val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution - val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering - var children: Seq[SparkPlan] = operator.children - assert(requiredChildDistributions.length == children.length) - assert(requiredChildOrderings.length == children.length) - - // Ensure that the operator's children satisfy their output distribution requirements: - children = children.zip(requiredChildDistributions).map { case (child, distribution) => - if (child.outputPartitioning.satisfies(distribution)) { - child - } else { - Exchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) - } - } - - // If the operator has multiple children and specifies child output distributions (e.g. join), - // then the children's output partitionings must be compatible: - if (children.length > 1 - && requiredChildDistributions.toSet != Set(UnspecifiedDistribution) - && !Partitioning.allCompatible(children.map(_.outputPartitioning))) { - - // First check if the existing partitions of the children all match. This means they are - // partitioned by the same partitioning into the same number of partitions. In that case, - // don't try to make them match `defaultPartitions`, just use the existing partitioning. - val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max - val useExistingPartitioning = children.zip(requiredChildDistributions).forall { - case (child, distribution) => { - child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) - } - } - - children = if (useExistingPartitioning) { - // We do not need to shuffle any child's output. - children - } else { - // We need to shuffle at least one child's output. - // Now, we will determine the number of partitions that will be used by created - // partitioning schemes. - val numPartitions = { - // Let's see if we need to shuffle all child's outputs when we use - // maxChildrenNumPartitions. - val shufflesAllChildren = children.zip(requiredChildDistributions).forall { - case (child, distribution) => { - !child.outputPartitioning.guarantees( - createPartitioning(distribution, maxChildrenNumPartitions)) - } - } - // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the - // number of partitions. Otherwise, we use maxChildrenNumPartitions. - if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions - } - - children.zip(requiredChildDistributions).map { - case (child, distribution) => { - val targetPartitioning = - createPartitioning(distribution, numPartitions) - if (child.outputPartitioning.guarantees(targetPartitioning)) { - child - } else { - child match { - // If child is an exchange, we replace it with - // a new one having targetPartitioning. - case Exchange(_, c, _) => Exchange(targetPartitioning, c) - case _ => Exchange(targetPartitioning, child) - } - } - } - } - } - } - - // Now, we need to add ExchangeCoordinator if necessary. - // Actually, it is not a good idea to add ExchangeCoordinators while we are adding Exchanges. - // However, with the way that we plan the query, we do not have a place where we have a - // global picture of all shuffle dependencies of a post-shuffle stage. So, we add coordinator - // at here for now. - // Once we finish https://issues.apache.org/jira/browse/SPARK-10665, - // we can first add Exchanges and then add coordinator once we have a DAG of query fragments. - children = withExchangeCoordinator(children, requiredChildDistributions) - - // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: - children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => - if (requiredOrdering.nonEmpty) { - // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. - if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) { - Sort(requiredOrdering, global = false, child = child) - } else { - child - } - } else { - child - } - } - - operator.withNewChildren(children) - } - - def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case operator @ Exchange(partitioning, child, _) => - child.children match { - case Exchange(childPartitioning, baseChild, _)::Nil => - if (childPartitioning.guarantees(partitioning)) child else operator - case _ => operator - } - case operator: SparkPlan => ensureDistributionAndOrdering(operator) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index a64da22580..ddc08822f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.execution.joins -import scala.concurrent._ -import scala.concurrent.duration._ - import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -27,10 +24,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.ThreadUtils import org.apache.spark.util.collection.CompactBuffer /** @@ -52,60 +48,25 @@ case class BroadcastHashJoin( override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - val timeout: Duration = { - val timeoutValue = sqlContext.conf.broadcastTimeout - if (timeoutValue < 0) { - Duration.Inf - } else { - timeoutValue.seconds - } - } - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - override def requiredChildDistribution: Seq[Distribution] = - UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - - // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value - // for the same query. - @transient - private lazy val broadcastFuture = { - // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - Future { - // This will run in another thread. Set the execution id so that we can connect these jobs - // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { - // Note that we use .execute().collect() because we don't want to convert data to Scala - // types - val input: Array[InternalRow] = buildPlan.execute().map { row => - row.copy() - }.collect() - // The following line doesn't run in a job so we cannot track the metric value. However, we - // have already tracked it in the above lines. So here we can use - // `SQLMetrics.nullLongMetric` to ignore it. - // TODO: move this check into HashedRelation - val hashed = if (canJoinKeyFitWithinLong) { - LongHashedRelation( - input.iterator, buildSideKeyGenerator, input.size) - } else { - HashedRelation( - input.iterator, buildSideKeyGenerator, input.size) - } - sparkContext.broadcast(hashed) - } - }(BroadcastHashJoin.broadcastHashJoinExecutionContext) - } - - protected override def doPrepare(): Unit = { - broadcastFuture + override def requiredChildDistribution: Seq[Distribution] = { + val mode = HashedRelationBroadcastMode( + canJoinKeyFitWithinLong, + rewriteKeyExpr(buildKeys), + buildPlan.output) + buildSide match { + case BuildLeft => + BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil + case BuildRight => + UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil + } } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val broadcastRelation = Await.result(broadcastFuture, timeout) - + val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow() val hashTable = broadcastRelation.value @@ -160,7 +121,7 @@ case class BroadcastHashJoin( */ private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = { // create a name for HashedRelation - val broadcastRelation = Await.result(broadcastFuture, timeout) + val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) val relationTerm = ctx.freshName("relation") val clsName = broadcastRelation.value.getClass.getName @@ -362,9 +323,3 @@ case class BroadcastHashJoin( } } } - -object BroadcastHashJoin { - - private[joins] val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128)) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 4f1cfd2e81..1f99fbedde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.{InternalAccumulator, TaskContext} +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -38,25 +39,25 @@ case class BroadcastLeftSemiJoinHash( override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + override def requiredChildDistribution: Seq[Distribution] = { + val mode = if (condition.isEmpty) { + HashSetBroadcastMode(rightKeys, right.output) + } else { + HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output) + } + UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val input = right.execute().map { row => - row.copy() - }.collect() - if (condition.isEmpty) { - val hashSet = buildKeyHashSet(input.toIterator) - val broadcastedRelation = sparkContext.broadcast(hashSet) - + val broadcastedRelation = right.executeBroadcast[java.util.Set[InternalRow]]() left.execute().mapPartitionsInternal { streamIter => hashSemiJoin(streamIter, broadcastedRelation.value, numOutputRows) } } else { - val hashRelation = - HashedRelation(input.toIterator, rightKeyGenerator, input.size) - val broadcastedRelation = sparkContext.broadcast(hashRelation) - + val broadcastedRelation = right.executeBroadcast[HashedRelation]() left.execute().mapPartitionsInternal { streamIter => val hashedRelation = broadcastedRelation.value TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 4585cbda92..e8bd7f69db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.{BitSet, CompactBuffer} @@ -33,7 +33,6 @@ case class BroadcastNestedLoopJoin( buildSide: BuildSide, joinType: JoinType, condition: Option[Expression]) extends BinaryNode { - // TODO: Override requiredChildDistribution. override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) @@ -44,8 +43,15 @@ case class BroadcastNestedLoopJoin( case BuildLeft => (right, left) } + override def requiredChildDistribution: Seq[Distribution] = buildSide match { + case BuildLeft => + BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution :: Nil + case BuildRight => + UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil + } + private[this] def genResultProjection: InternalRow => InternalRow = { - UnsafeProjection.create(schema) + UnsafeProjection.create(schema) } override def outputPartitioning: Partitioning = streamed.outputPartitioning @@ -73,15 +79,14 @@ case class BroadcastNestedLoopJoin( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map { row => - row.copy() - }.collect().toIndexedSeq) + val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => + val relation = broadcastedRelation.value + val matchedRows = new CompactBuffer[InternalRow] - val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) + val includedBroadcastTuples = new BitSet(relation.length) val joinedRow = new JoinedRow val leftNulls = new GenericMutableRow(left.output.size) @@ -92,8 +97,8 @@ case class BroadcastNestedLoopJoin( var i = 0 var streamRowMatched = false - while (i < broadcastedRelation.value.size) { - val broadcastedRow = broadcastedRelation.value(i) + while (i < relation.length) { + val broadcastedRow = relation(i) buildSide match { case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 0220e0b8a7..1cb6a00617 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.LongSQLMetric @@ -44,22 +45,7 @@ trait HashSemiJoin { protected def buildKeyHashSet( buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { - val hashSet = new java.util.HashSet[InternalRow]() - - // Create a Hash set of buildKeys - val rightKey = rightKeyGenerator - while (buildIter.hasNext) { - val currentRow = buildIter.next() - val rowKey = rightKey(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - hashSet.add(rowKey.copy()) - } - } - } - - hashSet + HashSemiJoin.buildKeyHashSet(rightKeys, right.output, buildIter) } protected def hashSemiJoin( @@ -92,3 +78,36 @@ trait HashSemiJoin { } } } + +private[execution] object HashSemiJoin { + def buildKeyHashSet( + keys: Seq[Expression], + attributes: Seq[Attribute], + rows: Iterator[InternalRow]): java.util.HashSet[InternalRow] = { + val hashSet = new java.util.HashSet[InternalRow]() + + // Create a Hash set of buildKeys + val key = UnsafeProjection.create(keys, attributes) + while (rows.hasNext) { + val currentRow = rows.next() + val rowKey = key(currentRow) + if (!rowKey.anyNull) { + val keyExists = hashSet.contains(rowKey) + if (!keyExists) { + hashSet.add(rowKey.copy()) + } + } + } + hashSet + } +} + +/** HashSetBroadcastMode requires that the input rows are broadcasted as a set. */ +private[execution] case class HashSetBroadcastMode( + keys: Seq[Expression], + attributes: Seq[Attribute]) extends BroadcastMode { + + override def transform(rows: Array[InternalRow]): java.util.HashSet[InternalRow] = { + HashSemiJoin.buildKeyHashSet(keys, attributes, rows.iterator) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 0978570d42..606269bf25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -25,12 +25,11 @@ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode +import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer} import org.apache.spark.sql.execution.local.LocalNode -import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap -import org.apache.spark.unsafe.memory.MemoryLocation import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils} import org.apache.spark.util.collection.CompactBuffer @@ -675,3 +674,20 @@ private[joins] object LongHashedRelation { } } } + +/** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */ +private[execution] case class HashedRelationBroadcastMode( + canJoinKeyFitWithinLong: Boolean, + keys: Seq[Expression], + attributes: Seq[Attribute]) extends BroadcastMode { + + def transform(rows: Array[InternalRow]): HashedRelation = { + val generator = UnsafeProjection.create(keys, attributes) + if (canJoinKeyFitWithinLong) { + LongHashedRelation(rows.iterator, generator, rows.length) + } else { + HashedRelation(rows.iterator, generator, rows.length) + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index ce758d63b3..df6dac8818 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -29,9 +29,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics * for hash join. */ case class LeftSemiJoinBNL( - streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - extends BinaryNode { - // TODO: Override requiredChildDistribution. + streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) extends BinaryNode { override private[sql] lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) @@ -46,27 +44,28 @@ case class LeftSemiJoinBNL( /** The Broadcast relation */ override def right: SparkPlan = broadcast + override def requiredChildDistribution: Seq[Distribution] = { + UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil + } + @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - val broadcastedRelation = - sparkContext.broadcast(broadcast.execute().map { row => - row.copy() - }.collect().toIndexedSeq) + val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow + val relation = broadcastedRelation.value streamedIter.filter(streamedRow => { var i = 0 var matched = false - while (i < broadcastedRelation.value.size && !matched) { - val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { + while (i < relation.length && !matched) { + if (boundCondition(joinedRow(streamedRow, relation(i)))) { matched = true } i += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index ef76847bcb..cd543d4195 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.exchange.ShuffleExchange /** @@ -38,7 +39,8 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode { private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) protected override def doExecute(): RDD[InternalRow] = { val shuffled = new ShuffledRowRDD( - Exchange.prepareShuffleDependency(child.execute(), child.output, SinglePartition, serializer)) + ShuffleExchange.prepareShuffleDependency( + child.execute(), child.output, SinglePartition, serializer)) shuffled.mapPartitionsInternal(_.take(limit)) } } @@ -110,7 +112,8 @@ case class TakeOrderedAndProject( } } val shuffled = new ShuffledRowRDD( - Exchange.prepareShuffleDependency(localTopK, child.output, SinglePartition, serializer)) + ShuffleExchange.prepareShuffleDependency( + localTopK, child.output, SinglePartition, serializer)) shuffled.mapPartitions { iter => val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) if (projectList.isDefined) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index e8d0678989..83d7953aaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -23,9 +23,9 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators -import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -357,7 +357,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext * Verifies that the plan for `df` contains `expected` number of Exchange operators. */ private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = { - assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.size == expected) + assert(df.queryExecution.executedPlan.collect { case e: ShuffleExchange => e }.size == expected) } test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 99ba2e2061..50a246489e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -26,8 +26,8 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} -import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} import org.apache.spark.sql.test.SQLTestData.TestData2 @@ -1119,7 +1119,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } atFirstAgg = true } - case e: Exchange => atFirstAgg = false + case e: ShuffleExchange => atFirstAgg = false case _ => } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 35ff1c40fe..b1c588a63d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{MapOutputStatistics, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql._ +import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext @@ -297,13 +298,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = agg.queryExecution.executedPlan.collect { - case e: Exchange => e + case e: ShuffleExchange => e } assert(exchanges.length === 1) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => @@ -311,7 +312,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 2) case o => @@ -348,13 +349,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: Exchange => e + case e: ShuffleExchange => e } assert(exchanges.length === 2) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => @@ -362,7 +363,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 2) case o => @@ -404,13 +405,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: Exchange => e + case e: ShuffleExchange => e } assert(exchanges.length === 4) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => @@ -456,13 +457,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: Exchange => e + case e: ShuffleExchange => e } assert(exchanges.length === 3) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: Exchange => + case e: ShuffleExchange => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 87bff3295f..d4f22de90c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical.SinglePartition +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { @@ -28,7 +29,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), - plan => Exchange(SinglePartition, plan), + plan => ShuffleExchange(SinglePartition, plan), input.map(Row.fromTuple) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 250ce8f866..4de56783fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchange} import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -212,7 +213,7 @@ class PlannerSuite extends SharedSQLContext { | JOIN tiny ON (small.key = tiny.key) """.stripMargin ).queryExecution.executedPlan.collect { - case exchange: Exchange => exchange + case exchange: ShuffleExchange => exchange }.length assert(numExchanges === 5) } @@ -227,7 +228,7 @@ class PlannerSuite extends SharedSQLContext { | JOIN tiny ON (normal.key = tiny.key) """.stripMargin ).queryExecution.executedPlan.collect { - case exchange: Exchange => exchange + case exchange: ShuffleExchange => exchange }.length assert(numExchanges === 5) } @@ -295,7 +296,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.isEmpty) { + if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -333,7 +334,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.isEmpty) { + if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -353,7 +354,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.nonEmpty) { + if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") } } @@ -376,7 +377,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.nonEmpty) { + if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") } } @@ -435,7 +436,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val inputPlan = Exchange(finalPartitioning, + val inputPlan = ShuffleExchange(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -444,7 +445,7 @@ class PlannerSuite extends SharedSQLContext { val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.size == 2) { + if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) { fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") } } @@ -455,7 +456,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val inputPlan = Exchange(finalPartitioning, + val inputPlan = ShuffleExchange(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -464,7 +465,7 @@ class PlannerSuite extends SharedSQLContext { val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: Exchange => true }.size == 1) { + if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) { fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index e25b5e0610..a256ee95a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,7 +22,8 @@ import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} -import org.apache.spark.sql.{QueryTest, SQLConf, SQLContext} +import org.apache.spark.sql.{QueryTest, SQLContext} +import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ /** @@ -62,7 +63,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") val df3 = df1.join(broadcast(df2), joinExpression, joinType) - val plan = df3.queryExecution.sparkPlan + val plan = EnsureRequirements(sqlContext).apply(df3.queryExecution.sparkPlan) assert(plan.collect { case p: T => p }.size === 1) plan.executeCollect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index e22a810a6b..6dfff3770b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -88,7 +89,15 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftPlan: SparkPlan, rightPlan: SparkPlan, side: BuildSide) = { - joins.BroadcastHashJoin(leftKeys, rightKeys, Inner, side, boundCondition, leftPlan, rightPlan) + val broadcastJoin = joins.BroadcastHashJoin( + leftKeys, + rightKeys, + Inner, + side, + boundCondition, + leftPlan, + rightPlan) + EnsureRequirements(sqlContext).apply(broadcastJoin) } def makeSortMergeJoin( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index f4b01fbad0..cd6b6fcbb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 9c86084f9b..f3ad8409e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 9ba645626f..a05a57c0f5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -22,8 +22,9 @@ import java.io.File import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.execution.{Exchange, PhysicalRDD} +import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSourceStrategy} +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.SortMergeJoin import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -252,8 +253,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoin]) val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoin] - assert(joinOperator.left.find(_.isInstanceOf[Exchange]).isDefined == shuffleLeft) - assert(joinOperator.right.find(_.isInstanceOf[Exchange]).isDefined == shuffleRight) + assert(joinOperator.left.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleLeft) + assert(joinOperator.right.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleRight) } } } @@ -312,7 +313,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet agged.sort("i", "j"), df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty) } } @@ -326,7 +327,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet agged.sort("i", "j"), df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty) } } @@ -339,7 +340,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet val agged = hiveContext.table("bucketed_table").groupBy("i").count() // make sure we fall back to non-bucketing mode and can't avoid shuffle - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isDefined) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isDefined) checkAnswer(agged.sort("i"), df1.groupBy("i").count().sort("i")) } }