Revert "[SPARK-34581][SQL] Don't optimize out grouping expressions from aggregate expressions without aggregate function"
This reverts commit c8d78a70b4
.
This commit is contained in:
parent
20d68dc2f4
commit
fdccd88c2a
|
@ -17,8 +17,8 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.analysis
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, GroupingExprRef, NamedExpression}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
|
||||
/**
|
||||
|
@ -52,22 +52,3 @@ object UpdateAttributeNullability extends Rule[LogicalPlan] {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates nullability of [[GroupingExprRef]]s in a resolved LogicalPlan by using the nullability of
|
||||
* referenced grouping expression.
|
||||
*/
|
||||
object UpdateGroupingExprRefNullability extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case a: Aggregate =>
|
||||
val nullabilities = a.groupingExpressions.map(_.nullable).toArray
|
||||
|
||||
val newAggregateExpressions =
|
||||
a.aggregateExpressions.map(_.transform {
|
||||
case g: GroupingExprRef if g.nullable != nullabilities(g.ordinal) =>
|
||||
g.copy(nullable = nullabilities(g.ordinal))
|
||||
}.asInstanceOf[NamedExpression])
|
||||
|
||||
a.copy(aggregateExpressions = newAggregateExpressions)
|
||||
}
|
||||
}
|
|
@ -35,7 +35,7 @@ trait AliasHelper {
|
|||
protected def getAliasMap(plan: Aggregate): AttributeMap[Alias] = {
|
||||
// Find all the aliased expressions in the aggregate list that don't include any actual
|
||||
// AggregateExpression or PythonUDF, and create a map from the alias to the expression
|
||||
val aliasMap = plan.aggregateExpressionsWithoutGroupingRefs.collect {
|
||||
val aliasMap = plan.aggregateExpressions.collect {
|
||||
case a: Alias if a.child.find(e => e.isInstanceOf[AggregateExpression] ||
|
||||
PythonUDF.isGroupedAggPandasUDF(e)).isEmpty =>
|
||||
(a.toAttribute, a)
|
||||
|
|
|
@ -80,14 +80,6 @@ object AggregateExpression {
|
|||
filter,
|
||||
NamedExpression.newExprId)
|
||||
}
|
||||
|
||||
def containsAggregate(expr: Expression): Boolean = {
|
||||
expr.find(isAggregate).isDefined
|
||||
}
|
||||
|
||||
def isAggregate(expr: Expression): Boolean = {
|
||||
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -277,22 +277,3 @@ object GroupingAnalytics {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A reference to an grouping expression in [[Aggregate]] node.
|
||||
*
|
||||
* @param ordinal The ordinal of the grouping expression in [[Aggregate]] that this expression
|
||||
* refers to.
|
||||
* @param dataType The [[DataType]] of the referenced grouping expression.
|
||||
* @param nullable True if null is a valid value for the referenced grouping expression.
|
||||
*/
|
||||
case class GroupingExprRef(
|
||||
ordinal: Int,
|
||||
dataType: DataType,
|
||||
nullable: Boolean)
|
||||
extends LeafExpression with Unevaluable {
|
||||
|
||||
override def stringArgs: Iterator[Any] = {
|
||||
Iterator(ordinal)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.optimizer
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
|
||||
/**
|
||||
|
@ -26,6 +26,15 @@ import org.apache.spark.sql.catalyst.rules.Rule
|
|||
*/
|
||||
object SimplifyExtractValueOps extends Rule[LogicalPlan] {
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
// One place where this optimization is invalid is an aggregation where the select
|
||||
// list expression is a function of a grouping expression:
|
||||
//
|
||||
// SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b)
|
||||
//
|
||||
// cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this
|
||||
// optimization for Aggregates (although this misses some cases where the optimization
|
||||
// can be made).
|
||||
case a: Aggregate => a
|
||||
case p => p.transformExpressionsUp {
|
||||
// Remove redundant field extraction.
|
||||
case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) =>
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.optimizer
|
||||
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
|
||||
/**
|
||||
* This rule ensures that [[Aggregate]] nodes contain all required [[GroupingExprRef]]
|
||||
* references for optimization phase.
|
||||
*/
|
||||
object EnforceGroupingReferencesInAggregates extends Rule[LogicalPlan] {
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = {
|
||||
plan transform {
|
||||
case a: Aggregate =>
|
||||
Aggregate.withGroupingRefs(a.groupingExpressions, a.aggregateExpressions, a.child)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -119,8 +119,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
|
|||
OptimizeUpdateFields,
|
||||
SimplifyExtractValueOps,
|
||||
OptimizeCsvJsonExprs,
|
||||
CombineConcats,
|
||||
UpdateGroupingExprRefNullability) ++
|
||||
CombineConcats) ++
|
||||
extendedOperatorOptimizationRules
|
||||
|
||||
val operatorOptimizationBatch: Seq[Batch] = {
|
||||
|
@ -149,7 +148,6 @@ abstract class Optimizer(catalogManager: CatalogManager)
|
|||
EliminateView,
|
||||
ReplaceExpressions,
|
||||
RewriteNonCorrelatedExists,
|
||||
EnforceGroupingReferencesInAggregates,
|
||||
ComputeCurrentTime,
|
||||
GetCurrentDatabaseAndCatalog(catalogManager)) ::
|
||||
//////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -269,9 +267,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
|
|||
RewriteCorrelatedScalarSubquery.ruleName ::
|
||||
RewritePredicateSubquery.ruleName ::
|
||||
NormalizeFloatingNumbers.ruleName ::
|
||||
ReplaceUpdateFieldsExpression.ruleName ::
|
||||
EnforceGroupingReferencesInAggregates.ruleName ::
|
||||
UpdateGroupingExprRefNullability.ruleName :: Nil
|
||||
ReplaceUpdateFieldsExpression.ruleName :: Nil
|
||||
|
||||
/**
|
||||
* Optimize all the subqueries inside expression.
|
||||
|
@ -512,7 +508,7 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
|
|||
case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, lower) =>
|
||||
val aliasMap = getAliasMap(lower)
|
||||
|
||||
val newAggregate = Aggregate.withGroupingRefs(
|
||||
val newAggregate = upper.copy(
|
||||
child = lower.child,
|
||||
groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)),
|
||||
aggregateExpressions = upper.aggregateExpressions.map(
|
||||
|
@ -528,19 +524,23 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
|
|||
}
|
||||
|
||||
private def lowerIsRedundant(upper: Aggregate, lower: Aggregate): Boolean = {
|
||||
val upperHasNoAggregateExpressions =
|
||||
!upper.aggregateExpressions.exists(AggregateExpression.containsAggregate)
|
||||
val upperHasNoAggregateExpressions = !upper.aggregateExpressions.exists(isAggregate)
|
||||
|
||||
lazy val upperRefsOnlyDeterministicNonAgg = upper.references.subsetOf(AttributeSet(
|
||||
lower
|
||||
.aggregateExpressions
|
||||
.filter(_.deterministic)
|
||||
.filterNot(AggregateExpression.containsAggregate)
|
||||
.filter(!isAggregate(_))
|
||||
.map(_.toAttribute)
|
||||
))
|
||||
|
||||
upperHasNoAggregateExpressions && upperRefsOnlyDeterministicNonAgg
|
||||
}
|
||||
|
||||
private def isAggregate(expr: Expression): Boolean = {
|
||||
expr.find(e => e.isInstanceOf[AggregateExpression] ||
|
||||
PythonUDF.isGroupedAggPandasUDF(e)).isDefined
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1978,18 +1978,7 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
|
|||
case a @ Aggregate(grouping, _, _) if grouping.nonEmpty =>
|
||||
val newGrouping = grouping.filter(!_.foldable)
|
||||
if (newGrouping.nonEmpty) {
|
||||
val droppedGroupsBefore =
|
||||
grouping.scanLeft(0)((n, e) => n + (if (e.foldable) 1 else 0)).toArray
|
||||
|
||||
val newAggregateExpressions =
|
||||
a.aggregateExpressions.map(_.transform {
|
||||
case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 =>
|
||||
g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal))
|
||||
}.asInstanceOf[NamedExpression])
|
||||
|
||||
a.copy(
|
||||
groupingExpressions = newGrouping,
|
||||
aggregateExpressions = newAggregateExpressions)
|
||||
a.copy(groupingExpressions = newGrouping)
|
||||
} else {
|
||||
// All grouping expressions are literals. We should not drop them all, because this can
|
||||
// change the return semantics when the input of the Aggregate is empty (SPARK-17114). We
|
||||
|
@ -2010,25 +1999,7 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
|
|||
if (newGrouping.size == grouping.size) {
|
||||
a
|
||||
} else {
|
||||
var i = 0
|
||||
val droppedGroupsBefore = grouping.scanLeft(0)((n, e) =>
|
||||
n + (if (i >= newGrouping.size || e.eq(newGrouping(i))) {
|
||||
i += 1
|
||||
0
|
||||
} else {
|
||||
1
|
||||
})
|
||||
).toArray
|
||||
|
||||
val newAggregateExpressions =
|
||||
a.aggregateExpressions.map(_.transform {
|
||||
case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 =>
|
||||
g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal))
|
||||
}.asInstanceOf[NamedExpression])
|
||||
|
||||
a.copy(
|
||||
groupingExpressions = newGrouping,
|
||||
aggregateExpressions = newAggregateExpressions)
|
||||
a.copy(groupingExpressions = newGrouping)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -633,10 +633,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
|
|||
* subqueries.
|
||||
*/
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
|
||||
case a @ Aggregate(grouping, _, child) =>
|
||||
case a @ Aggregate(grouping, expressions, child) =>
|
||||
val subqueries = ArrayBuffer.empty[ScalarSubquery]
|
||||
val rewriteExprs = a.aggregateExpressionsWithoutGroupingRefs
|
||||
.map(extractCorrelatedScalarSubqueries(_, subqueries))
|
||||
val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
|
||||
if (subqueries.nonEmpty) {
|
||||
// We currently only allow correlated subqueries in an aggregate if they are part of the
|
||||
// grouping expressions. As a result we need to replace all the scalar subqueries in the
|
||||
|
|
|
@ -287,7 +287,7 @@ object PhysicalAggregation {
|
|||
(Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)
|
||||
|
||||
def unapply(a: Any): Option[ReturnType] = a match {
|
||||
case a @ logical.Aggregate(groupingExpressions, resultExpressions, child) =>
|
||||
case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
|
||||
// A single aggregate expression might appear multiple times in resultExpressions.
|
||||
// In order to avoid evaluating an individual aggregate function multiple times, we'll
|
||||
// build a set of semantically distinct aggregate expressions and re-write expressions so
|
||||
|
@ -297,9 +297,11 @@ object PhysicalAggregation {
|
|||
val aggregateExpressions = resultExpressions.flatMap { expr =>
|
||||
expr.collect {
|
||||
// addExpr() always returns false for non-deterministic expressions and do not add them.
|
||||
case a
|
||||
if AggregateExpression.isAggregate(a) && !equivalentAggregateExpressions.addExpr(a) =>
|
||||
a
|
||||
case agg: AggregateExpression
|
||||
if !equivalentAggregateExpressions.addExpr(agg) => agg
|
||||
case udf: PythonUDF
|
||||
if PythonUDF.isGroupedAggPandasUDF(udf) &&
|
||||
!equivalentAggregateExpressions.addExpr(udf) => udf
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -320,7 +322,7 @@ object PhysicalAggregation {
|
|||
// which takes the grouping columns and final aggregate result buffer as input.
|
||||
// Thus, we must re-write the result expressions so that their attributes match up with
|
||||
// the attributes of the final result projection's input row:
|
||||
val rewrittenResultExpressions = a.aggregateExpressionsWithoutGroupingRefs.map { expr =>
|
||||
val rewrittenResultExpressions = resultExpressions.map { expr =>
|
||||
expr.transformDown {
|
||||
case ae: AggregateExpression =>
|
||||
// The final aggregation buffer's attributes will be `finalAggregationAttributes`,
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.plans.logical
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.spark.sql.catalyst.AliasIdentifier
|
||||
import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, MultiInstanceRelation, TypeCoercion, TypeCoercionBase}
|
||||
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable}
|
||||
|
@ -783,23 +781,14 @@ case class Range(
|
|||
/**
|
||||
* This is a Group by operator with the aggregate functions and projections.
|
||||
*
|
||||
* @param groupingExpressions Expressions for grouping keys.
|
||||
* @param aggregateExpressions Expressions for a project list, which can contain
|
||||
* [[AggregateExpression]]s and [[GroupingExprRef]]s.
|
||||
* @param child The child of the aggregate node.
|
||||
* @param groupingExpressions expressions for grouping keys
|
||||
* @param aggregateExpressions expressions for a project list, which could contain
|
||||
* [[AggregateExpression]]s.
|
||||
*
|
||||
* Expressions without aggregate functions in [[aggregateExpressions]] can contain
|
||||
* [[GroupingExprRef]]s to refer to complex grouping expressions in [[groupingExpressions]]. These
|
||||
* references ensure that optimization rules don't change the aggregate expressions to invalid ones
|
||||
* that no longer refer to any grouping expressions and also simplify the expression transformations
|
||||
* on the node (need to transform the expression only once).
|
||||
*
|
||||
* For example, in the following query Spark shouldn't optimize the aggregate expression
|
||||
* `Not(IsNull(c))` to `IsNotNull(c)` as the grouping expression is `IsNull(c)`:
|
||||
* SELECT not(c IS NULL)
|
||||
* FROM t
|
||||
* GROUP BY c IS NULL
|
||||
* Instead, the aggregate expression should contain `Not(GroupingExprRef(0))`.
|
||||
* Note: Currently, aggregateExpressions is the project list of this Group by operator. Before
|
||||
* separating projection from grouping and aggregate, we should avoid expression-level optimization
|
||||
* on aggregateExpressions, which could reference an expression in groupingExpressions.
|
||||
* For example, see the rule [[org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps]]
|
||||
*/
|
||||
case class Aggregate(
|
||||
groupingExpressions: Seq[Expression],
|
||||
|
@ -826,21 +815,8 @@ case class Aggregate(
|
|||
}
|
||||
}
|
||||
|
||||
private def expandGroupingReferences(e: Expression): Expression = {
|
||||
e match {
|
||||
case _ if AggregateExpression.isAggregate(e) => e
|
||||
case g: GroupingExprRef => groupingExpressions(g.ordinal)
|
||||
case _ => e.mapChildren(expandGroupingReferences)
|
||||
}
|
||||
}
|
||||
|
||||
lazy val aggregateExpressionsWithoutGroupingRefs = {
|
||||
aggregateExpressions.map(expandGroupingReferences(_).asInstanceOf[NamedExpression])
|
||||
}
|
||||
|
||||
override lazy val validConstraints: ExpressionSet = {
|
||||
val nonAgg = aggregateExpressionsWithoutGroupingRefs.
|
||||
filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty)
|
||||
val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty)
|
||||
getAllValidConstraints(nonAgg)
|
||||
}
|
||||
|
||||
|
@ -848,51 +824,6 @@ case class Aggregate(
|
|||
copy(child = newChild)
|
||||
}
|
||||
|
||||
object Aggregate {
|
||||
private def collectComplexGroupingExpressions(groupingExpressions: Seq[Expression]) = {
|
||||
val complexGroupingExpressions = mutable.Map.empty[Expression, (Expression, Int)]
|
||||
var i = 0
|
||||
groupingExpressions.foreach { ge =>
|
||||
if (!ge.foldable && ge.children.nonEmpty &&
|
||||
!complexGroupingExpressions.contains(ge.canonicalized)) {
|
||||
complexGroupingExpressions += ge.canonicalized -> (ge, i)
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
complexGroupingExpressions
|
||||
}
|
||||
|
||||
private def insertGroupingReferences(
|
||||
aggregateExpressions: Seq[NamedExpression],
|
||||
groupingExpressions: collection.Map[Expression, (Expression, Int)]): Seq[NamedExpression] = {
|
||||
def insertGroupingExprRefs(e: Expression): Expression = {
|
||||
e match {
|
||||
case _ if AggregateExpression.isAggregate(e) => e
|
||||
case _ if groupingExpressions.contains(e.canonicalized) =>
|
||||
val (groupingExpression, ordinal) = groupingExpressions(e.canonicalized)
|
||||
GroupingExprRef(ordinal, groupingExpression.dataType, groupingExpression.nullable)
|
||||
case _ => e.mapChildren(insertGroupingExprRefs)
|
||||
}
|
||||
}
|
||||
|
||||
aggregateExpressions.map(insertGroupingExprRefs(_).asInstanceOf[NamedExpression])
|
||||
}
|
||||
|
||||
def withGroupingRefs(
|
||||
groupingExpressions: Seq[Expression],
|
||||
aggregateExpressions: Seq[NamedExpression],
|
||||
child: LogicalPlan): Aggregate = {
|
||||
val complexGroupingExpressions = collectComplexGroupingExpressions(groupingExpressions)
|
||||
val aggrExprWithGroupingReferences = if (complexGroupingExpressions.nonEmpty) {
|
||||
insertGroupingReferences(aggregateExpressions, complexGroupingExpressions)
|
||||
} else {
|
||||
aggregateExpressions
|
||||
}
|
||||
|
||||
new Aggregate(groupingExpressions, aggrExprWithGroupingReferences, child)
|
||||
}
|
||||
}
|
||||
|
||||
case class Window(
|
||||
windowExpressions: Seq[NamedExpression],
|
||||
partitionSpec: Seq[Expression],
|
||||
|
|
|
@ -96,7 +96,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
|
|||
.groupBy('a + 'b)(('a + 'b) as 'c)
|
||||
.analyze
|
||||
val optimized = Optimize.execute(query)
|
||||
comparePlans(optimized, EnforceGroupingReferencesInAggregates(expected))
|
||||
comparePlans(optimized, expected)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -36,8 +36,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
|
|||
|
||||
object Optimizer extends RuleExecutor[LogicalPlan] {
|
||||
val batches =
|
||||
Batch("Finish Analysis", Once,
|
||||
EnforceGroupingReferencesInAggregates) ::
|
||||
Batch("collapse projections", FixedPoint(10),
|
||||
CollapseProject) ::
|
||||
Batch("Constant Folding", FixedPoint(10),
|
||||
|
@ -59,7 +57,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
|
|||
private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = {
|
||||
val optimized = Optimizer.execute(originalQuery.analyze)
|
||||
assert(optimized.resolved, "optimized plans must be still resolvable")
|
||||
comparePlans(optimized, EnforceGroupingReferencesInAggregates(correctAnswer.analyze))
|
||||
comparePlans(optimized, correctAnswer.analyze)
|
||||
}
|
||||
|
||||
test("explicit get from namedStruct") {
|
||||
|
@ -407,6 +405,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
|
|||
val arrayAggRel = relation.groupBy(
|
||||
CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0))
|
||||
checkRule(arrayAggRel, arrayAggRel)
|
||||
|
||||
// This could be done if we had a more complex rule that checks that
|
||||
// the CreateMap does not come from key.
|
||||
val originalQuery = relation
|
||||
.groupBy('id)(
|
||||
GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a"
|
||||
)
|
||||
checkRule(originalQuery, originalQuery)
|
||||
}
|
||||
|
||||
test("SPARK-23500: namedStruct and getField in the same Project #1") {
|
||||
|
|
|
@ -40,7 +40,6 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
|
|||
private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
|
||||
e.isInstanceOf[AggregateExpression] ||
|
||||
PythonUDF.isGroupedAggPandasUDF(e) ||
|
||||
e.isInstanceOf[GroupingExprRef] ||
|
||||
agg.groupingExpressions.exists(_.semanticEquals(e))
|
||||
}
|
||||
|
||||
|
@ -120,8 +119,23 @@ object ExtractGroupingPythonUDFFromAggregate extends Rule[LogicalPlan] {
|
|||
groupingExpr += expr
|
||||
}
|
||||
}
|
||||
val aggExpr = agg.aggregateExpressions.map { expr =>
|
||||
expr.transformUp {
|
||||
// PythonUDF over aggregate was pull out by ExtractPythonUDFFromAggregate.
|
||||
// PythonUDF here should be either
|
||||
// 1. Argument of an aggregate function.
|
||||
// CheckAnalysis guarantees the arguments are deterministic.
|
||||
// 2. PythonUDF in grouping key. Grouping key must be deterministic.
|
||||
// 3. PythonUDF not in grouping key. It is either no arguments or with grouping key
|
||||
// in its arguments. Such PythonUDF was pull out by ExtractPythonUDFFromAggregate, too.
|
||||
case p: PythonUDF if p.udfDeterministic =>
|
||||
val canonicalized = p.canonicalized.asInstanceOf[PythonUDF]
|
||||
attributeMap.getOrElse(canonicalized, p)
|
||||
}.asInstanceOf[NamedExpression]
|
||||
}
|
||||
agg.copy(
|
||||
groupingExpressions = groupingExpr.toSeq,
|
||||
aggregateExpressions = aggExpr,
|
||||
child = Project((projList ++ agg.child.output).toSeq, agg.child))
|
||||
}
|
||||
|
||||
|
|
|
@ -179,12 +179,3 @@ SELECT count(*) FROM test_agg WHERE k = 1 or k = 2 or count(*) + 1L > 1L or max(
|
|||
|
||||
-- Aggregate with multiple distinct decimal columns
|
||||
SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1 AS DECIMAL(9, 0))) t(decimal_col);
|
||||
|
||||
-- SPARK-34581: Don't optimize out grouping expressions from aggregate expressions without aggregate function
|
||||
SELECT not(a IS NULL), count(*) AS c
|
||||
FROM testData
|
||||
GROUP BY a IS NULL;
|
||||
|
||||
SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c
|
||||
FROM testData
|
||||
GROUP BY a IS NULL;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
-- Automatically generated by SQLQueryTestSuite
|
||||
-- Number of queries: 64
|
||||
-- Number of queries: 62
|
||||
|
||||
|
||||
-- !query
|
||||
|
@ -642,25 +642,3 @@ SELECT AVG(DISTINCT decimal_col), SUM(DISTINCT decimal_col) FROM VALUES (CAST(1
|
|||
struct<avg(DISTINCT decimal_col):decimal(13,4),sum(DISTINCT decimal_col):decimal(19,0)>
|
||||
-- !query output
|
||||
1.0000 1
|
||||
|
||||
|
||||
-- !query
|
||||
SELECT not(a IS NULL), count(*) AS c
|
||||
FROM testData
|
||||
GROUP BY a IS NULL
|
||||
-- !query schema
|
||||
struct<(NOT (a IS NULL)):boolean,c:bigint>
|
||||
-- !query output
|
||||
false 2
|
||||
true 7
|
||||
|
||||
|
||||
-- !query
|
||||
SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c
|
||||
FROM testData
|
||||
GROUP BY a IS NULL
|
||||
-- !query schema
|
||||
struct<(IF((NOT (a IS NULL)), rand(0), 1)):double,c:bigint>
|
||||
-- !query output
|
||||
0.7604953758285915 7
|
||||
1.0 2
|
||||
|
|
Loading…
Reference in a new issue