From ce473b223ac64b60662a2e1731891a89fa3d126b Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Mon, 31 Aug 2020 15:43:13 +0900 Subject: [PATCH] [SPARK-32740][SQL] Refactor common partitioning/distribution logic to BaseAggregateExec ### What changes were proposed in this pull request? For all three different aggregate physical operator: `HashAggregateExec`, `ObjectHashAggregateExec` and `SortAggregateExec`, they have same `outputPartitioning` and `requiredChildDistribution` logic. Refactor these same logic into their super class `BaseAggregateExec` to avoid code duplication and future bugs (similar to `HashJoin` and `ShuffledJoin`). ### Why are the changes needed? Reduce duplicated code across classes and prevent future bugs if we only update one class but forget another. We already did similar refactoring for join (`HashJoin` and `ShuffledJoin`). ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing unit tests as this is pure refactoring and no new logic added. Closes #29583 from c21/aggregate-refactor. Authored-by: Cheng Su Signed-off-by: Takeshi Yamamuro --- .../aggregate/BaseAggregateExec.scala | 20 ++++++++++++++++--- .../aggregate/HashAggregateExec.scala | 16 +-------------- .../aggregate/ObjectHashAggregateExec.scala | 15 +------------- .../aggregate/SortAggregateExec.scala | 18 ++--------------- 4 files changed, 21 insertions(+), 48 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala index f6d04601fc..efba51706c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, PartialMerge} -import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, ExplainUtils, UnaryExecNode} /** * Holds common logic for aggregate operators */ -trait BaseAggregateExec extends UnaryExecNode { +trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning { + def requiredChildDistributionExpressions: Option[Seq[Expression]] def groupingExpressions: Seq[NamedExpression] def aggregateExpressions: Seq[AggregateExpression] def aggregateAttributes: Seq[Attribute] @@ -81,4 +83,16 @@ trait BaseAggregateExec extends UnaryExecNode { // attributes of the child Aggregate, when the child Aggregate contains the subquery in // AggregateFunction. See SPARK-31620 for more details. AttributeSet(inputAggBufferAttributes.filterNot(child.output.contains)) + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override protected def outputExpressions: Seq[NamedExpression] = resultExpressions + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 40b95df44a..dcb465707a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ @@ -54,8 +53,7 @@ case class HashAggregateExec( resultExpressions: Seq[NamedExpression], child: SparkPlan) extends BaseAggregateExec - with BlockingOperatorWithCodegen - with AliasAwareOutputPartitioning { + with BlockingOperatorWithCodegen { require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) @@ -71,18 +69,6 @@ case class HashAggregateExec( "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters")) - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override protected def outputExpressions: Seq[NamedExpression] = resultExpressions - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash // map and/or the sort-based aggregation once it has processed a given number of input rows. private val testFallbackStartsAt: Option[(Int, Int)] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 231adbeb24..02f666383d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics @@ -67,7 +66,7 @@ case class ObjectHashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends BaseAggregateExec with AliasAwareOutputPartitioning { + extends BaseAggregateExec { override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ @@ -78,18 +77,6 @@ case class ObjectHashAggregateExec( "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build") ) - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - - override protected def outputExpressions: Seq[NamedExpression] = resultExpressions - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") val aggTime = longMetric("aggTime") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 48763686e4..aa1559e2ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -22,9 +22,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.execution.{AliasAwareOutputOrdering, AliasAwareOutputPartitioning, SparkPlan} +import org.apache.spark.sql.execution.{AliasAwareOutputOrdering, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -39,28 +38,15 @@ case class SortAggregateExec( resultExpressions: Seq[NamedExpression], child: SparkPlan) extends BaseAggregateExec - with AliasAwareOutputPartitioning - with AliasAwareOutputOrdering { + with AliasAwareOutputOrdering { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) - - override def requiredChildDistribution: List[Distribution] = { - requiredChildDistributionExpressions match { - case Some(exprs) if exprs.isEmpty => AllTuples :: Nil - case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil - case None => UnspecifiedDistribution :: Nil - } - } - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { groupingExpressions.map(SortOrder(_, Ascending)) :: Nil } - override protected def outputExpressions: Seq[NamedExpression] = resultExpressions - override protected def orderingExpressions: Seq[SortOrder] = { groupingExpressions.map(SortOrder(_, Ascending)) }