[SPARK-36677][SQL] NestedColumnAliasing should not push down aggregate functions into projections

### What changes were proposed in this pull request?

This PR filters out `ExtractValues`s that contains any aggregation function in the `NestedColumnAliasing` rule to prevent cases where aggregations are pushed down into projections.

### Why are the changes needed?

To handle a corner/missed case in `NestedColumnAliasing` that can cause users to encounter a runtime exception.

Consider the following schema:
```
root
 |-- a: struct (nullable = true)
 |    |-- c: struct (nullable = true)
 |    |    |-- e: string (nullable = true)
 |    |-- d: integer (nullable = true)
 |-- b: string (nullable = true)
```
and the query:
`SELECT MAX(a).c.e FROM (SELECT a, b FROM test_aggregates) GROUP BY b`

Executing the query before this PR will result in the error:
```
java.lang.UnsupportedOperationException: Cannot generate code for expression: max(input[0, struct<c:struct<e:string>,d:int>, true])
  at org.apache.spark.sql.errors.QueryExecutionErrors$.cannotGenerateCodeForExpressionError(QueryExecutionErrors.scala:83)
  at org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode(Expression.scala:312)
  at org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode$(Expression.scala:311)
  at org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression.doGenCode(interfaces.scala:99)
...
```
The optimised plan before this PR is:

```
'Aggregate [b#1], [_extract_e#5 AS max(a).c.e#3]
+- 'Project [max(a#0).c.e AS _extract_e#5, b#1]
   +- Relation default.test_aggregates[a#0,b#1] parquet
```

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

A new unit test in `NestedColumnAliasingSuite`. The test consists of the repro mentioned earlier.
The produced optimized plan is checked for equivalency with a plan of the form:
```
 Aggregate [b#452], [max(a#451).c.e AS max('a)[c][e]#456]
+- LocalRelation <empty>, [a#451, b#452]
```

Closes #33921 from vicennial/spark-36677.

Authored-by: Venkata Sai Akhil Gudesa <venkata.gudesa@databricks.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
This commit is contained in:
Venkata Sai Akhil Gudesa 2021-09-07 18:15:48 -07:00 committed by Liang-Chi Hsieh
parent 5a0ae694d0
commit 2ed6e7bc5d
2 changed files with 40 additions and 2 deletions

View file

@ -21,6 +21,7 @@ import scala.collection
import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@ -258,6 +259,13 @@ object NestedColumnAliasing {
.filter(!_.references.subsetOf(exclusiveAttrSet))
.groupBy(_.references.head.canonicalized.asInstanceOf[Attribute])
.flatMap { case (attr: Attribute, nestedFields: collection.Seq[ExtractValue]) =>
// Check if `ExtractValue` expressions contain any aggregate functions in their tree. Those
// that do should not have an alias generated as it can lead to pushing the aggregate down
// into a projection.
def containsAggregateFunction(ev: ExtractValue): Boolean =
ev.find(_.isInstanceOf[AggregateFunction]).isDefined
// Remove redundant [[ExtractValue]]s if they share the same parent nest field.
// For example, when `a.b` and `a.b.c` are in project list, we only need to alias `a.b`.
// Because `a.b` requires all of the inner fields of `b`, we cannot prune `a.b.c`.
@ -268,7 +276,10 @@ object NestedColumnAliasing {
val child = e.children.head
nestedFields.forall(f => child.find(_.semanticEquals(f)).isEmpty)
case _ => true
}.distinct
}
.distinct
// Discard [[ExtractValue]]s that contain aggregate functions.
.filterNot(containsAggregateFunction)
// If all nested fields of `attr` are used, we don't need to introduce new aliases.
// By default, the [[ColumnPruning]] rule uses `attr` already.
@ -276,7 +287,7 @@ object NestedColumnAliasing {
// nested field once.
val numUsedNestedFields = dedupNestedFields.map(_.canonicalized).distinct
.map { nestedField => totalFieldNum(nestedField.dataType) }.sum
if (numUsedNestedFields < totalFieldNum(attr.dataType)) {
if (dedupNestedFields.nonEmpty && numUsedNestedFields < totalFieldNum(attr.dataType)) {
Some((attr, dedupNestedFields.toSeq))
} else {
None

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.SchemaPruningTest
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
@ -763,6 +764,32 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
$"_extract_search_params.col2".as("col2")).analyze
comparePlans(optimized, query)
}
test("SPARK-36677: NestedColumnAliasing should not push down aggregate functions into " +
"projections") {
val nestedRelation = LocalRelation(
'a.struct(
'c.struct(
'e.string),
'd.string),
'b.string)
val plan = nestedRelation
.select($"a", $"b")
.groupBy($"b")(max($"a").getField("c").getField("e"))
.analyze
val optimized = Optimize.execute(plan)
// The plan should not contain aggregation functions inside the projection
SimpleAnalyzer.checkAnalysis(optimized)
val expected = nestedRelation
.groupBy($"b")(max($"a").getField("c").getField("e"))
.analyze
comparePlans(optimized, expected)
}
}
object NestedColumnAliasingSuite {