[SPARK-32740][SQL] Refactor common partitioning/distribution logic to BaseAggregateExec

### What changes were proposed in this pull request?

For all three different aggregate physical operator: `HashAggregateExec`, `ObjectHashAggregateExec` and `SortAggregateExec`, they have same `outputPartitioning` and `requiredChildDistribution` logic. Refactor these same logic into their super class `BaseAggregateExec` to avoid code duplication and future bugs (similar to `HashJoin` and `ShuffledJoin`).

### Why are the changes needed?

Reduce duplicated code across classes and prevent future bugs if we only update one class but forget another. We already did similar refactoring for join (`HashJoin` and `ShuffledJoin`).

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing unit tests as this is pure refactoring and no new logic added.

Closes #29583 from c21/aggregate-refactor.

Authored-by: Cheng Su <chengsu@fb.com>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
Cheng Su 2020-08-31 15:43:13 +09:00 committed by Takeshi Yamamuro
parent a1e459ed9f
commit ce473b223a
4 changed files with 21 additions and 48 deletions

View file

@ -17,14 +17,16 @@
package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, PartialMerge}
import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, ExplainUtils, UnaryExecNode}
/**
* Holds common logic for aggregate operators
*/
trait BaseAggregateExec extends UnaryExecNode {
trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning {
def requiredChildDistributionExpressions: Option[Seq[Expression]]
def groupingExpressions: Seq[NamedExpression]
def aggregateExpressions: Seq[AggregateExpression]
def aggregateAttributes: Seq[Attribute]
@ -81,4 +83,16 @@ trait BaseAggregateExec extends UnaryExecNode {
// attributes of the child Aggregate, when the child Aggregate contains the subquery in
// AggregateFunction. See SPARK-31620 for more details.
AttributeSet(inputAggBufferAttributes.filterNot(child.output.contains))
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}
}

View file

@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution._
@ -54,8 +53,7 @@ case class HashAggregateExec(
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends BaseAggregateExec
with BlockingOperatorWithCodegen
with AliasAwareOutputPartitioning {
with BlockingOperatorWithCodegen {
require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
@ -71,18 +69,6 @@ case class HashAggregateExec(
"avgHashProbe" ->
SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters"))
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}
// This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
// map and/or the sort-based aggregation once it has processed a given number of input rows.
private val testFallbackStartsAt: Option[(Int, Int)] = {

View file

@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
@ -67,7 +66,7 @@ case class ObjectHashAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends BaseAggregateExec with AliasAwareOutputPartitioning {
extends BaseAggregateExec {
override lazy val allAttributes: AttributeSeq =
child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++
@ -78,18 +77,6 @@ case class ObjectHashAggregateExec(
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build")
)
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numOutputRows = longMetric("numOutputRows")
val aggTime = longMetric("aggTime")

View file

@ -22,9 +22,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.{AliasAwareOutputOrdering, AliasAwareOutputPartitioning, SparkPlan}
import org.apache.spark.sql.execution.{AliasAwareOutputOrdering, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
@ -39,28 +38,15 @@ case class SortAggregateExec(
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends BaseAggregateExec
with AliasAwareOutputPartitioning
with AliasAwareOutputOrdering {
with AliasAwareOutputOrdering {
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
}
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
override protected def orderingExpressions: Seq[SortOrder] = {
groupingExpressions.map(SortOrder(_, Ascending))
}