diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index ab5a0feb62..8b253da3d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -93,8 +93,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { agg } } - val pushedAggregates = PushDownUtils - .pushAggregates(sHolder.builder, aggregates, groupingExpressions) + val normalizedAggregates = DataSourceStrategy.normalizeExprs( + aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] + val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs( + groupingExpressions, sHolder.relation.output) + val pushedAggregates = PushDownUtils.pushAggregates( + sHolder.builder, normalizedAggregates, normalizedGroupingExpressions) if (pushedAggregates.isEmpty) { aggNode // return original plan node } else { @@ -115,7 +119,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { // scalastyle:on val newOutput = scan.readSchema().toAttributes assert(newOutput.length == groupingExpressions.length + aggregates.length) - val groupAttrs = groupingExpressions.zip(newOutput).map { + val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map { case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) case (_, b) => b } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 37bc35210e..526dad91e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -239,8 +239,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } test("scan with aggregate push-down: MAX MIN with filter and group by") { - val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" + - " group by DEPT") + val df = sql("select MAX(SaLaRY), MIN(BONUS) FROM h2.test.employee where dept > 0" + + " group by DePt") val filters = df.queryExecution.optimizedPlan.collect { case f: Filter => f }