[SPARK-32740][SQL] Refactor common partitioning/distribution logic to BaseAggregateExec
### What changes were proposed in this pull request? For all three different aggregate physical operator: `HashAggregateExec`, `ObjectHashAggregateExec` and `SortAggregateExec`, they have same `outputPartitioning` and `requiredChildDistribution` logic. Refactor these same logic into their super class `BaseAggregateExec` to avoid code duplication and future bugs (similar to `HashJoin` and `ShuffledJoin`). ### Why are the changes needed? Reduce duplicated code across classes and prevent future bugs if we only update one class but forget another. We already did similar refactoring for join (`HashJoin` and `ShuffledJoin`). ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing unit tests as this is pure refactoring and no new logic added. Closes #29583 from c21/aggregate-refactor. Authored-by: Cheng Su <chengsu@fb.com> Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
parent
a1e459ed9f
commit
ce473b223a
|
@ -17,14 +17,16 @@
|
|||
|
||||
package org.apache.spark.sql.execution.aggregate
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, NamedExpression}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, NamedExpression}
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, PartialMerge}
|
||||
import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode}
|
||||
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
|
||||
import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, ExplainUtils, UnaryExecNode}
|
||||
|
||||
/**
|
||||
* Holds common logic for aggregate operators
|
||||
*/
|
||||
trait BaseAggregateExec extends UnaryExecNode {
|
||||
trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning {
|
||||
def requiredChildDistributionExpressions: Option[Seq[Expression]]
|
||||
def groupingExpressions: Seq[NamedExpression]
|
||||
def aggregateExpressions: Seq[AggregateExpression]
|
||||
def aggregateAttributes: Seq[Attribute]
|
||||
|
@ -81,4 +83,16 @@ trait BaseAggregateExec extends UnaryExecNode {
|
|||
// attributes of the child Aggregate, when the child Aggregate contains the subquery in
|
||||
// AggregateFunction. See SPARK-31620 for more details.
|
||||
AttributeSet(inputAggBufferAttributes.filterNot(child.output.contains))
|
||||
|
||||
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
|
||||
|
||||
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
|
||||
|
||||
override def requiredChildDistribution: List[Distribution] = {
|
||||
requiredChildDistributionExpressions match {
|
||||
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
|
||||
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
|
||||
case None => UnspecifiedDistribution :: Nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
|
|||
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
|
||||
import org.apache.spark.sql.catalyst.util.truncatedString
|
||||
import org.apache.spark.sql.execution._
|
||||
|
@ -54,8 +53,7 @@ case class HashAggregateExec(
|
|||
resultExpressions: Seq[NamedExpression],
|
||||
child: SparkPlan)
|
||||
extends BaseAggregateExec
|
||||
with BlockingOperatorWithCodegen
|
||||
with AliasAwareOutputPartitioning {
|
||||
with BlockingOperatorWithCodegen {
|
||||
|
||||
require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
|
||||
|
||||
|
@ -71,18 +69,6 @@ case class HashAggregateExec(
|
|||
"avgHashProbe" ->
|
||||
SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters"))
|
||||
|
||||
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
|
||||
|
||||
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
|
||||
|
||||
override def requiredChildDistribution: List[Distribution] = {
|
||||
requiredChildDistributionExpressions match {
|
||||
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
|
||||
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
|
||||
case None => UnspecifiedDistribution :: Nil
|
||||
}
|
||||
}
|
||||
|
||||
// This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
|
||||
// map and/or the sort-based aggregation once it has processed a given number of input rows.
|
||||
private val testFallbackStartsAt: Option[(Int, Int)] = {
|
||||
|
|
|
@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.errors._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.catalyst.util.truncatedString
|
||||
import org.apache.spark.sql.execution._
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
|
@ -67,7 +66,7 @@ case class ObjectHashAggregateExec(
|
|||
initialInputBufferOffset: Int,
|
||||
resultExpressions: Seq[NamedExpression],
|
||||
child: SparkPlan)
|
||||
extends BaseAggregateExec with AliasAwareOutputPartitioning {
|
||||
extends BaseAggregateExec {
|
||||
|
||||
override lazy val allAttributes: AttributeSeq =
|
||||
child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++
|
||||
|
@ -78,18 +77,6 @@ case class ObjectHashAggregateExec(
|
|||
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build")
|
||||
)
|
||||
|
||||
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
|
||||
|
||||
override def requiredChildDistribution: List[Distribution] = {
|
||||
requiredChildDistributionExpressions match {
|
||||
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
|
||||
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
|
||||
case None => UnspecifiedDistribution :: Nil
|
||||
}
|
||||
}
|
||||
|
||||
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
|
||||
|
||||
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
|
||||
val numOutputRows = longMetric("numOutputRows")
|
||||
val aggTime = longMetric("aggTime")
|
||||
|
|
|
@ -22,9 +22,8 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.errors._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.catalyst.util.truncatedString
|
||||
import org.apache.spark.sql.execution.{AliasAwareOutputOrdering, AliasAwareOutputPartitioning, SparkPlan}
|
||||
import org.apache.spark.sql.execution.{AliasAwareOutputOrdering, SparkPlan}
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
|
||||
/**
|
||||
|
@ -39,28 +38,15 @@ case class SortAggregateExec(
|
|||
resultExpressions: Seq[NamedExpression],
|
||||
child: SparkPlan)
|
||||
extends BaseAggregateExec
|
||||
with AliasAwareOutputPartitioning
|
||||
with AliasAwareOutputOrdering {
|
||||
with AliasAwareOutputOrdering {
|
||||
|
||||
override lazy val metrics = Map(
|
||||
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
|
||||
|
||||
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
|
||||
|
||||
override def requiredChildDistribution: List[Distribution] = {
|
||||
requiredChildDistributionExpressions match {
|
||||
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
|
||||
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
|
||||
case None => UnspecifiedDistribution :: Nil
|
||||
}
|
||||
}
|
||||
|
||||
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
|
||||
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
|
||||
}
|
||||
|
||||
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
|
||||
|
||||
override protected def orderingExpressions: Seq[SortOrder] = {
|
||||
groupingExpressions.map(SortOrder(_, Ascending))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue