[SPARK-34952][SQL][FOLLOW-UP] DSv2 aggregate push down follow-up
### What changes were proposed in this pull request? update java doc, JDBC data source doc, address follow up comments ### Why are the changes needed? update doc and address follow up comments ### Does this PR introduce _any_ user-facing change? Yes, add the new JDBC option `pushDownAggregate` in JDBC data source doc. ### How was this patch tested? manually checked Closes #33526 from huaxingao/aggPD_followup. Authored-by: Huaxin Gao <huaxin_gao@apple.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
1614d00417
commit
c8dd97d456
|
@ -237,6 +237,15 @@ logging into the data sources.
|
|||
<td>read</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td><code>pushDownAggregate</code></td>
|
||||
<td><code>false</code></td>
|
||||
<td>
|
||||
The option to enable or disable aggregate push-down into the JDBC data source. The default value is false, in which case Spark will not push down aggregates to the JDBC data source. Otherwise, if sets to true, aggregates will be pushed down to the JDBC data source. Aggregate push-down is usually turned off when the aggregate is performed faster by Spark than by the JDBC data source. Please note that aggregates can be pushed down if and only if all the aggregate functions and the related filters can be pushed down. Spark assumes that the data source can't fully complete the aggregate and does a final aggregate over the data source output.
|
||||
</td>
|
||||
<td>read</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td><code>keytab</code></td>
|
||||
<td>(none)</td>
|
||||
|
|
|
@ -28,19 +28,15 @@ import java.io.Serializable;
|
|||
*/
|
||||
@Evolving
|
||||
public final class Aggregation implements Serializable {
|
||||
private AggregateFunc[] aggregateExpressions;
|
||||
private FieldReference[] groupByColumns;
|
||||
private final AggregateFunc[] aggregateExpressions;
|
||||
private final FieldReference[] groupByColumns;
|
||||
|
||||
public Aggregation(AggregateFunc[] aggregateExpressions, FieldReference[] groupByColumns) {
|
||||
this.aggregateExpressions = aggregateExpressions;
|
||||
this.groupByColumns = groupByColumns;
|
||||
}
|
||||
|
||||
public AggregateFunc[] aggregateExpressions() {
|
||||
return aggregateExpressions;
|
||||
}
|
||||
public AggregateFunc[] aggregateExpressions() { return aggregateExpressions; }
|
||||
|
||||
public FieldReference[] groupByColumns() {
|
||||
return groupByColumns;
|
||||
}
|
||||
public FieldReference[] groupByColumns() { return groupByColumns; }
|
||||
}
|
||||
|
|
|
@ -26,24 +26,20 @@ import org.apache.spark.annotation.Evolving;
|
|||
*/
|
||||
@Evolving
|
||||
public final class Count implements AggregateFunc {
|
||||
private FieldReference column;
|
||||
private boolean isDistinct;
|
||||
private final FieldReference column;
|
||||
private final boolean isDistinct;
|
||||
|
||||
public Count(FieldReference column, boolean isDistinct) {
|
||||
this.column = column;
|
||||
this.isDistinct = isDistinct;
|
||||
}
|
||||
public Count(FieldReference column, boolean isDistinct) {
|
||||
this.column = column;
|
||||
this.isDistinct = isDistinct;
|
||||
}
|
||||
|
||||
public FieldReference column() {
|
||||
return column;
|
||||
}
|
||||
public boolean isDinstinct() {
|
||||
return isDistinct;
|
||||
}
|
||||
public FieldReference column() { return column; }
|
||||
public boolean isDistinct() { return isDistinct; }
|
||||
|
||||
@Override
|
||||
public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; }
|
||||
@Override
|
||||
public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; }
|
||||
|
||||
@Override
|
||||
public String describe() { return this.toString(); }
|
||||
@Override
|
||||
public String describe() { return this.toString(); }
|
||||
}
|
||||
|
|
|
@ -27,14 +27,12 @@ import org.apache.spark.annotation.Evolving;
|
|||
@Evolving
|
||||
public final class CountStar implements AggregateFunc {
|
||||
|
||||
public CountStar() {
|
||||
}
|
||||
public CountStar() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "CountStar()";
|
||||
}
|
||||
@Override
|
||||
public String toString() { return "CountStar()"; }
|
||||
|
||||
@Override
|
||||
public String describe() { return this.toString(); }
|
||||
@Override
|
||||
public String describe() { return this.toString(); }
|
||||
}
|
||||
|
|
|
@ -26,19 +26,15 @@ import org.apache.spark.annotation.Evolving;
|
|||
*/
|
||||
@Evolving
|
||||
public final class Max implements AggregateFunc {
|
||||
private FieldReference column;
|
||||
private final FieldReference column;
|
||||
|
||||
public Max(FieldReference column) {
|
||||
this.column = column;
|
||||
}
|
||||
public Max(FieldReference column) { this.column = column; }
|
||||
|
||||
public FieldReference column() { return column; }
|
||||
public FieldReference column() { return column; }
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Max(" + column.describe() + ")";
|
||||
}
|
||||
@Override
|
||||
public String toString() { return "Max(" + column.describe() + ")"; }
|
||||
|
||||
@Override
|
||||
public String describe() { return this.toString(); }
|
||||
@Override
|
||||
public String describe() { return this.toString(); }
|
||||
}
|
||||
|
|
|
@ -26,21 +26,15 @@ import org.apache.spark.annotation.Evolving;
|
|||
*/
|
||||
@Evolving
|
||||
public final class Min implements AggregateFunc {
|
||||
private FieldReference column;
|
||||
private final FieldReference column;
|
||||
|
||||
public Min(FieldReference column) {
|
||||
this.column = column;
|
||||
}
|
||||
public Min(FieldReference column) { this.column = column; }
|
||||
|
||||
public FieldReference column() {
|
||||
return column;
|
||||
}
|
||||
public FieldReference column() { return column; }
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Min(" + column.describe() + ")";
|
||||
}
|
||||
@Override
|
||||
public String toString() { return "Min(" + column.describe() + ")"; }
|
||||
|
||||
@Override
|
||||
public String describe() { return this.toString(); }
|
||||
@Override
|
||||
public String describe() { return this.toString(); }
|
||||
}
|
||||
|
|
|
@ -27,31 +27,25 @@ import org.apache.spark.sql.types.DataType;
|
|||
*/
|
||||
@Evolving
|
||||
public final class Sum implements AggregateFunc {
|
||||
private FieldReference column;
|
||||
private DataType dataType;
|
||||
private boolean isDistinct;
|
||||
private final FieldReference column;
|
||||
private final DataType dataType;
|
||||
private final boolean isDistinct;
|
||||
|
||||
public Sum(FieldReference column, DataType dataType, boolean isDistinct) {
|
||||
this.column = column;
|
||||
this.dataType = dataType;
|
||||
this.isDistinct = isDistinct;
|
||||
}
|
||||
public Sum(FieldReference column, DataType dataType, boolean isDistinct) {
|
||||
this.column = column;
|
||||
this.dataType = dataType;
|
||||
this.isDistinct = isDistinct;
|
||||
}
|
||||
|
||||
public FieldReference column() {
|
||||
return column;
|
||||
}
|
||||
public DataType dataType() {
|
||||
return dataType;
|
||||
}
|
||||
public boolean isDinstinct() {
|
||||
return isDistinct;
|
||||
}
|
||||
public FieldReference column() { return column; }
|
||||
public DataType dataType() { return dataType; }
|
||||
public boolean isDistinct() { return isDistinct; }
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")";
|
||||
}
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String describe() { return this.toString(); }
|
||||
@Override
|
||||
public String describe() { return this.toString(); }
|
||||
}
|
||||
|
|
|
@ -27,12 +27,10 @@ import org.apache.spark.sql.connector.expressions.Aggregation;
|
|||
* "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate
|
||||
* to the data source, the data source can still output data with duplicated keys, which is OK
|
||||
* as Spark will do GROUP BY key again. The final query plan can be something like this:
|
||||
* {{{
|
||||
* <pre>
|
||||
* Aggregate [key#1], [min(min(value)#2) AS m#3]
|
||||
* +- RelationV2[key#1, min(value)#2]
|
||||
* }}}
|
||||
*
|
||||
* <p>
|
||||
* </pre>
|
||||
* Similarly, if there is no grouping expression, the data source can still output more than one
|
||||
* rows.
|
||||
*
|
||||
|
@ -51,6 +49,8 @@ public interface SupportsPushDownAggregates extends ScanBuilder {
|
|||
* Pushes down Aggregation to datasource. The order of the datasource scan output columns should
|
||||
* be: grouping columns, aggregate columns (in the same order as the aggregate functions in
|
||||
* the given Aggregation).
|
||||
*
|
||||
* @return true if the aggregation can be pushed down to datasource, false otherwise.
|
||||
*/
|
||||
boolean pushAggregation(Aggregation aggregation);
|
||||
}
|
||||
|
|
|
@ -148,12 +148,12 @@ object JDBCRDD extends Logging {
|
|||
s"MAX(${quote(max.column.fieldNames.head)})"
|
||||
case count: Count =>
|
||||
assert(count.column.fieldNames.length == 1)
|
||||
val distinct = if (count.isDinstinct) "DISTINCT" else ""
|
||||
val distinct = if (count.isDistinct) "DISTINCT" else ""
|
||||
val column = quote(count.column.fieldNames.head)
|
||||
s"COUNT($distinct $column)"
|
||||
case sum: Sum =>
|
||||
assert(sum.column.fieldNames.length == 1)
|
||||
val distinct = if (sum.isDinstinct) "DISTINCT" else ""
|
||||
val distinct = if (sum.isDistinct) "DISTINCT" else ""
|
||||
val column = quote(sum.column.fieldNames.head)
|
||||
s"SUM($distinct $column)"
|
||||
case _: CountStar =>
|
||||
|
@ -172,8 +172,8 @@ object JDBCRDD extends Logging {
|
|||
* @param parts - An array of JDBCPartitions specifying partition ids and
|
||||
* per-partition WHERE clauses.
|
||||
* @param options - JDBC options that contains url, table and other information.
|
||||
* @param requiredSchema - The schema of the columns to SELECT.
|
||||
* @param aggregation - The pushed down aggregation
|
||||
* @param outputSchema - The schema of the columns to SELECT.
|
||||
* @param groupByColumns - The pushed down group by columns.
|
||||
*
|
||||
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
|
||||
*/
|
||||
|
|
|
@ -92,8 +92,8 @@ object PushDownUtils extends PredicateHelper {
|
|||
|
||||
scanBuilder match {
|
||||
case r: SupportsPushDownAggregates =>
|
||||
val translatedAggregates = aggregates.map(DataSourceStrategy.translateAggregate).flatten
|
||||
val translatedGroupBys = groupBy.map(columnAsString).flatten
|
||||
val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate)
|
||||
val translatedGroupBys = groupBy.flatMap(columnAsString)
|
||||
|
||||
if (translatedAggregates.length != aggregates.length ||
|
||||
translatedGroupBys.length != groupBy.length) {
|
||||
|
@ -101,11 +101,7 @@ object PushDownUtils extends PredicateHelper {
|
|||
}
|
||||
|
||||
val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)
|
||||
if (r.pushAggregation(agg)) {
|
||||
Some(agg)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
Some(agg).filter(r.pushAggregation)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,7 +33,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
|
|||
import DataSourceV2Implicits._
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = {
|
||||
applyColumnPruning(pushdownAggregate(pushDownFilters(createScanBuilder(plan))))
|
||||
applyColumnPruning(pushDownAggregates(pushDownFilters(createScanBuilder(plan))))
|
||||
}
|
||||
|
||||
private def createScanBuilder(plan: LogicalPlan) = plan.transform {
|
||||
|
@ -68,7 +68,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
|
|||
filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder)
|
||||
}
|
||||
|
||||
def pushdownAggregate(plan: LogicalPlan): LogicalPlan = plan.transform {
|
||||
def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform {
|
||||
// update the scan builder with agg pushdown and return a new plan with agg pushed
|
||||
case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) =>
|
||||
child match {
|
||||
|
|
|
@ -85,14 +85,14 @@ case class JDBCScanBuilder(
|
|||
val structField = getStructFieldForCol(min.column)
|
||||
outputSchema = outputSchema.add(structField.copy("min(" + structField.name + ")"))
|
||||
case count: Count =>
|
||||
val distinct = if (count.isDinstinct) "DISTINCT " else ""
|
||||
val distinct = if (count.isDistinct) "DISTINCT " else ""
|
||||
val structField = getStructFieldForCol(count.column)
|
||||
outputSchema =
|
||||
outputSchema.add(StructField(s"count($distinct" + structField.name + ")", LongType))
|
||||
case _: CountStar =>
|
||||
outputSchema = outputSchema.add(StructField("count(*)", LongType))
|
||||
case sum: Sum =>
|
||||
val distinct = if (sum.isDinstinct) "DISTINCT " else ""
|
||||
val distinct = if (sum.isDistinct) "DISTINCT " else ""
|
||||
val structField = getStructFieldForCol(sum.column)
|
||||
outputSchema =
|
||||
outputSchema.add(StructField(s"sum($distinct" + structField.name + ")", sum.dataType))
|
||||
|
|
|
@ -453,7 +453,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
|
|||
checkAnswer(query, Seq(Row(47100.0)))
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: aggregate over alias") {
|
||||
test("scan with aggregate push-down: aggregate over alias NOT push down") {
|
||||
val cols = Seq("a", "b", "c", "d")
|
||||
val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
|
||||
val df2 = df1.groupBy().sum("c")
|
||||
|
|
Loading…
Reference in a new issue