[SPARK-34952][SQL][FOLLOWUP] Change column type to be NamedReference

### What changes were proposed in this pull request?
Currently, we have `FieldReference` for aggregate column type, should be `NamedReference` instead

### Why are the changes needed?
`FieldReference` is a private class, should use `NamedReference` instead

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
existing tests

Closes #33927 from huaxingao/agg_followup.

Authored-by: Huaxin Gao <huaxin_gao@apple.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 23794fb303)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Huaxin Gao 2021-09-08 14:05:44 +08:00 committed by Wenchen Fan
parent c4332c7bf0
commit 7e8860751c
6 changed files with 24 additions and 25 deletions

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.connector.expressions.aggregate;
import java.io.Serializable; import java.io.Serializable;
import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.FieldReference; import org.apache.spark.sql.connector.expressions.NamedReference;
/** /**
* Aggregation in SQL statement. * Aggregation in SQL statement.
@ -30,14 +30,14 @@ import org.apache.spark.sql.connector.expressions.FieldReference;
@Evolving @Evolving
public final class Aggregation implements Serializable { public final class Aggregation implements Serializable {
private final AggregateFunc[] aggregateExpressions; private final AggregateFunc[] aggregateExpressions;
private final FieldReference[] groupByColumns; private final NamedReference[] groupByColumns;
public Aggregation(AggregateFunc[] aggregateExpressions, FieldReference[] groupByColumns) { public Aggregation(AggregateFunc[] aggregateExpressions, NamedReference[] groupByColumns) {
this.aggregateExpressions = aggregateExpressions; this.aggregateExpressions = aggregateExpressions;
this.groupByColumns = groupByColumns; this.groupByColumns = groupByColumns;
} }
public AggregateFunc[] aggregateExpressions() { return aggregateExpressions; } public AggregateFunc[] aggregateExpressions() { return aggregateExpressions; }
public FieldReference[] groupByColumns() { return groupByColumns; } public NamedReference[] groupByColumns() { return groupByColumns; }
} }

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.connector.expressions.aggregate; package org.apache.spark.sql.connector.expressions.aggregate;
import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.FieldReference; import org.apache.spark.sql.connector.expressions.NamedReference;
/** /**
* An aggregate function that returns the number of the specific row in a group. * An aggregate function that returns the number of the specific row in a group.
@ -27,15 +27,15 @@ import org.apache.spark.sql.connector.expressions.FieldReference;
*/ */
@Evolving @Evolving
public final class Count implements AggregateFunc { public final class Count implements AggregateFunc {
private final FieldReference column; private final NamedReference column;
private final boolean isDistinct; private final boolean isDistinct;
public Count(FieldReference column, boolean isDistinct) { public Count(NamedReference column, boolean isDistinct) {
this.column = column; this.column = column;
this.isDistinct = isDistinct; this.isDistinct = isDistinct;
} }
public FieldReference column() { return column; } public NamedReference column() { return column; }
public boolean isDistinct() { return isDistinct; } public boolean isDistinct() { return isDistinct; }
@Override @Override

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.connector.expressions.aggregate; package org.apache.spark.sql.connector.expressions.aggregate;
import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.FieldReference; import org.apache.spark.sql.connector.expressions.NamedReference;
/** /**
* An aggregate function that returns the maximum value in a group. * An aggregate function that returns the maximum value in a group.
@ -27,11 +27,11 @@ import org.apache.spark.sql.connector.expressions.FieldReference;
*/ */
@Evolving @Evolving
public final class Max implements AggregateFunc { public final class Max implements AggregateFunc {
private final FieldReference column; private final NamedReference column;
public Max(FieldReference column) { this.column = column; } public Max(NamedReference column) { this.column = column; }
public FieldReference column() { return column; } public NamedReference column() { return column; }
@Override @Override
public String toString() { return "MAX(" + column.describe() + ")"; } public String toString() { return "MAX(" + column.describe() + ")"; }

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.connector.expressions.aggregate; package org.apache.spark.sql.connector.expressions.aggregate;
import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.FieldReference; import org.apache.spark.sql.connector.expressions.NamedReference;
/** /**
* An aggregate function that returns the minimum value in a group. * An aggregate function that returns the minimum value in a group.
@ -27,11 +27,11 @@ import org.apache.spark.sql.connector.expressions.FieldReference;
*/ */
@Evolving @Evolving
public final class Min implements AggregateFunc { public final class Min implements AggregateFunc {
private final FieldReference column; private final NamedReference column;
public Min(FieldReference column) { this.column = column; } public Min(NamedReference column) { this.column = column; }
public FieldReference column() { return column; } public NamedReference column() { return column; }
@Override @Override
public String toString() { return "MIN(" + column.describe() + ")"; } public String toString() { return "MIN(" + column.describe() + ")"; }

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.connector.expressions.aggregate; package org.apache.spark.sql.connector.expressions.aggregate;
import org.apache.spark.annotation.Evolving; import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.FieldReference; import org.apache.spark.sql.connector.expressions.NamedReference;
/** /**
* An aggregate function that returns the summation of all the values in a group. * An aggregate function that returns the summation of all the values in a group.
@ -27,15 +27,15 @@ import org.apache.spark.sql.connector.expressions.FieldReference;
*/ */
@Evolving @Evolving
public final class Sum implements AggregateFunc { public final class Sum implements AggregateFunc {
private final FieldReference column; private final NamedReference column;
private final boolean isDistinct; private final boolean isDistinct;
public Sum(FieldReference column, boolean isDistinct) { public Sum(NamedReference column, boolean isDistinct) {
this.column = column; this.column = column;
this.isDistinct = isDistinct; this.isDistinct = isDistinct;
} }
public FieldReference column() { return column; } public NamedReference column() { return column; }
public boolean isDistinct() { return isDistinct; } public boolean isDistinct() { return isDistinct; }
@Override @Override

View file

@ -702,20 +702,19 @@ object DataSourceStrategy
if (aggregates.filter.isEmpty) { if (aggregates.filter.isEmpty) {
aggregates.aggregateFunction match { aggregates.aggregateFunction match {
case aggregate.Min(PushableColumnWithoutNestedColumn(name)) => case aggregate.Min(PushableColumnWithoutNestedColumn(name)) =>
Some(new Min(FieldReference(name).asInstanceOf[FieldReference])) Some(new Min(FieldReference(name)))
case aggregate.Max(PushableColumnWithoutNestedColumn(name)) => case aggregate.Max(PushableColumnWithoutNestedColumn(name)) =>
Some(new Max(FieldReference(name).asInstanceOf[FieldReference])) Some(new Max(FieldReference(name)))
case count: aggregate.Count if count.children.length == 1 => case count: aggregate.Count if count.children.length == 1 =>
count.children.head match { count.children.head match {
// SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table // SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table
case Literal(_, _) => Some(new CountStar()) case Literal(_, _) => Some(new CountStar())
case PushableColumnWithoutNestedColumn(name) => case PushableColumnWithoutNestedColumn(name) =>
Some(new Count(FieldReference(name).asInstanceOf[FieldReference], Some(new Count(FieldReference(name), aggregates.isDistinct))
aggregates.isDistinct))
case _ => None case _ => None
} }
case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
Some(new Sum(FieldReference(name).asInstanceOf[FieldReference], aggregates.isDistinct)) Some(new Sum(FieldReference(name), aggregates.isDistinct))
case _ => None case _ => None
} }
} else { } else {