[SPARK-36677][SQL] NestedColumnAliasing should not push down aggregate functions into projections
### What changes were proposed in this pull request?
This PR filters out `ExtractValues`s that contains any aggregation function in the `NestedColumnAliasing` rule to prevent cases where aggregations are pushed down into projections.
### Why are the changes needed?
To handle a corner/missed case in `NestedColumnAliasing` that can cause users to encounter a runtime exception.
Consider the following schema:
```
root
|-- a: struct (nullable = true)
| |-- c: struct (nullable = true)
| | |-- e: string (nullable = true)
| |-- d: integer (nullable = true)
|-- b: string (nullable = true)
```
and the query:
`SELECT MAX(a).c.e FROM (SELECT a, b FROM test_aggregates) GROUP BY b`
Executing the query before this PR will result in the error:
```
java.lang.UnsupportedOperationException: Cannot generate code for expression: max(input[0, struct<c:struct<e:string>,d:int>, true])
at org.apache.spark.sql.errors.QueryExecutionErrors$.cannotGenerateCodeForExpressionError(QueryExecutionErrors.scala:83)
at org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode(Expression.scala:312)
at org.apache.spark.sql.catalyst.expressions.Unevaluable.doGenCode$(Expression.scala:311)
at org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression.doGenCode(interfaces.scala:99)
...
```
The optimised plan before this PR is:
```
'Aggregate [b#1], [_extract_e#5 AS max(a).c.e#3]
+- 'Project [max(a#0).c.e AS _extract_e#5, b#1]
+- Relation default.test_aggregates[a#0,b#1] parquet
```
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
A new unit test in `NestedColumnAliasingSuite`. The test consists of the repro mentioned earlier.
The produced optimized plan is checked for equivalency with a plan of the form:
```
Aggregate [b#452], [max(a#451).c.e AS max('a)[c][e]#456]
+- LocalRelation <empty>, [a#451, b#452]
```
Closes #33921 from vicennial/spark-36677.
Authored-by: Venkata Sai Akhil Gudesa <venkata.gudesa@databricks.com>
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
(cherry picked from commit 2ed6e7bc5d
)
Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
This commit is contained in:
parent
e39948fada
commit
4a92b0e278
|
@ -21,6 +21,7 @@ import scala.collection
|
|||
import scala.collection.mutable
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -258,6 +259,13 @@ object NestedColumnAliasing {
|
|||
.filter(!_.references.subsetOf(exclusiveAttrSet))
|
||||
.groupBy(_.references.head.canonicalized.asInstanceOf[Attribute])
|
||||
.flatMap { case (attr: Attribute, nestedFields: collection.Seq[ExtractValue]) =>
|
||||
|
||||
// Check if `ExtractValue` expressions contain any aggregate functions in their tree. Those
|
||||
// that do should not have an alias generated as it can lead to pushing the aggregate down
|
||||
// into a projection.
|
||||
def containsAggregateFunction(ev: ExtractValue): Boolean =
|
||||
ev.find(_.isInstanceOf[AggregateFunction]).isDefined
|
||||
|
||||
// Remove redundant [[ExtractValue]]s if they share the same parent nest field.
|
||||
// For example, when `a.b` and `a.b.c` are in project list, we only need to alias `a.b`.
|
||||
// Because `a.b` requires all of the inner fields of `b`, we cannot prune `a.b.c`.
|
||||
|
@ -268,7 +276,10 @@ object NestedColumnAliasing {
|
|||
val child = e.children.head
|
||||
nestedFields.forall(f => child.find(_.semanticEquals(f)).isEmpty)
|
||||
case _ => true
|
||||
}.distinct
|
||||
}
|
||||
.distinct
|
||||
// Discard [[ExtractValue]]s that contain aggregate functions.
|
||||
.filterNot(containsAggregateFunction)
|
||||
|
||||
// If all nested fields of `attr` are used, we don't need to introduce new aliases.
|
||||
// By default, the [[ColumnPruning]] rule uses `attr` already.
|
||||
|
@ -276,7 +287,7 @@ object NestedColumnAliasing {
|
|||
// nested field once.
|
||||
val numUsedNestedFields = dedupNestedFields.map(_.canonicalized).distinct
|
||||
.map { nestedField => totalFieldNum(nestedField.dataType) }.sum
|
||||
if (numUsedNestedFields < totalFieldNum(attr.dataType)) {
|
||||
if (dedupNestedFields.nonEmpty && numUsedNestedFields < totalFieldNum(attr.dataType)) {
|
||||
Some((attr, dedupNestedFields.toSeq))
|
||||
} else {
|
||||
None
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
|
|||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.sql.catalyst.SchemaPruningTest
|
||||
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
|
@ -763,6 +764,32 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
|
|||
$"_extract_search_params.col2".as("col2")).analyze
|
||||
comparePlans(optimized, query)
|
||||
}
|
||||
|
||||
test("SPARK-36677: NestedColumnAliasing should not push down aggregate functions into " +
|
||||
"projections") {
|
||||
val nestedRelation = LocalRelation(
|
||||
'a.struct(
|
||||
'c.struct(
|
||||
'e.string),
|
||||
'd.string),
|
||||
'b.string)
|
||||
|
||||
val plan = nestedRelation
|
||||
.select($"a", $"b")
|
||||
.groupBy($"b")(max($"a").getField("c").getField("e"))
|
||||
.analyze
|
||||
|
||||
val optimized = Optimize.execute(plan)
|
||||
|
||||
// The plan should not contain aggregation functions inside the projection
|
||||
SimpleAnalyzer.checkAnalysis(optimized)
|
||||
|
||||
val expected = nestedRelation
|
||||
.groupBy($"b")(max($"a").getField("c").getField("e"))
|
||||
.analyze
|
||||
|
||||
comparePlans(optimized, expected)
|
||||
}
|
||||
}
|
||||
|
||||
object NestedColumnAliasingSuite {
|
||||
|
|
Loading…
Reference in a new issue