Revert "[SPARK-34581][SQL] Don't optimize out grouping expressions from aggregate expressions without aggregate function"

This reverts commit c8d78a70b4.
This commit is contained in:
Wenchen Fan 2021-04-23 15:55:30 +08:00
parent 20d68dc2f4
commit fdccd88c2a
15 changed files with 68 additions and 247 deletions

View file

@ -17,8 +17,8 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.{Attribute, GroupingExprRef, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
/**
@ -52,22 +52,3 @@ object UpdateAttributeNullability extends Rule[LogicalPlan] {
}
}
}
/**
* Updates nullability of [[GroupingExprRef]]s in a resolved LogicalPlan by using the nullability of
* referenced grouping expression.
*/
object UpdateGroupingExprRefNullability extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a: Aggregate =>
val nullabilities = a.groupingExpressions.map(_.nullable).toArray
val newAggregateExpressions =
a.aggregateExpressions.map(_.transform {
case g: GroupingExprRef if g.nullable != nullabilities(g.ordinal) =>
g.copy(nullable = nullabilities(g.ordinal))
}.asInstanceOf[NamedExpression])
a.copy(aggregateExpressions = newAggregateExpressions)
}
}

View file

@ -35,7 +35,7 @@ trait AliasHelper {
protected def getAliasMap(plan: Aggregate): AttributeMap[Alias] = {
// Find all the aliased expressions in the aggregate list that don't include any actual
// AggregateExpression or PythonUDF, and create a map from the alias to the expression
val aliasMap = plan.aggregateExpressionsWithoutGroupingRefs.collect {
val aliasMap = plan.aggregateExpressions.collect {
case a: Alias if a.child.find(e => e.isInstanceOf[AggregateExpression] ||
PythonUDF.isGroupedAggPandasUDF(e)).isEmpty =>
(a.toAttribute, a)

View file

@ -80,14 +80,6 @@ object AggregateExpression {
filter,
NamedExpression.newExprId)
}
def containsAggregate(expr: Expression): Boolean = {
expr.find(isAggregate).isDefined
}
def isAggregate(expr: Expression): Boolean = {
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
}
}
/**

View file

@ -277,22 +277,3 @@ object GroupingAnalytics {
}
}
}
/**
* A reference to an grouping expression in [[Aggregate]] node.
*
* @param ordinal The ordinal of the grouping expression in [[Aggregate]] that this expression
* refers to.
* @param dataType The [[DataType]] of the referenced grouping expression.
* @param nullable True if null is a valid value for the referenced grouping expression.
*/
case class GroupingExprRef(
ordinal: Int,
dataType: DataType,
nullable: Boolean)
extends LeafExpression with Unevaluable {
override def stringArgs: Iterator[Any] = {
Iterator(ordinal)
}
}

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
/**
@ -26,6 +26,15 @@ import org.apache.spark.sql.catalyst.rules.Rule
*/
object SimplifyExtractValueOps extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// One place where this optimization is invalid is an aggregation where the select
// list expression is a function of a grouping expression:
//
// SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b)
//
// cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this
// optimization for Aggregates (although this misses some cases where the optimization
// can be made).
case a: Aggregate => a
case p => p.transformExpressionsUp {
// Remove redundant field extraction.
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>

View file

@ -1,34 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
/**
* This rule ensures that [[Aggregate]] nodes contain all required [[GroupingExprRef]]
* references for optimization phase.
*/
object EnforceGroupingReferencesInAggregates extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan transform {
case a: Aggregate =>
Aggregate.withGroupingRefs(a.groupingExpressions, a.aggregateExpressions, a.child)
}
}
}

View file

@ -119,8 +119,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
OptimizeUpdateFields,
SimplifyExtractValueOps,
OptimizeCsvJsonExprs,
CombineConcats,
UpdateGroupingExprRefNullability) ++
CombineConcats) ++
extendedOperatorOptimizationRules
val operatorOptimizationBatch: Seq[Batch] = {
@ -149,7 +148,6 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateView,
ReplaceExpressions,
RewriteNonCorrelatedExists,
EnforceGroupingReferencesInAggregates,
ComputeCurrentTime,
GetCurrentDatabaseAndCatalog(catalogManager)) ::
//////////////////////////////////////////////////////////////////////////////////////////
@ -269,9 +267,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RewriteCorrelatedScalarSubquery.ruleName ::
RewritePredicateSubquery.ruleName ::
NormalizeFloatingNumbers.ruleName ::
ReplaceUpdateFieldsExpression.ruleName ::
EnforceGroupingReferencesInAggregates.ruleName ::
UpdateGroupingExprRefNullability.ruleName :: Nil
ReplaceUpdateFieldsExpression.ruleName :: Nil
/**
* Optimize all the subqueries inside expression.
@ -512,7 +508,7 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, lower) =>
val aliasMap = getAliasMap(lower)
val newAggregate = Aggregate.withGroupingRefs(
val newAggregate = upper.copy(
child = lower.child,
groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)),
aggregateExpressions = upper.aggregateExpressions.map(
@ -528,19 +524,23 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
}
private def lowerIsRedundant(upper: Aggregate, lower: Aggregate): Boolean = {
val upperHasNoAggregateExpressions =
!upper.aggregateExpressions.exists(AggregateExpression.containsAggregate)
val upperHasNoAggregateExpressions = !upper.aggregateExpressions.exists(isAggregate)
lazy val upperRefsOnlyDeterministicNonAgg = upper.references.subsetOf(AttributeSet(
lower
.aggregateExpressions
.filter(_.deterministic)
.filterNot(AggregateExpression.containsAggregate)
.filter(!isAggregate(_))
.map(_.toAttribute)
))
upperHasNoAggregateExpressions && upperRefsOnlyDeterministicNonAgg
}
private def isAggregate(expr: Expression): Boolean = {
expr.find(e => e.isInstanceOf[AggregateExpression] ||
PythonUDF.isGroupedAggPandasUDF(e)).isDefined
}
}
/**
@ -1978,18 +1978,7 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
case a @ Aggregate(grouping, _, _) if grouping.nonEmpty =>
val newGrouping = grouping.filter(!_.foldable)
if (newGrouping.nonEmpty) {
val droppedGroupsBefore =
grouping.scanLeft(0)((n, e) => n + (if (e.foldable) 1 else 0)).toArray
val newAggregateExpressions =
a.aggregateExpressions.map(_.transform {
case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 =>
g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal))
}.asInstanceOf[NamedExpression])
a.copy(
groupingExpressions = newGrouping,
aggregateExpressions = newAggregateExpressions)
a.copy(groupingExpressions = newGrouping)
} else {
// All grouping expressions are literals. We should not drop them all, because this can
// change the return semantics when the input of the Aggregate is empty (SPARK-17114). We
@ -2010,25 +1999,7 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
if (newGrouping.size == grouping.size) {
a
} else {
var i = 0
val droppedGroupsBefore = grouping.scanLeft(0)((n, e) =>
n + (if (i >= newGrouping.size || e.eq(newGrouping(i))) {
i += 1
0
} else {
1
})
).toArray
val newAggregateExpressions =
a.aggregateExpressions.map(_.transform {
case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 =>
g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal))
}.asInstanceOf[NamedExpression])
a.copy(
groupingExpressions = newGrouping,
aggregateExpressions = newAggregateExpressions)
a.copy(groupingExpressions = newGrouping)
}
}
}

View file

@ -633,10 +633,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
* subqueries.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
case a @ Aggregate(grouping, _, child) =>
case a @ Aggregate(grouping, expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val rewriteExprs = a.aggregateExpressionsWithoutGroupingRefs
.map(extractCorrelatedScalarSubqueries(_, subqueries))
val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
// We currently only allow correlated subqueries in an aggregate if they are part of the
// grouping expressions. As a result we need to replace all the scalar subqueries in the

View file

@ -287,7 +287,7 @@ object PhysicalAggregation {
(Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)
def unapply(a: Any): Option[ReturnType] = a match {
case a @ logical.Aggregate(groupingExpressions, resultExpressions, child) =>
case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
// A single aggregate expression might appear multiple times in resultExpressions.
// In order to avoid evaluating an individual aggregate function multiple times, we'll
// build a set of semantically distinct aggregate expressions and re-write expressions so
@ -297,9 +297,11 @@ object PhysicalAggregation {
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
// addExpr() always returns false for non-deterministic expressions and do not add them.
case a
if AggregateExpression.isAggregate(a) && !equivalentAggregateExpressions.addExpr(a) =>
a
case agg: AggregateExpression
if !equivalentAggregateExpressions.addExpr(agg) => agg
case udf: PythonUDF
if PythonUDF.isGroupedAggPandasUDF(udf) &&
!equivalentAggregateExpressions.addExpr(udf) => udf
}
}
@ -320,7 +322,7 @@ object PhysicalAggregation {
// which takes the grouping columns and final aggregate result buffer as input.
// Thus, we must re-write the result expressions so that their attributes match up with
// the attributes of the final result projection's input row:
val rewrittenResultExpressions = a.aggregateExpressionsWithoutGroupingRefs.map { expr =>
val rewrittenResultExpressions = resultExpressions.map { expr =>
expr.transformDown {
case ae: AggregateExpression =>
// The final aggregation buffer's attributes will be `finalAggregationAttributes`,

View file

@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.plans.logical
import scala.collection.mutable
import org.apache.spark.sql.catalyst.AliasIdentifier
import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, MultiInstanceRelation, TypeCoercion, TypeCoercionBase}
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable}
@ -783,23 +781,14 @@ case class Range(
/**
* This is a Group by operator with the aggregate functions and projections.
*
* @param groupingExpressions Expressions for grouping keys.
* @param aggregateExpressions Expressions for a project list, which can contain
* [[AggregateExpression]]s and [[GroupingExprRef]]s.
* @param child The child of the aggregate node.
* @param groupingExpressions expressions for grouping keys
* @param aggregateExpressions expressions for a project list, which could contain
* [[AggregateExpression]]s.
*
* Expressions without aggregate functions in [[aggregateExpressions]] can contain
* [[GroupingExprRef]]s to refer to complex grouping expressions in [[groupingExpressions]]. These
* references ensure that optimization rules don't change the aggregate expressions to invalid ones
* that no longer refer to any grouping expressions and also simplify the expression transformations
* on the node (need to transform the expression only once).
*
* For example, in the following query Spark shouldn't optimize the aggregate expression
* `Not(IsNull(c))` to `IsNotNull(c)` as the grouping expression is `IsNull(c)`:
* SELECT not(c IS NULL)
* FROM t
* GROUP BY c IS NULL
* Instead, the aggregate expression should contain `Not(GroupingExprRef(0))`.
* Note: Currently, aggregateExpressions is the project list of this Group by operator. Before
* separating projection from grouping and aggregate, we should avoid expression-level optimization
* on aggregateExpressions, which could reference an expression in groupingExpressions.
* For example, see the rule [[org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps]]
*/
case class Aggregate(
groupingExpressions: Seq[Expression],
@ -826,21 +815,8 @@ case class Aggregate(
}
}
private def expandGroupingReferences(e: Expression): Expression = {
e match {
case _ if AggregateExpression.isAggregate(e) => e
case g: GroupingExprRef => groupingExpressions(g.ordinal)
case _ => e.mapChildren(expandGroupingReferences)
}
}
lazy val aggregateExpressionsWithoutGroupingRefs = {
aggregateExpressions.map(expandGroupingReferences(_).asInstanceOf[NamedExpression])
}
override lazy val validConstraints: ExpressionSet = {
val nonAgg = aggregateExpressionsWithoutGroupingRefs.
filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty)
val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty)
getAllValidConstraints(nonAgg)
}
@ -848,51 +824,6 @@ case class Aggregate(
copy(child = newChild)
}
object Aggregate {
private def collectComplexGroupingExpressions(groupingExpressions: Seq[Expression]) = {
val complexGroupingExpressions = mutable.Map.empty[Expression, (Expression, Int)]
var i = 0
groupingExpressions.foreach { ge =>
if (!ge.foldable && ge.children.nonEmpty &&
!complexGroupingExpressions.contains(ge.canonicalized)) {
complexGroupingExpressions += ge.canonicalized -> (ge, i)
}
i += 1
}
complexGroupingExpressions
}
private def insertGroupingReferences(
aggregateExpressions: Seq[NamedExpression],
groupingExpressions: collection.Map[Expression, (Expression, Int)]): Seq[NamedExpression] = {
def insertGroupingExprRefs(e: Expression): Expression = {
e match {
case _ if AggregateExpression.isAggregate(e) => e
case _ if groupingExpressions.contains(e.canonicalized) =>
val (groupingExpression, ordinal) = groupingExpressions(e.canonicalized)
GroupingExprRef(ordinal, groupingExpression.dataType, groupingExpression.nullable)
case _ => e.mapChildren(insertGroupingExprRefs)
}
}
aggregateExpressions.map(insertGroupingExprRefs(_).asInstanceOf[NamedExpression])
}
def withGroupingRefs(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: LogicalPlan): Aggregate = {
val complexGroupingExpressions = collectComplexGroupingExpressions(groupingExpressions)
val aggrExprWithGroupingReferences = if (complexGroupingExpressions.nonEmpty) {
insertGroupingReferences(aggregateExpressions, complexGroupingExpressions)
} else {
aggregateExpressions
}
new Aggregate(groupingExpressions, aggrExprWithGroupingReferences, child)
}
}
case class Window(
windowExpressions: Seq[NamedExpression],
partitionSpec: Seq[Expression],

View file

@ -96,7 +96,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy('a + 'b)(('a + 'b) as 'c)
.analyze
val optimized = Optimize.execute(query)
comparePlans(optimized, EnforceGroupingReferencesInAggregates(expected))
comparePlans(optimized, expected)
}
}

View file

@ -36,8 +36,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
object Optimizer extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Finish Analysis", Once,
EnforceGroupingReferencesInAggregates) ::
Batch("collapse projections", FixedPoint(10),
CollapseProject) ::
Batch("Constant Folding", FixedPoint(10),
@ -59,7 +57,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = {
val optimized = Optimizer.execute(originalQuery.analyze)
assert(optimized.resolved, "optimized plans must be still resolvable")
comparePlans(optimized, EnforceGroupingReferencesInAggregates(correctAnswer.analyze))
comparePlans(optimized, correctAnswer.analyze)
}
test("explicit get from namedStruct") {
@ -407,6 +405,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
val arrayAggRel = relation.groupBy(
CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0))
checkRule(arrayAggRel, arrayAggRel)
// This could be done if we had a more complex rule that checks that
// the CreateMap does not come from key.
val originalQuery = relation
.groupBy('id)(
GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a"
)
checkRule(originalQuery, originalQuery)
}
test("SPARK-23500: namedStruct and getField in the same Project #1") {

View file

@ -40,7 +40,6 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
e.isInstanceOf[AggregateExpression] ||
PythonUDF.isGroupedAggPandasUDF(e) ||
e.isInstanceOf[GroupingExprRef] ||
agg.groupingExpressions.exists(_.semanticEquals(e))
}
@ -120,8 +119,23 @@ object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] {
groupingExpr += expr
}
}
val aggExpr = agg.aggregateExpressions.map { expr =>
expr.transformUp {
// PythonUDF over aggregate was pull out by ExtractPythonUDFFromAggregate.
// PythonUDF here should be either
// 1. Argument of an aggregate function.
// CheckAnalysis guarantees the arguments are deterministic.
// 2. PythonUDF in grouping key. Grouping key must be deterministic.
// 3. PythonUDF not in grouping key. It is either no arguments or with grouping key
// in its arguments. Such PythonUDF was pull out by ExtractPythonUDFFromAggregate, too.
case p: PythonUDF if p.udfDeterministic =>
val canonicalized = p.canonicalized.asInstanceOf[PythonUDF]
attributeMap.getOrElse(canonicalized, p)
}.asInstanceOf[NamedExpression]
}
agg.copy(
groupingExpressions = groupingExpr.toSeq,
aggregateExpressions = aggExpr,
child = Project((projList ++ agg.child.output).toSeq, agg.child))
}

View file

@ -179,12 +179,3 @@ SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(
-- Aggregate with multiple distinct decimal columns
SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 AS DECIMAL(9, 0))) t(decimal_col);
-- SPARK-34581: Don't optimize out grouping expressions from aggregate expressions without aggregate function
SELECT not(a IS NULL), count(*) AS c
FROM testData
GROUP BY a IS NULL;
SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c
FROM testData
GROUP BY a IS NULL;

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 64
-- Number of queries: 62
-- !query
@ -642,25 +642,3 @@ SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1
struct<avg(DISTINCT decimal_col):decimal(13,4),sum(DISTINCT decimal_col):decimal(19,0)>
-- !query output
1.0000 1
-- !query
SELECT not(a IS NULL), count(*) AS c
FROM testData
GROUP BY a IS NULL
-- !query schema
struct<(NOT (a IS NULL)):boolean,c:bigint>
-- !query output
false 2
true 7
-- !query
SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c
FROM testData
GROUP BY a IS NULL
-- !query schema
struct<(IF((NOT (a IS NULL)), rand(0), 1)):double,c:bigint>
-- !query output
0.7604953758285915 7
1.0 2