[SPARK-34952][SQL][FOLLOW-UP] DSv2 aggregate push down follow-up

### What changes were proposed in this pull request?
update java doc, JDBC data source doc, address follow up comments

### Why are the changes needed?
update doc and address follow up comments

### Does this PR introduce _any_ user-facing change?
Yes, add the new JDBC option `pushDownAggregate` in JDBC data source doc.

### How was this patch tested?
manually checked

Closes #33526 from huaxingao/aggPD_followup.

Authored-by: Huaxin Gao <huaxin_gao@apple.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Huaxin Gao 2021-07-28 12:52:42 +08:00 committed by Wenchen Fan
parent 1614d00417
commit c8dd97d456
13 changed files with 78 additions and 99 deletions

View file

@ -237,6 +237,15 @@ logging into the data sources.
<td>read</td>
</tr>
<tr>
<td><code>pushDownAggregate</code></td>
<td><code>false</code></td>
<td>
The option to enable or disable aggregate push-down into the JDBC data source. The default value is false, in which case Spark will not push down aggregates to the JDBC data source. Otherwise, if sets to true, aggregates will be pushed down to the JDBC data source. Aggregate push-down is usually turned off when the aggregate is performed faster by Spark than by the JDBC data source. Please note that aggregates can be pushed down if and only if all the aggregate functions and the related filters can be pushed down. Spark assumes that the data source can't fully complete the aggregate and does a final aggregate over the data source output.
</td>
<td>read</td>
</tr>
<tr>
<td><code>keytab</code></td>
<td>(none)</td>

View file

@ -28,19 +28,15 @@ import java.io.Serializable;
*/
@Evolving
public final class Aggregation implements Serializable {
private AggregateFunc[] aggregateExpressions;
private FieldReference[] groupByColumns;
private final AggregateFunc[] aggregateExpressions;
private final FieldReference[] groupByColumns;
public Aggregation(AggregateFunc[] aggregateExpressions, FieldReference[] groupByColumns) {
this.aggregateExpressions = aggregateExpressions;
this.groupByColumns = groupByColumns;
}
public AggregateFunc[] aggregateExpressions() {
return aggregateExpressions;
}
public AggregateFunc[] aggregateExpressions() { return aggregateExpressions; }
public FieldReference[] groupByColumns() {
return groupByColumns;
}
public FieldReference[] groupByColumns() { return groupByColumns; }
}

View file

@ -26,24 +26,20 @@ import org.apache.spark.annotation.Evolving;
*/
@Evolving
public final class Count implements AggregateFunc {
private FieldReference column;
private boolean isDistinct;
private final FieldReference column;
private final boolean isDistinct;
public Count(FieldReference column, boolean isDistinct) {
this.column = column;
this.isDistinct = isDistinct;
}
public Count(FieldReference column, boolean isDistinct) {
this.column = column;
this.isDistinct = isDistinct;
}
public FieldReference column() {
return column;
}
public boolean isDinstinct() {
return isDistinct;
}
public FieldReference column() { return column; }
public boolean isDistinct() { return isDistinct; }
@Override
public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; }
@Override
public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; }
@Override
public String describe() { return this.toString(); }
@Override
public String describe() { return this.toString(); }
}

View file

@ -27,14 +27,12 @@ import org.apache.spark.annotation.Evolving;
@Evolving
public final class CountStar implements AggregateFunc {
public CountStar() {
}
public CountStar() {
}
@Override
public String toString() {
return "CountStar()";
}
@Override
public String toString() { return "CountStar()"; }
@Override
public String describe() { return this.toString(); }
@Override
public String describe() { return this.toString(); }
}

View file

@ -26,19 +26,15 @@ import org.apache.spark.annotation.Evolving;
*/
@Evolving
public final class Max implements AggregateFunc {
private FieldReference column;
private final FieldReference column;
public Max(FieldReference column) {
this.column = column;
}
public Max(FieldReference column) { this.column = column; }
public FieldReference column() { return column; }
public FieldReference column() { return column; }
@Override
public String toString() {
return "Max(" + column.describe() + ")";
}
@Override
public String toString() { return "Max(" + column.describe() + ")"; }
@Override
public String describe() { return this.toString(); }
@Override
public String describe() { return this.toString(); }
}

View file

@ -26,21 +26,15 @@ import org.apache.spark.annotation.Evolving;
*/
@Evolving
public final class Min implements AggregateFunc {
private FieldReference column;
private final FieldReference column;
public Min(FieldReference column) {
this.column = column;
}
public Min(FieldReference column) { this.column = column; }
public FieldReference column() {
return column;
}
public FieldReference column() { return column; }
@Override
public String toString() {
return "Min(" + column.describe() + ")";
}
@Override
public String toString() { return "Min(" + column.describe() + ")"; }
@Override
public String describe() { return this.toString(); }
@Override
public String describe() { return this.toString(); }
}

View file

@ -27,31 +27,25 @@ import org.apache.spark.sql.types.DataType;
*/
@Evolving
public final class Sum implements AggregateFunc {
private FieldReference column;
private DataType dataType;
private boolean isDistinct;
private final FieldReference column;
private final DataType dataType;
private final boolean isDistinct;
public Sum(FieldReference column, DataType dataType, boolean isDistinct) {
this.column = column;
this.dataType = dataType;
this.isDistinct = isDistinct;
}
public Sum(FieldReference column, DataType dataType, boolean isDistinct) {
this.column = column;
this.dataType = dataType;
this.isDistinct = isDistinct;
}
public FieldReference column() {
return column;
}
public DataType dataType() {
return dataType;
}
public boolean isDinstinct() {
return isDistinct;
}
public FieldReference column() { return column; }
public DataType dataType() { return dataType; }
public boolean isDistinct() { return isDistinct; }
@Override
public String toString() {
return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")";
}
@Override
public String toString() {
return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")";
}
@Override
public String describe() { return this.toString(); }
@Override
public String describe() { return this.toString(); }
}

View file

@ -27,12 +27,10 @@ import org.apache.spark.sql.connector.expressions.Aggregation;
* "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate
* to the data source, the data source can still output data with duplicated keys, which is OK
* as Spark will do GROUP BY key again. The final query plan can be something like this:
* {{{
* <pre>
* Aggregate [key#1], [min(min(value)#2) AS m#3]
* +- RelationV2[key#1, min(value)#2]
* }}}
*
* <p>
* </pre>
* Similarly, if there is no grouping expression, the data source can still output more than one
* rows.
*
@ -51,6 +49,8 @@ public interface SupportsPushDownAggregates extends ScanBuilder {
* Pushes down Aggregation to datasource. The order of the datasource scan output columns should
* be: grouping columns, aggregate columns (in the same order as the aggregate functions in
* the given Aggregation).
*
* @return true if the aggregation can be pushed down to datasource, false otherwise.
*/
boolean pushAggregation(Aggregation aggregation);
}

View file

@ -148,12 +148,12 @@ object JDBCRDD extends Logging {
s"MAX(${quote(max.column.fieldNames.head)})"
case count: Count =>
assert(count.column.fieldNames.length == 1)
val distinct = if (count.isDinstinct) "DISTINCT" else ""
val distinct = if (count.isDistinct) "DISTINCT" else ""
val column = quote(count.column.fieldNames.head)
s"COUNT($distinct $column)"
case sum: Sum =>
assert(sum.column.fieldNames.length == 1)
val distinct = if (sum.isDinstinct) "DISTINCT" else ""
val distinct = if (sum.isDistinct) "DISTINCT" else ""
val column = quote(sum.column.fieldNames.head)
s"SUM($distinct $column)"
case _: CountStar =>
@ -172,8 +172,8 @@ object JDBCRDD extends Logging {
* @param parts - An array of JDBCPartitions specifying partition ids and
* per-partition WHERE clauses.
* @param options - JDBC options that contains url, table and other information.
* @param requiredSchema - The schema of the columns to SELECT.
* @param aggregation - The pushed down aggregation
* @param outputSchema - The schema of the columns to SELECT.
* @param groupByColumns - The pushed down group by columns.
*
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
*/

View file

@ -92,8 +92,8 @@ object PushDownUtils extends PredicateHelper {
scanBuilder match {
case r: SupportsPushDownAggregates =>
val translatedAggregates = aggregates.map(DataSourceStrategy.translateAggregate).flatten
val translatedGroupBys = groupBy.map(columnAsString).flatten
val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate)
val translatedGroupBys = groupBy.flatMap(columnAsString)
if (translatedAggregates.length != aggregates.length ||
translatedGroupBys.length != groupBy.length) {
@ -101,11 +101,7 @@ object PushDownUtils extends PredicateHelper {
}
val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)
if (r.pushAggregation(agg)) {
Some(agg)
} else {
None
}
Some(agg).filter(r.pushAggregation)
case _ => None
}
}

View file

@ -33,7 +33,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
import DataSourceV2Implicits._
def apply(plan: LogicalPlan): LogicalPlan = {
applyColumnPruning(pushdownAggregate(pushDownFilters(createScanBuilder(plan))))
applyColumnPruning(pushDownAggregates(pushDownFilters(createScanBuilder(plan))))
}
private def createScanBuilder(plan: LogicalPlan) = plan.transform {
@ -68,7 +68,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder)
}
def pushdownAggregate(plan: LogicalPlan): LogicalPlan = plan.transform {
def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform {
// update the scan builder with agg pushdown and return a new plan with agg pushed
case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) =>
child match {

View file

@ -85,14 +85,14 @@ case class JDBCScanBuilder(
val structField = getStructFieldForCol(min.column)
outputSchema = outputSchema.add(structField.copy("min(" + structField.name + ")"))
case count: Count =>
val distinct = if (count.isDinstinct) "DISTINCT " else ""
val distinct = if (count.isDistinct) "DISTINCT " else ""
val structField = getStructFieldForCol(count.column)
outputSchema =
outputSchema.add(StructField(s"count($distinct" + structField.name + ")", LongType))
case _: CountStar =>
outputSchema = outputSchema.add(StructField("count(*)", LongType))
case sum: Sum =>
val distinct = if (sum.isDinstinct) "DISTINCT " else ""
val distinct = if (sum.isDistinct) "DISTINCT " else ""
val structField = getStructFieldForCol(sum.column)
outputSchema =
outputSchema.add(StructField(s"sum($distinct" + structField.name + ")", sum.dataType))

View file

@ -453,7 +453,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(query, Seq(Row(47100.0)))
}
test("scan with aggregate push-down: aggregate over alias") {
test("scan with aggregate push-down: aggregate over alias NOT push down") {
val cols = Seq("a", "b", "c", "d")
val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
val df2 = df1.groupBy().sum("c")