[SPARK-34952][SQL][FOLLOW-UP] DSv2 aggregate push down follow-up

### What changes were proposed in this pull request?
update java doc, JDBC data source doc, address follow up comments

### Why are the changes needed?
update doc and address follow up comments

### Does this PR introduce _any_ user-facing change?
Yes, add the new JDBC option `pushDownAggregate` in JDBC data source doc.

### How was this patch tested?
manually checked

Closes #33526 from huaxingao/aggPD_followup.

Authored-by: Huaxin Gao <huaxin_gao@apple.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Huaxin Gao 2021-07-28 12:52:42 +08:00 committed by Wenchen Fan
parent 1614d00417
commit c8dd97d456
13 changed files with 78 additions and 99 deletions

View file

@ -237,6 +237,15 @@ logging into the data sources.
<td>read</td> <td>read</td>
</tr> </tr>
<tr>
<td><code>pushDownAggregate</code></td>
<td><code>false</code></td>
<td>
The option to enable or disable aggregate push-down into the JDBC data source. The default value is false, in which case Spark will not push down aggregates to the JDBC data source. Otherwise, if sets to true, aggregates will be pushed down to the JDBC data source. Aggregate push-down is usually turned off when the aggregate is performed faster by Spark than by the JDBC data source. Please note that aggregates can be pushed down if and only if all the aggregate functions and the related filters can be pushed down. Spark assumes that the data source can't fully complete the aggregate and does a final aggregate over the data source output.
</td>
<td>read</td>
</tr>
<tr> <tr>
<td><code>keytab</code></td> <td><code>keytab</code></td>
<td>(none)</td> <td>(none)</td>

View file

@ -28,19 +28,15 @@ import java.io.Serializable;
*/ */
@Evolving @Evolving
public final class Aggregation implements Serializable { public final class Aggregation implements Serializable {
private AggregateFunc[] aggregateExpressions; private final AggregateFunc[] aggregateExpressions;
private FieldReference[] groupByColumns; private final FieldReference[] groupByColumns;
public Aggregation(AggregateFunc[] aggregateExpressions, FieldReference[] groupByColumns) { public Aggregation(AggregateFunc[] aggregateExpressions, FieldReference[] groupByColumns) {
this.aggregateExpressions = aggregateExpressions; this.aggregateExpressions = aggregateExpressions;
this.groupByColumns = groupByColumns; this.groupByColumns = groupByColumns;
} }
public AggregateFunc[] aggregateExpressions() { public AggregateFunc[] aggregateExpressions() { return aggregateExpressions; }
return aggregateExpressions;
}
public FieldReference[] groupByColumns() { public FieldReference[] groupByColumns() { return groupByColumns; }
return groupByColumns;
}
} }

View file

@ -26,24 +26,20 @@ import org.apache.spark.annotation.Evolving;
*/ */
@Evolving @Evolving
public final class Count implements AggregateFunc { public final class Count implements AggregateFunc {
private FieldReference column; private final FieldReference column;
private boolean isDistinct; private final boolean isDistinct;
public Count(FieldReference column, boolean isDistinct) { public Count(FieldReference column, boolean isDistinct) {
this.column = column; this.column = column;
this.isDistinct = isDistinct; this.isDistinct = isDistinct;
} }
public FieldReference column() { public FieldReference column() { return column; }
return column; public boolean isDistinct() { return isDistinct; }
}
public boolean isDinstinct() {
return isDistinct;
}
@Override @Override
public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; } public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; }
@Override @Override
public String describe() { return this.toString(); } public String describe() { return this.toString(); }
} }

View file

@ -27,14 +27,12 @@ import org.apache.spark.annotation.Evolving;
@Evolving @Evolving
public final class CountStar implements AggregateFunc { public final class CountStar implements AggregateFunc {
public CountStar() { public CountStar() {
} }
@Override @Override
public String toString() { public String toString() { return "CountStar()"; }
return "CountStar()";
}
@Override @Override
public String describe() { return this.toString(); } public String describe() { return this.toString(); }
} }

View file

@ -26,19 +26,15 @@ import org.apache.spark.annotation.Evolving;
*/ */
@Evolving @Evolving
public final class Max implements AggregateFunc { public final class Max implements AggregateFunc {
private FieldReference column; private final FieldReference column;
public Max(FieldReference column) { public Max(FieldReference column) { this.column = column; }
this.column = column;
}
public FieldReference column() { return column; } public FieldReference column() { return column; }
@Override @Override
public String toString() { public String toString() { return "Max(" + column.describe() + ")"; }
return "Max(" + column.describe() + ")";
}
@Override @Override
public String describe() { return this.toString(); } public String describe() { return this.toString(); }
} }

View file

@ -26,21 +26,15 @@ import org.apache.spark.annotation.Evolving;
*/ */
@Evolving @Evolving
public final class Min implements AggregateFunc { public final class Min implements AggregateFunc {
private FieldReference column; private final FieldReference column;
public Min(FieldReference column) { public Min(FieldReference column) { this.column = column; }
this.column = column;
}
public FieldReference column() { public FieldReference column() { return column; }
return column;
}
@Override @Override
public String toString() { public String toString() { return "Min(" + column.describe() + ")"; }
return "Min(" + column.describe() + ")";
}
@Override @Override
public String describe() { return this.toString(); } public String describe() { return this.toString(); }
} }

View file

@ -27,31 +27,25 @@ import org.apache.spark.sql.types.DataType;
*/ */
@Evolving @Evolving
public final class Sum implements AggregateFunc { public final class Sum implements AggregateFunc {
private FieldReference column; private final FieldReference column;
private DataType dataType; private final DataType dataType;
private boolean isDistinct; private final boolean isDistinct;
public Sum(FieldReference column, DataType dataType, boolean isDistinct) { public Sum(FieldReference column, DataType dataType, boolean isDistinct) {
this.column = column; this.column = column;
this.dataType = dataType; this.dataType = dataType;
this.isDistinct = isDistinct; this.isDistinct = isDistinct;
} }
public FieldReference column() { public FieldReference column() { return column; }
return column; public DataType dataType() { return dataType; }
} public boolean isDistinct() { return isDistinct; }
public DataType dataType() {
return dataType;
}
public boolean isDinstinct() {
return isDistinct;
}
@Override @Override
public String toString() { public String toString() {
return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")"; return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")";
} }
@Override @Override
public String describe() { return this.toString(); } public String describe() { return this.toString(); }
} }

View file

@ -27,12 +27,10 @@ import org.apache.spark.sql.connector.expressions.Aggregation;
* "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate * "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate
* to the data source, the data source can still output data with duplicated keys, which is OK * to the data source, the data source can still output data with duplicated keys, which is OK
* as Spark will do GROUP BY key again. The final query plan can be something like this: * as Spark will do GROUP BY key again. The final query plan can be something like this:
* {{{ * <pre>
* Aggregate [key#1], [min(min(value)#2) AS m#3] * Aggregate [key#1], [min(min(value)#2) AS m#3]
* +- RelationV2[key#1, min(value)#2] * +- RelationV2[key#1, min(value)#2]
* }}} * </pre>
*
* <p>
* Similarly, if there is no grouping expression, the data source can still output more than one * Similarly, if there is no grouping expression, the data source can still output more than one
* rows. * rows.
* *
@ -51,6 +49,8 @@ public interface SupportsPushDownAggregates extends ScanBuilder {
* Pushes down Aggregation to datasource. The order of the datasource scan output columns should * Pushes down Aggregation to datasource. The order of the datasource scan output columns should
* be: grouping columns, aggregate columns (in the same order as the aggregate functions in * be: grouping columns, aggregate columns (in the same order as the aggregate functions in
* the given Aggregation). * the given Aggregation).
*
* @return true if the aggregation can be pushed down to datasource, false otherwise.
*/ */
boolean pushAggregation(Aggregation aggregation); boolean pushAggregation(Aggregation aggregation);
} }

View file

@ -148,12 +148,12 @@ object JDBCRDD extends Logging {
s"MAX(${quote(max.column.fieldNames.head)})" s"MAX(${quote(max.column.fieldNames.head)})"
case count: Count => case count: Count =>
assert(count.column.fieldNames.length == 1) assert(count.column.fieldNames.length == 1)
val distinct = if (count.isDinstinct) "DISTINCT" else "" val distinct = if (count.isDistinct) "DISTINCT" else ""
val column = quote(count.column.fieldNames.head) val column = quote(count.column.fieldNames.head)
s"COUNT($distinct $column)" s"COUNT($distinct $column)"
case sum: Sum => case sum: Sum =>
assert(sum.column.fieldNames.length == 1) assert(sum.column.fieldNames.length == 1)
val distinct = if (sum.isDinstinct) "DISTINCT" else "" val distinct = if (sum.isDistinct) "DISTINCT" else ""
val column = quote(sum.column.fieldNames.head) val column = quote(sum.column.fieldNames.head)
s"SUM($distinct $column)" s"SUM($distinct $column)"
case _: CountStar => case _: CountStar =>
@ -172,8 +172,8 @@ object JDBCRDD extends Logging {
* @param parts - An array of JDBCPartitions specifying partition ids and * @param parts - An array of JDBCPartitions specifying partition ids and
* per-partition WHERE clauses. * per-partition WHERE clauses.
* @param options - JDBC options that contains url, table and other information. * @param options - JDBC options that contains url, table and other information.
* @param requiredSchema - The schema of the columns to SELECT. * @param outputSchema - The schema of the columns to SELECT.
* @param aggregation - The pushed down aggregation * @param groupByColumns - The pushed down group by columns.
* *
* @return An RDD representing "SELECT requiredColumns FROM fqTable". * @return An RDD representing "SELECT requiredColumns FROM fqTable".
*/ */

View file

@ -92,8 +92,8 @@ object PushDownUtils extends PredicateHelper {
scanBuilder match { scanBuilder match {
case r: SupportsPushDownAggregates => case r: SupportsPushDownAggregates =>
val translatedAggregates = aggregates.map(DataSourceStrategy.translateAggregate).flatten val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate)
val translatedGroupBys = groupBy.map(columnAsString).flatten val translatedGroupBys = groupBy.flatMap(columnAsString)
if (translatedAggregates.length != aggregates.length || if (translatedAggregates.length != aggregates.length ||
translatedGroupBys.length != groupBy.length) { translatedGroupBys.length != groupBy.length) {
@ -101,11 +101,7 @@ object PushDownUtils extends PredicateHelper {
} }
val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray) val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)
if (r.pushAggregation(agg)) { Some(agg).filter(r.pushAggregation)
Some(agg)
} else {
None
}
case _ => None case _ => None
} }
} }

View file

@ -33,7 +33,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
import DataSourceV2Implicits._ import DataSourceV2Implicits._
def apply(plan: LogicalPlan): LogicalPlan = { def apply(plan: LogicalPlan): LogicalPlan = {
applyColumnPruning(pushdownAggregate(pushDownFilters(createScanBuilder(plan)))) applyColumnPruning(pushDownAggregates(pushDownFilters(createScanBuilder(plan))))
} }
private def createScanBuilder(plan: LogicalPlan) = plan.transform { private def createScanBuilder(plan: LogicalPlan) = plan.transform {
@ -68,7 +68,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder) filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder)
} }
def pushdownAggregate(plan: LogicalPlan): LogicalPlan = plan.transform { def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform {
// update the scan builder with agg pushdown and return a new plan with agg pushed // update the scan builder with agg pushdown and return a new plan with agg pushed
case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) =>
child match { child match {

View file

@ -85,14 +85,14 @@ case class JDBCScanBuilder(
val structField = getStructFieldForCol(min.column) val structField = getStructFieldForCol(min.column)
outputSchema = outputSchema.add(structField.copy("min(" + structField.name + ")")) outputSchema = outputSchema.add(structField.copy("min(" + structField.name + ")"))
case count: Count => case count: Count =>
val distinct = if (count.isDinstinct) "DISTINCT " else "" val distinct = if (count.isDistinct) "DISTINCT " else ""
val structField = getStructFieldForCol(count.column) val structField = getStructFieldForCol(count.column)
outputSchema = outputSchema =
outputSchema.add(StructField(s"count($distinct" + structField.name + ")", LongType)) outputSchema.add(StructField(s"count($distinct" + structField.name + ")", LongType))
case _: CountStar => case _: CountStar =>
outputSchema = outputSchema.add(StructField("count(*)", LongType)) outputSchema = outputSchema.add(StructField("count(*)", LongType))
case sum: Sum => case sum: Sum =>
val distinct = if (sum.isDinstinct) "DISTINCT " else "" val distinct = if (sum.isDistinct) "DISTINCT " else ""
val structField = getStructFieldForCol(sum.column) val structField = getStructFieldForCol(sum.column)
outputSchema = outputSchema =
outputSchema.add(StructField(s"sum($distinct" + structField.name + ")", sum.dataType)) outputSchema.add(StructField(s"sum($distinct" + structField.name + ")", sum.dataType))

View file

@ -453,7 +453,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(query, Seq(Row(47100.0))) checkAnswer(query, Seq(Row(47100.0)))
} }
test("scan with aggregate push-down: aggregate over alias") { test("scan with aggregate push-down: aggregate over alias NOT push down") {
val cols = Seq("a", "b", "c", "d") val cols = Seq("a", "b", "c", "d")
val df1 = sql("select * from h2.test.employee").toDF(cols: _*) val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
val df2 = df1.groupBy().sum("c") val df2 = df1.groupBy().sum("c")