diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index 9facae3b57..e2553f7832 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -21,6 +21,7 @@ import scala.collection import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -258,6 +259,13 @@ object NestedColumnAliasing { .filter(!_.references.subsetOf(exclusiveAttrSet)) .groupBy(_.references.head.canonicalized.asInstanceOf[Attribute]) .flatMap { case (attr: Attribute, nestedFields: collection.Seq[ExtractValue]) => + + // Check if `ExtractValue` expressions contain any aggregate functions in their tree. Those + // that do should not have an alias generated as it can lead to pushing the aggregate down + // into a projection. + def containsAggregateFunction(ev: ExtractValue): Boolean = + ev.find(_.isInstanceOf[AggregateFunction]).isDefined + // Remove redundant [[ExtractValue]]s if they share the same parent nest field. // For example, when `a.b` and `a.b.c` are in project list, we only need to alias `a.b`. // Because `a.b` requires all of the inner fields of `b`, we cannot prune `a.b.c`. @@ -268,7 +276,10 @@ object NestedColumnAliasing { val child = e.children.head nestedFields.forall(f => child.find(_.semanticEquals(f)).isEmpty) case _ => true - }.distinct + } + .distinct + // Discard [[ExtractValue]]s that contain aggregate functions. + .filterNot(containsAggregateFunction) // If all nested fields of `attr` are used, we don't need to introduce new aliases. // By default, the [[ColumnPruning]] rule uses `attr` already. @@ -276,7 +287,7 @@ object NestedColumnAliasing { // nested field once. val numUsedNestedFields = dedupNestedFields.map(_.canonicalized).distinct .map { nestedField => totalFieldNum(nestedField.dataType) }.sum - if (numUsedNestedFields < totalFieldNum(attr.dataType)) { + if (dedupNestedFields.nonEmpty && numUsedNestedFields < totalFieldNum(attr.dataType)) { Some((attr, dedupNestedFields.toSeq)) } else { None diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala index e49e028423..40ab72c89f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.SchemaPruningTest +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -763,6 +764,32 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { $"_extract_search_params.col2".as("col2")).analyze comparePlans(optimized, query) } + + test("SPARK-36677: NestedColumnAliasing should not push down aggregate functions into " + + "projections") { + val nestedRelation = LocalRelation( + 'a.struct( + 'c.struct( + 'e.string), + 'd.string), + 'b.string) + + val plan = nestedRelation + .select($"a", $"b") + .groupBy($"b")(max($"a").getField("c").getField("e")) + .analyze + + val optimized = Optimize.execute(plan) + + // The plan should not contain aggregation functions inside the projection + SimpleAnalyzer.checkAnalysis(optimized) + + val expected = nestedRelation + .groupBy($"b")(max($"a").getField("c").getField("e")) + .analyze + + comparePlans(optimized, expected) + } } object NestedColumnAliasingSuite {