diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java index 039252348d..cf7dbb2978 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java @@ -20,7 +20,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import java.io.Serializable; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.NamedReference; /** * Aggregation in SQL statement. @@ -30,14 +30,14 @@ import org.apache.spark.sql.connector.expressions.FieldReference; @Evolving public final class Aggregation implements Serializable { private final AggregateFunc[] aggregateExpressions; - private final FieldReference[] groupByColumns; + private final NamedReference[] groupByColumns; - public Aggregation(AggregateFunc[] aggregateExpressions, FieldReference[] groupByColumns) { + public Aggregation(AggregateFunc[] aggregateExpressions, NamedReference[] groupByColumns) { this.aggregateExpressions = aggregateExpressions; this.groupByColumns = groupByColumns; } public AggregateFunc[] aggregateExpressions() { return aggregateExpressions; } - public FieldReference[] groupByColumns() { return groupByColumns; } + public NamedReference[] groupByColumns() { return groupByColumns; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java index 14493a4339..1273886e29 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.NamedReference; /** * An aggregate function that returns the number of the specific row in a group. @@ -27,15 +27,15 @@ import org.apache.spark.sql.connector.expressions.FieldReference; */ @Evolving public final class Count implements AggregateFunc { - private final FieldReference column; + private final NamedReference column; private final boolean isDistinct; - public Count(FieldReference column, boolean isDistinct) { + public Count(NamedReference column, boolean isDistinct) { this.column = column; this.isDistinct = isDistinct; } - public FieldReference column() { return column; } + public NamedReference column() { return column; } public boolean isDistinct() { return isDistinct; } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java index 985fd80552..ed07cc9e32 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.NamedReference; /** * An aggregate function that returns the maximum value in a group. @@ -27,11 +27,11 @@ import org.apache.spark.sql.connector.expressions.FieldReference; */ @Evolving public final class Max implements AggregateFunc { - private final FieldReference column; + private final NamedReference column; - public Max(FieldReference column) { this.column = column; } + public Max(NamedReference column) { this.column = column; } - public FieldReference column() { return column; } + public NamedReference column() { return column; } @Override public String toString() { return "MAX(" + column.describe() + ")"; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java index 7b7b557844..2e76103774 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.NamedReference; /** * An aggregate function that returns the minimum value in a group. @@ -27,11 +27,11 @@ import org.apache.spark.sql.connector.expressions.FieldReference; */ @Evolving public final class Min implements AggregateFunc { - private final FieldReference column; + private final NamedReference column; - public Min(FieldReference column) { this.column = column; } + public Min(NamedReference column) { this.column = column; } - public FieldReference column() { return column; } + public NamedReference column() { return column; } @Override public String toString() { return "MIN(" + column.describe() + ")"; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java index 66ce436e70..057ebd89f7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.NamedReference; /** * An aggregate function that returns the summation of all the values in a group. @@ -27,15 +27,15 @@ import org.apache.spark.sql.connector.expressions.FieldReference; */ @Evolving public final class Sum implements AggregateFunc { - private final FieldReference column; + private final NamedReference column; private final boolean isDistinct; - public Sum(FieldReference column, boolean isDistinct) { + public Sum(NamedReference column, boolean isDistinct) { this.column = column; this.isDistinct = isDistinct; } - public FieldReference column() { return column; } + public NamedReference column() { return column; } public boolean isDistinct() { return isDistinct; } @Override diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7a5c343133..a53665fe2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -702,20 +702,19 @@ object DataSourceStrategy if (aggregates.filter.isEmpty) { aggregates.aggregateFunction match { case aggregate.Min(PushableColumnWithoutNestedColumn(name)) => - Some(new Min(FieldReference(name).asInstanceOf[FieldReference])) + Some(new Min(FieldReference(name))) case aggregate.Max(PushableColumnWithoutNestedColumn(name)) => - Some(new Max(FieldReference(name).asInstanceOf[FieldReference])) + Some(new Max(FieldReference(name))) case count: aggregate.Count if count.children.length == 1 => count.children.head match { // SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table case Literal(_, _) => Some(new CountStar()) case PushableColumnWithoutNestedColumn(name) => - Some(new Count(FieldReference(name).asInstanceOf[FieldReference], - aggregates.isDistinct)) + Some(new Count(FieldReference(name), aggregates.isDistinct)) case _ => None } case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => - Some(new Sum(FieldReference(name).asInstanceOf[FieldReference], aggregates.isDistinct)) + Some(new Sum(FieldReference(name), aggregates.isDistinct)) case _ => None } } else {