diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md index 6d44a229bf..16d525eea3 100644 --- a/docs/sql-data-sources-jdbc.md +++ b/docs/sql-data-sources-jdbc.md @@ -237,6 +237,15 @@ logging into the data sources. read + + pushDownAggregate + false + + The option to enable or disable aggregate push-down into the JDBC data source. The default value is false, in which case Spark will not push down aggregates to the JDBC data source. Otherwise, if sets to true, aggregates will be pushed down to the JDBC data source. Aggregate push-down is usually turned off when the aggregate is performed faster by Spark than by the JDBC data source. Please note that aggregates can be pushed down if and only if all the aggregate functions and the related filters can be pushed down. Spark assumes that the data source can't fully complete the aggregate and does a final aggregate over the data source output. + + read + + keytab (none) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Aggregation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Aggregation.java index fdf30312f1..8eb3491ea1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Aggregation.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Aggregation.java @@ -28,19 +28,15 @@ import java.io.Serializable; */ @Evolving public final class Aggregation implements Serializable { - private AggregateFunc[] aggregateExpressions; - private FieldReference[] groupByColumns; + private final AggregateFunc[] aggregateExpressions; + private final FieldReference[] groupByColumns; public Aggregation(AggregateFunc[] aggregateExpressions, FieldReference[] groupByColumns) { this.aggregateExpressions = aggregateExpressions; this.groupByColumns = groupByColumns; } - public AggregateFunc[] aggregateExpressions() { - return aggregateExpressions; - } + public AggregateFunc[] aggregateExpressions() { return aggregateExpressions; } - public FieldReference[] groupByColumns() { - return groupByColumns; - } + public FieldReference[] groupByColumns() { return groupByColumns; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java index 17562a1aa1..0e28a939e3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java @@ -26,24 +26,20 @@ import org.apache.spark.annotation.Evolving; */ @Evolving public final class Count implements AggregateFunc { - private FieldReference column; - private boolean isDistinct; + private final FieldReference column; + private final boolean isDistinct; - public Count(FieldReference column, boolean isDistinct) { - this.column = column; - this.isDistinct = isDistinct; - } + public Count(FieldReference column, boolean isDistinct) { + this.column = column; + this.isDistinct = isDistinct; + } - public FieldReference column() { - return column; - } - public boolean isDinstinct() { - return isDistinct; - } + public FieldReference column() { return column; } + public boolean isDistinct() { return isDistinct; } - @Override - public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; } + @Override + public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; } - @Override - public String describe() { return this.toString(); } + @Override + public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java index 777a99d58e..21a3564480 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java @@ -27,14 +27,12 @@ import org.apache.spark.annotation.Evolving; @Evolving public final class CountStar implements AggregateFunc { - public CountStar() { - } + public CountStar() { + } - @Override - public String toString() { - return "CountStar()"; - } + @Override + public String toString() { return "CountStar()"; } - @Override - public String describe() { return this.toString(); } + @Override + public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java index fe7689c18a..d2ff6b2f04 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java @@ -26,19 +26,15 @@ import org.apache.spark.annotation.Evolving; */ @Evolving public final class Max implements AggregateFunc { - private FieldReference column; + private final FieldReference column; - public Max(FieldReference column) { - this.column = column; - } + public Max(FieldReference column) { this.column = column; } - public FieldReference column() { return column; } + public FieldReference column() { return column; } - @Override - public String toString() { - return "Max(" + column.describe() + ")"; - } + @Override + public String toString() { return "Max(" + column.describe() + ")"; } - @Override - public String describe() { return this.toString(); } + @Override + public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java index f528b0bedf..efa8036100 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java @@ -26,21 +26,15 @@ import org.apache.spark.annotation.Evolving; */ @Evolving public final class Min implements AggregateFunc { - private FieldReference column; + private final FieldReference column; - public Min(FieldReference column) { - this.column = column; - } + public Min(FieldReference column) { this.column = column; } - public FieldReference column() { - return column; - } + public FieldReference column() { return column; } - @Override - public String toString() { - return "Min(" + column.describe() + ")"; - } + @Override + public String toString() { return "Min(" + column.describe() + ")"; } - @Override - public String describe() { return this.toString(); } + @Override + public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java index 4cb34bee28..e4e860e3f3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java @@ -27,31 +27,25 @@ import org.apache.spark.sql.types.DataType; */ @Evolving public final class Sum implements AggregateFunc { - private FieldReference column; - private DataType dataType; - private boolean isDistinct; + private final FieldReference column; + private final DataType dataType; + private final boolean isDistinct; - public Sum(FieldReference column, DataType dataType, boolean isDistinct) { - this.column = column; - this.dataType = dataType; - this.isDistinct = isDistinct; - } + public Sum(FieldReference column, DataType dataType, boolean isDistinct) { + this.column = column; + this.dataType = dataType; + this.isDistinct = isDistinct; + } - public FieldReference column() { - return column; - } - public DataType dataType() { - return dataType; - } - public boolean isDinstinct() { - return isDistinct; - } + public FieldReference column() { return column; } + public DataType dataType() { return dataType; } + public boolean isDistinct() { return isDistinct; } - @Override - public String toString() { - return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")"; - } + @Override + public String toString() { + return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")"; + } - @Override - public String describe() { return this.toString(); } + @Override + public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java index 7efa333bda..8ec9a2597a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -27,12 +27,10 @@ import org.apache.spark.sql.connector.expressions.Aggregation; * "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate * to the data source, the data source can still output data with duplicated keys, which is OK * as Spark will do GROUP BY key again. The final query plan can be something like this: - * {{{ + *
  *   Aggregate [key#1], [min(min(value)#2) AS m#3]
  *     +- RelationV2[key#1, min(value)#2]
- * }}}
- *
- * 

+ *

* Similarly, if there is no grouping expression, the data source can still output more than one * rows. * @@ -51,6 +49,8 @@ public interface SupportsPushDownAggregates extends ScanBuilder { * Pushes down Aggregation to datasource. The order of the datasource scan output columns should * be: grouping columns, aggregate columns (in the same order as the aggregate functions in * the given Aggregation). + * + * @return true if the aggregation can be pushed down to datasource, false otherwise. */ boolean pushAggregation(Aggregation aggregation); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index c22ca1502b..af6c407e4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -148,12 +148,12 @@ object JDBCRDD extends Logging { s"MAX(${quote(max.column.fieldNames.head)})" case count: Count => assert(count.column.fieldNames.length == 1) - val distinct = if (count.isDinstinct) "DISTINCT" else "" + val distinct = if (count.isDistinct) "DISTINCT" else "" val column = quote(count.column.fieldNames.head) s"COUNT($distinct $column)" case sum: Sum => assert(sum.column.fieldNames.length == 1) - val distinct = if (sum.isDinstinct) "DISTINCT" else "" + val distinct = if (sum.isDistinct) "DISTINCT" else "" val column = quote(sum.column.fieldNames.head) s"SUM($distinct $column)" case _: CountStar => @@ -172,8 +172,8 @@ object JDBCRDD extends Logging { * @param parts - An array of JDBCPartitions specifying partition ids and * per-partition WHERE clauses. * @param options - JDBC options that contains url, table and other information. - * @param requiredSchema - The schema of the columns to SELECT. - * @param aggregation - The pushed down aggregation + * @param outputSchema - The schema of the columns to SELECT. + * @param groupByColumns - The pushed down group by columns. * * @return An RDD representing "SELECT requiredColumns FROM fqTable". */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index ab5c5da43a..34b64313c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -92,8 +92,8 @@ object PushDownUtils extends PredicateHelper { scanBuilder match { case r: SupportsPushDownAggregates => - val translatedAggregates = aggregates.map(DataSourceStrategy.translateAggregate).flatten - val translatedGroupBys = groupBy.map(columnAsString).flatten + val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate) + val translatedGroupBys = groupBy.flatMap(columnAsString) if (translatedAggregates.length != aggregates.length || translatedGroupBys.length != groupBy.length) { @@ -101,11 +101,7 @@ object PushDownUtils extends PredicateHelper { } val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray) - if (r.pushAggregation(agg)) { - Some(agg) - } else { - None - } + Some(agg).filter(r.pushAggregation) case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 445ff033d4..a1fc981a69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -33,7 +33,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { import DataSourceV2Implicits._ def apply(plan: LogicalPlan): LogicalPlan = { - applyColumnPruning(pushdownAggregate(pushDownFilters(createScanBuilder(plan)))) + applyColumnPruning(pushDownAggregates(pushDownFilters(createScanBuilder(plan)))) } private def createScanBuilder(plan: LogicalPlan) = plan.transform { @@ -68,7 +68,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder) } - def pushdownAggregate(plan: LogicalPlan): LogicalPlan = plan.transform { + def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform { // update the scan builder with agg pushdown and return a new plan with agg pushed case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => child match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 7442edaafd..afdc822c66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -85,14 +85,14 @@ case class JDBCScanBuilder( val structField = getStructFieldForCol(min.column) outputSchema = outputSchema.add(structField.copy("min(" + structField.name + ")")) case count: Count => - val distinct = if (count.isDinstinct) "DISTINCT " else "" + val distinct = if (count.isDistinct) "DISTINCT " else "" val structField = getStructFieldForCol(count.column) outputSchema = outputSchema.add(StructField(s"count($distinct" + structField.name + ")", LongType)) case _: CountStar => outputSchema = outputSchema.add(StructField("count(*)", LongType)) case sum: Sum => - val distinct = if (sum.isDinstinct) "DISTINCT " else "" + val distinct = if (sum.isDistinct) "DISTINCT " else "" val structField = getStructFieldForCol(sum.column) outputSchema = outputSchema.add(StructField(s"sum($distinct" + structField.name + ")", sum.dataType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index c1f8f5f00e..8dfb6defa4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -453,7 +453,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(query, Seq(Row(47100.0))) } - test("scan with aggregate push-down: aggregate over alias") { + test("scan with aggregate push-down: aggregate over alias NOT push down") { val cols = Seq("a", "b", "c", "d") val df1 = sql("select * from h2.test.employee").toDF(cols: _*) val df2 = df1.groupBy().sum("c")