[SPARK-34952][SQL] DSv2 Aggregate push down APIs

### What changes were proposed in this pull request?
Add interfaces and APIs to push down Aggregates to V2 Data Source

### Why are the changes needed?
improve performance

### Does this PR introduce _any_ user-facing change?
SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED was added. If this is set to true, Aggregates are pushed down to Data Source.

### How was this patch tested?
New tests were added to test aggregates push down in https://github.com/apache/spark/pull/32049.  The original PR is split into two PRs. This PR doesn't contain new tests.

Closes #33352 from huaxingao/aggPushDownInterface.

Authored-by: Huaxin Gao <huaxin_gao@apple.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit c561ee6865)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Huaxin Gao 2021-07-26 16:01:22 +08:00 committed by Wenchen Fan
parent 1e17a5bc19
commit b1f522cf97
20 changed files with 1061 additions and 49 deletions

View file

@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector.expressions;
import org.apache.spark.annotation.Evolving;
import java.io.Serializable;
/**
* Base class of the Aggregate Functions.
*
* @since 3.2.0
*/
@Evolving
public interface AggregateFunc extends Expression, Serializable {
}

View file

@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector.expressions;
import org.apache.spark.annotation.Evolving;
import java.io.Serializable;
/**
* Aggregation in SQL statement.
*
* @since 3.2.0
*/
@Evolving
public final class Aggregation implements Serializable {
private AggregateFunc[] aggregateExpressions;
private FieldReference[] groupByColumns;
public Aggregation(AggregateFunc[] aggregateExpressions, FieldReference[] groupByColumns) {
this.aggregateExpressions = aggregateExpressions;
this.groupByColumns = groupByColumns;
}
public AggregateFunc[] aggregateExpressions() {
return aggregateExpressions;
}
public FieldReference[] groupByColumns() {
return groupByColumns;
}
}

View file

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector.expressions;
import org.apache.spark.annotation.Evolving;
/**
* An aggregate function that returns the number of the specific row in a group.
*
* @since 3.2.0
*/
@Evolving
public final class Count implements AggregateFunc {
private FieldReference column;
private boolean isDistinct;
public Count(FieldReference column, boolean isDistinct) {
this.column = column;
this.isDistinct = isDistinct;
}
public FieldReference column() {
return column;
}
public boolean isDinstinct() {
return isDistinct;
}
@Override
public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; }
@Override
public String describe() { return this.toString(); }
}

View file

@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector.expressions;
import org.apache.spark.annotation.Evolving;
/**
* An aggregate function that returns the number of rows in a group.
*
* @since 3.2.0
*/
@Evolving
public final class CountStar implements AggregateFunc {
public CountStar() {
}
@Override
public String toString() {
return "CountStar()";
}
@Override
public String describe() { return this.toString(); }
}

View file

@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector.expressions;
import org.apache.spark.annotation.Evolving;
/**
* An aggregate function that returns the maximum value in a group.
*
* @since 3.2.0
*/
@Evolving
public final class Max implements AggregateFunc {
private FieldReference column;
public Max(FieldReference column) {
this.column = column;
}
public FieldReference column() { return column; }
@Override
public String toString() {
return "Max(" + column.describe() + ")";
}
@Override
public String describe() { return this.toString(); }
}

View file

@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector.expressions;
import org.apache.spark.annotation.Evolving;
/**
* An aggregate function that returns the minimum value in a group.
*
* @since 3.2.0
*/
@Evolving
public final class Min implements AggregateFunc {
private FieldReference column;
public Min(FieldReference column) {
this.column = column;
}
public FieldReference column() {
return column;
}
@Override
public String toString() {
return "Min(" + column.describe() + ")";
}
@Override
public String describe() { return this.toString(); }
}

View file

@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector.expressions;
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.types.DataType;
/**
* An aggregate function that returns the summation of all the values in a group.
*
* @since 3.2.0
*/
@Evolving
public final class Sum implements AggregateFunc {
private FieldReference column;
private DataType dataType;
private boolean isDistinct;
public Sum(FieldReference column, DataType dataType, boolean isDistinct) {
this.column = column;
this.dataType = dataType;
this.isDistinct = isDistinct;
}
public FieldReference column() {
return column;
}
public DataType dataType() {
return dataType;
}
public boolean isDinstinct() {
return isDistinct;
}
@Override
public String toString() {
return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")";
}
@Override
public String describe() { return this.toString(); }
}

View file

@ -22,7 +22,8 @@ import org.apache.spark.annotation.Evolving;
/**
* An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ
* interfaces to do operator pushdown, and keep the operator pushdown result in the returned
* {@link Scan}.
* {@link Scan}. When pushing down operators, Spark pushes down filters first, then pushes down
* aggregates or applies column pruning.
*
* @since 3.0.0
*/

View file

@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connector.read;
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Aggregation;
/**
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
* push down aggregates. Spark assumes that the data source can't fully complete the
* grouping work, and will group the data source output again. For queries like
* "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate
* to the data source, the data source can still output data with duplicated keys, which is OK
* as Spark will do GROUP BY key again. The final query plan can be something like this:
* {{{
* Aggregate [key#1], [min(min(value)#2) AS m#3]
* +- RelationV2[key#1, min(value)#2]
* }}}
*
* <p>
* Similarly, if there is no grouping expression, the data source can still output more than one
* rows.
*
* <p>
* When pushing down operators, Spark pushes down filters to the data source first, then push down
* aggregates or apply column pruning. Depends on data source implementation, aggregates may or
* may not be able to be pushed down with filters. If pushed filters still need to be evaluated
* after scanning, aggregates can't be pushed down.
*
* @since 3.2.0
*/
@Evolving
public interface SupportsPushDownAggregates extends ScanBuilder {
/**
* Pushes down Aggregation to datasource. The order of the datasource scan output columns should
* be: grouping columns, aggregate columns (in the same order as the aggregate functions in
* the given Aggregation).
*/
boolean pushAggregation(Aggregation aggregation);
}

View file

@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.connector.expressions.Aggregation
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
@ -102,6 +103,7 @@ case class RowDataSourceScanExec(
requiredSchema: StructType,
filters: Set[Filter],
handledFilters: Set[Filter],
aggregation: Option[Aggregation],
rdd: RDD[InternalRow],
@transient relation: BaseRelation,
tableIdentifier: Option[TableIdentifier])
@ -129,12 +131,29 @@ case class RowDataSourceScanExec(
override def inputRDD: RDD[InternalRow] = rdd
override val metadata: Map[String, String] = {
val markedFilters = for (filter <- filters) yield {
if (handledFilters.contains(filter)) s"*$filter" else s"$filter"
def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")
val (aggString, groupByString) = if (aggregation.nonEmpty) {
(seqToString(aggregation.get.aggregateExpressions),
seqToString(aggregation.get.groupByColumns))
} else {
("[]", "[]")
}
val markedFilters = if (filters.nonEmpty) {
for (filter <- filters) yield {
if (handledFilters.contains(filter)) s"*$filter" else s"$filter"
}
} else {
handledFilters
}
Map(
"ReadSchema" -> requiredSchema.catalogString,
"PushedFilters" -> markedFilters.mkString("[", ", ", "]"))
"PushedFilters" -> seqToString(markedFilters.toSeq),
"PushedAggregates" -> aggString,
"PushedGroupby" -> groupByString)
}
// Don't care about `rdd` and `tableIdentifier` when canonicalizing.

View file

@ -33,12 +33,14 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, FieldReference, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
@ -332,6 +334,7 @@ object DataSourceStrategy
l.output.toStructType,
Set.empty,
Set.empty,
None,
toCatalystRDD(l, baseRelation.buildScan()),
baseRelation,
None) :: Nil
@ -405,6 +408,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
None,
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
@ -427,6 +431,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
None,
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
@ -692,6 +697,32 @@ object DataSourceStrategy
(nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters)
}
protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = {
if (aggregates.filter.isEmpty) {
aggregates.aggregateFunction match {
case aggregate.Min(PushableColumnWithoutNestedColumn(name)) =>
Some(new Min(FieldReference(name).asInstanceOf[FieldReference]))
case aggregate.Max(PushableColumnWithoutNestedColumn(name)) =>
Some(new Max(FieldReference(name).asInstanceOf[FieldReference]))
case count: aggregate.Count if count.children.length == 1 =>
count.children.head match {
// SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table
case Literal(_, _) => Some(new CountStar())
case PushableColumnWithoutNestedColumn(name) =>
Some(new Count(FieldReference(name).asInstanceOf[FieldReference],
aggregates.isDistinct))
case _ => None
}
case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
Some(new Sum(FieldReference(name).asInstanceOf[FieldReference],
sum.dataType, aggregates.isDistinct))
case _ => None
}
} else {
None
}
}
/**
* Convert RDD of Row into RDD of InternalRow with objects in catalyst types
*/

View file

@ -188,6 +188,9 @@ class JDBCOptions(
// An option to allow/disallow pushing down predicate into JDBC data source
val pushDownPredicate = parameters.getOrElse(JDBC_PUSHDOWN_PREDICATE, "true").toBoolean
// An option to allow/disallow pushing down aggregate into JDBC data source
val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE, "false").toBoolean
// The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either
// by --files option of spark-submit or manually
val keytab = {
@ -259,6 +262,7 @@ object JDBCOptions {
val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement")
val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate")
val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate")
val JDBC_KEYTAB = newOption("keytab")
val JDBC_PRINCIPAL = newOption("principal")
val JDBC_TABLE_COMMENT = newOption("tableComment")

View file

@ -25,6 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, FieldReference, Max, Min, Sum}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@ -133,6 +134,34 @@ object JDBCRDD extends Logging {
})
}
def compileAggregates(
aggregates: Seq[AggregateFunc],
dialect: JdbcDialect): Seq[String] = {
def quote(colName: String): String = dialect.quoteIdentifier(colName)
aggregates.map {
case min: Min =>
assert(min.column.fieldNames.length == 1)
s"MIN(${quote(min.column.fieldNames.head)})"
case max: Max =>
assert(max.column.fieldNames.length == 1)
s"MAX(${quote(max.column.fieldNames.head)})"
case count: Count =>
assert(count.column.fieldNames.length == 1)
val distinct = if (count.isDinstinct) "DISTINCT" else ""
val column = quote(count.column.fieldNames.head)
s"COUNT($distinct $column)"
case sum: Sum =>
assert(sum.column.fieldNames.length == 1)
val distinct = if (sum.isDinstinct) "DISTINCT" else ""
val column = quote(sum.column.fieldNames.head)
s"SUM($distinct $column)"
case _: CountStar =>
s"COUNT(1)"
case _ => ""
}
}
/**
* Build and return JDBCRDD from the given information.
*
@ -143,6 +172,8 @@ object JDBCRDD extends Logging {
* @param parts - An array of JDBCPartitions specifying partition ids and
* per-partition WHERE clauses.
* @param options - JDBC options that contains url, table and other information.
* @param requiredSchema - The schema of the columns to SELECT.
* @param aggregation - The pushed down aggregation
*
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
*/
@ -152,19 +183,27 @@ object JDBCRDD extends Logging {
requiredColumns: Array[String],
filters: Array[Filter],
parts: Array[Partition],
options: JDBCOptions): RDD[InternalRow] = {
options: JDBCOptions,
outputSchema: Option[StructType] = None,
groupByColumns: Option[Array[FieldReference]] = None): RDD[InternalRow] = {
val url = options.url
val dialect = JdbcDialects.get(url)
val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
val quotedColumns = if (groupByColumns.isEmpty) {
requiredColumns.map(colName => dialect.quoteIdentifier(colName))
} else {
// these are already quoted in JDBCScanBuilder
requiredColumns
}
new JDBCRDD(
sc,
JdbcUtils.createConnectionFactory(options),
pruneSchema(schema, requiredColumns),
outputSchema.getOrElse(pruneSchema(schema, requiredColumns)),
quotedColumns,
filters,
parts,
url,
options)
options,
groupByColumns)
}
}
@ -181,7 +220,8 @@ private[jdbc] class JDBCRDD(
filters: Array[Filter],
partitions: Array[Partition],
url: String,
options: JDBCOptions)
options: JDBCOptions,
groupByColumns: Option[Array[FieldReference]])
extends RDD[InternalRow](sc, Nil) {
/**
@ -221,6 +261,20 @@ private[jdbc] class JDBCRDD(
}
}
/**
* A GROUP BY clause representing pushed-down grouping columns.
*/
private def getGroupByClause: String = {
if (groupByColumns.nonEmpty && groupByColumns.get.nonEmpty) {
assert(groupByColumns.get.forall(_.fieldNames.length == 1))
val dialect = JdbcDialects.get(url)
val quotedColumns = groupByColumns.get.map(c => dialect.quoteIdentifier(c.fieldNames.head))
s"GROUP BY ${quotedColumns.mkString(", ")}"
} else {
""
}
}
/**
* Runs the SQL query against the JDBC driver.
*
@ -296,7 +350,8 @@ private[jdbc] class JDBCRDD(
val myWhereClause = getWhereClause(part)
val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause"
val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" +
s" $getGroupByClause"
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)

View file

@ -27,6 +27,7 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.jdbc.JdbcDialects
@ -288,6 +289,23 @@ private[sql] case class JDBCRelation(
jdbcOptions).asInstanceOf[RDD[Row]]
}
def buildScan(
requiredColumns: Array[String],
requireSchema: Option[StructType],
filters: Array[Filter],
groupByColumns: Option[Array[FieldReference]]): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sparkSession.sparkContext,
schema,
requiredColumns,
filters,
parts,
jdbcOptions,
requireSchema,
groupByColumns).asInstanceOf[RDD[Row]]
}
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
data.write
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)

View file

@ -87,7 +87,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(project, filters,
relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated, pushed), output)) =>
DataSourceV2ScanRelation(_, V1ScanWrapper(scan, pushed, aggregate), output)) =>
val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext)
if (v1Relation.schema != scan.readSchema()) {
throw QueryExecutionErrors.fallbackV1RelationReportsInconsistentSchemaError(
@ -98,8 +98,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
val dsScan = RowDataSourceScanExec(
output,
output.toStructType,
translated.toSet,
Set.empty,
pushed.toSet,
aggregate,
unsafeRowRDD,
v1Relation,
tableIdentifier = None)

View file

@ -20,9 +20,13 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.expressions.{Aggregation, FieldReference}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.datasources.PushableColumnWithoutNestedColumn
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.StructType
@ -70,6 +74,42 @@ object PushDownUtils extends PredicateHelper {
}
}
/**
* Pushes down aggregates to the data source reader
*
* @return pushed aggregation.
*/
def pushAggregates(
scanBuilder: ScanBuilder,
aggregates: Seq[AggregateExpression],
groupBy: Seq[Expression]): Option[Aggregation] = {
def columnAsString(e: Expression): Option[FieldReference] = e match {
case PushableColumnWithoutNestedColumn(name) =>
Some(FieldReference(name).asInstanceOf[FieldReference])
case _ => None
}
scanBuilder match {
case r: SupportsPushDownAggregates =>
val translatedAggregates = aggregates.map(DataSourceStrategy.translateAggregate).flatten
val translatedGroupBys = groupBy.map(columnAsString).flatten
if (translatedAggregates.length != aggregates.length ||
translatedGroupBys.length != groupBy.length) {
return None
}
val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)
if (r.pushAggregation(agg)) {
Some(agg)
} else {
None
}
case _ => None
}
}
/**
* Applies column pruning to the data source, w.r.t. the references of the given expressions.
*

View file

@ -17,23 +17,36 @@
package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.read.{Scan, V1Scan}
import org.apache.spark.sql.connector.expressions.Aggregation
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.StructType
object V2ScanRelationPushDown extends Rule[LogicalPlan] {
object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
import DataSourceV2Implicits._
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options)
def apply(plan: LogicalPlan): LogicalPlan = {
applyColumnPruning(pushdownAggregate(pushDownFilters(createScanBuilder(plan))))
}
val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output)
private def createScanBuilder(plan: LogicalPlan) = plan.transform {
case r: DataSourceV2Relation =>
ScanBuilderHolder(r.output, r, r.table.asReadable.newScanBuilder(r.options))
}
private def pushDownFilters(plan: LogicalPlan) = plan.transform {
// update the scan builder with filter push down and return a new plan with filter pushed
case Filter(condition, sHolder: ScanBuilderHolder) =>
val filters = splitConjunctivePredicates(condition)
val normalizedFilters =
DataSourceStrategy.normalizeExprs(filters, sHolder.relation.output)
val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) =
normalizedFilters.partition(SubqueryExpression.hasSubquery)
@ -41,37 +54,142 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] {
// `postScanFilters` need to be evaluated after the scan.
// `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter.
val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters(
scanBuilder, normalizedFiltersWithoutSubquery)
sHolder.builder, normalizedFiltersWithoutSubquery)
val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery
val normalizedProjects = DataSourceStrategy
.normalizeExprs(project, relation.output)
.asInstanceOf[Seq[NamedExpression]]
val (scan, output) = PushDownUtils.pruneColumns(
scanBuilder, relation, normalizedProjects, postScanFilters)
logInfo(
s"""
|Pushing operators to ${relation.name}
|Pushing operators to ${sHolder.relation.name}
|Pushed Filters: ${pushedFilters.mkString(", ")}
|Post-Scan Filters: ${postScanFilters.mkString(",")}
""".stripMargin)
val filterCondition = postScanFilters.reduceLeftOption(And)
filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder)
}
def pushdownAggregate(plan: LogicalPlan): LogicalPlan = plan.transform {
// update the scan builder with agg pushdown and return a new plan with agg pushed
case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) =>
child match {
case ScanOperation(project, filters, sHolder: ScanBuilderHolder)
if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) =>
sHolder.builder match {
case _: SupportsPushDownAggregates =>
val aggregates = resultExpressions.flatMap { expr =>
expr.collect {
case agg: AggregateExpression => agg
}
}
val pushedAggregates = PushDownUtils
.pushAggregates(sHolder.builder, aggregates, groupingExpressions)
if (pushedAggregates.isEmpty) {
aggNode // return original plan node
} else {
// No need to do column pruning because only the aggregate columns are used as
// DataSourceV2ScanRelation output columns. All the other columns are not
// included in the output.
val scan = sHolder.builder.build()
// scalastyle:off
// use the group by columns and aggregate columns as the output columns
// e.g. TABLE t (c1 INT, c2 INT, c3 INT)
// SELECT min(c1), max(c1) FROM t GROUP BY c2;
// Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation
// We want to have the following logical plan:
// == Optimized Logical Plan ==
// Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
// scalastyle:on
val newOutput = scan.readSchema().toAttributes
assert(newOutput.length == groupingExpressions.length + aggregates.length)
val groupAttrs = groupingExpressions.zip(newOutput).map {
case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
case (_, b) => b
}
val output = groupAttrs ++ newOutput.drop(groupAttrs.length)
logInfo(
s"""
|Pushing operators to ${sHolder.relation.name}
|Pushed Aggregate Functions:
| ${pushedAggregates.get.aggregateExpressions.mkString(", ")}
|Pushed Group by:
| ${pushedAggregates.get.groupByColumns.mkString(", ")}
|Output: ${output.mkString(", ")}
""".stripMargin)
val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates)
val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
val plan = Aggregate(
output.take(groupingExpressions.length), resultExpressions, scanRelation)
// scalastyle:off
// Change the optimized logical plan to reflect the pushed down aggregate
// e.g. TABLE t (c1 INT, c2 INT, c3 INT)
// SELECT min(c1), max(c1) FROM t GROUP BY c2;
// The original logical plan is
// Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
// +- RelationV2[c1#9, c2#10] ...
//
// After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22]
// we have the following
// !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
//
// We want to change it to
// == Optimized Logical Plan ==
// Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
// scalastyle:on
var i = 0
val aggOutput = output.drop(groupAttrs.length)
plan.transformExpressions {
case agg: AggregateExpression =>
val aggFunction: aggregate.AggregateFunction =
agg.aggregateFunction match {
case max: aggregate.Max => max.copy(child = aggOutput(i))
case min: aggregate.Min => min.copy(child = aggOutput(i))
case sum: aggregate.Sum => sum.copy(child = aggOutput(i))
case _: aggregate.Count => aggregate.Sum(aggOutput(i))
case other => other
}
i += 1
agg.copy(aggregateFunction = aggFunction)
}
}
case _ => aggNode
}
case _ => aggNode
}
}
def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform {
case ScanOperation(project, filters, sHolder: ScanBuilderHolder) =>
// column pruning
val normalizedProjects = DataSourceStrategy
.normalizeExprs(project, sHolder.output)
.asInstanceOf[Seq[NamedExpression]]
val (scan, output) = PushDownUtils.pruneColumns(
sHolder.builder, sHolder.relation, normalizedProjects, filters)
logInfo(
s"""
|Output: ${output.mkString(", ")}
""".stripMargin)
val wrappedScan = scan match {
case v1: V1Scan =>
val translated = filters.flatMap(DataSourceStrategy.translateFilter(_, true))
V1ScanWrapper(v1, translated, pushedFilters)
case _ => scan
}
val wrappedScan = getWrappedScan(scan, sHolder, Option.empty[Aggregation])
val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output)
val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
val projectionOverSchema = ProjectionOverSchema(output.toStructType)
val projectionFunc = (expr: Expression) => expr transformDown {
case projectionOverSchema(newExpr) => newExpr
}
val filterCondition = postScanFilters.reduceLeftOption(And)
val filterCondition = filters.reduceLeftOption(And)
val newFilterCondition = filterCondition.map(projectionFunc)
val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation)
@ -83,16 +201,36 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] {
} else {
withFilter
}
withProjection
}
private def getWrappedScan(
scan: Scan,
sHolder: ScanBuilderHolder,
aggregation: Option[Aggregation]): Scan = {
scan match {
case v1: V1Scan =>
val pushedFilters = sHolder.builder match {
case f: SupportsPushDownFilters =>
f.pushedFilters()
case _ => Array.empty[sources.Filter]
}
V1ScanWrapper(v1, pushedFilters, aggregation)
case _ => scan
}
}
}
case class ScanBuilderHolder(
output: Seq[AttributeReference],
relation: DataSourceV2Relation,
builder: ScanBuilder) extends LeafNode
// A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by
// the physical v1 scan node.
case class V1ScanWrapper(
v1Scan: V1Scan,
translatedFilters: Seq[sources.Filter],
handledFilters: Seq[sources.Filter]) extends Scan {
handledFilters: Seq[sources.Filter],
pushedAggregate: Option[Aggregation]) extends Scan {
override def readSchema(): StructType = v1Scan.readSchema()
}

View file

@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.read.V1Scan
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan}
@ -26,7 +27,9 @@ import org.apache.spark.sql.types.StructType
case class JDBCScan(
relation: JDBCRelation,
prunedSchema: StructType,
pushedFilters: Array[Filter]) extends V1Scan {
pushedFilters: Array[Filter],
pushedAggregateColumn: Array[String] = Array(),
groupByColumns: Option[Array[FieldReference]]) extends V1Scan {
override def readSchema(): StructType = prunedSchema
@ -36,14 +39,28 @@ case class JDBCScan(
override def schema: StructType = prunedSchema
override def needConversion: Boolean = relation.needConversion
override def buildScan(): RDD[Row] = {
relation.buildScan(prunedSchema.map(_.name).toArray, pushedFilters)
if (groupByColumns.isEmpty) {
relation.buildScan(
prunedSchema.map(_.name).toArray, Some(prunedSchema), pushedFilters, groupByColumns)
} else {
relation.buildScan(
pushedAggregateColumn, Some(prunedSchema), pushedFilters, groupByColumns)
}
}
}.asInstanceOf[T]
}
override def description(): String = {
val (aggString, groupByString) = if (groupByColumns.nonEmpty) {
val groupByColumnsLength = groupByColumns.get.length
(seqToString(pushedAggregateColumn.drop(groupByColumnsLength)),
seqToString(pushedAggregateColumn.take(groupByColumnsLength)))
} else {
("[]", "[]")
}
super.description() + ", prunedSchema: " + seqToString(prunedSchema) +
", PushedFilters: " + seqToString(pushedFilters)
", PushedFilters: " + seqToString(pushedFilters) +
", PushedAggregates: " + aggString + ", PushedGroupBy: " + groupByString
}
private def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")

View file

@ -17,18 +17,20 @@
package org.apache.spark.sql.execution.datasources.v2.jdbc
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.connector.expressions.{Aggregation, Count, CountStar, FieldReference, Max, Min, Sum}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation}
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{LongType, StructField, StructType}
case class JDBCScanBuilder(
session: SparkSession,
schema: StructType,
jdbcOptions: JDBCOptions)
extends ScanBuilder with SupportsPushDownFilters with SupportsPushDownRequiredColumns {
extends ScanBuilder with SupportsPushDownFilters with SupportsPushDownRequiredColumns
with SupportsPushDownAggregates{
private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis
@ -49,6 +51,58 @@ case class JDBCScanBuilder(
override def pushedFilters(): Array[Filter] = pushedFilter
private var pushedAggregations = Option.empty[Aggregation]
private var pushedAggregateColumn: Array[String] = Array()
private def getStructFieldForCol(col: FieldReference): StructField =
schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
override def pushAggregation(aggregation: Aggregation): Boolean = {
if (!jdbcOptions.pushDownAggregate) return false
val dialect = JdbcDialects.get(jdbcOptions.url)
val compiledAgg = JDBCRDD.compileAggregates(aggregation.aggregateExpressions, dialect)
var outputSchema = new StructType()
aggregation.groupByColumns.foreach { col =>
val structField = getStructFieldForCol(col)
outputSchema = outputSchema.add(structField)
pushedAggregateColumn = pushedAggregateColumn :+ dialect.quoteIdentifier(structField.name)
}
// The column names here are already quoted and can be used to build sql string directly.
// e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") =>
// SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee"
// GROUP BY "DEPT", "NAME"
pushedAggregateColumn = pushedAggregateColumn ++ compiledAgg
aggregation.aggregateExpressions.foreach {
case max: Max =>
val structField = getStructFieldForCol(max.column)
outputSchema = outputSchema.add(structField.copy("max(" + structField.name + ")"))
case min: Min =>
val structField = getStructFieldForCol(min.column)
outputSchema = outputSchema.add(structField.copy("min(" + structField.name + ")"))
case count: Count =>
val distinct = if (count.isDinstinct) "DISTINCT " else ""
val structField = getStructFieldForCol(count.column)
outputSchema =
outputSchema.add(StructField(s"count($distinct" + structField.name + ")", LongType))
case _: CountStar =>
outputSchema = outputSchema.add(StructField("count(*)", LongType))
case sum: Sum =>
val distinct = if (sum.isDinstinct) "DISTINCT " else ""
val structField = getStructFieldForCol(sum.column)
outputSchema =
outputSchema.add(StructField(s"sum($distinct" + structField.name + ")", sum.dataType))
case _ => return false
}
this.pushedAggregations = Some(aggregation)
prunedSchema = outputSchema
true
}
override def pruneColumns(requiredSchema: StructType): Unit = {
// JDBC doesn't support nested column pruning.
// TODO (SPARK-32593): JDBC support nested column and nested column pruning.
@ -65,6 +119,20 @@ case class JDBCScanBuilder(
val resolver = session.sessionState.conf.resolver
val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions)
JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), prunedSchema, pushedFilter)
// in prunedSchema, the schema is either pruned in pushAggregation (if aggregates are
// pushed down), or pruned in pruneColumns (in regular column pruning). These
// two are mutual exclusive.
// For aggregate push down case, we want to pass down the quoted column lists such as
// "DEPT","NAME",MAX("SALARY"),MIN("BONUS"), instead of getting column names from
// prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't
// be used in sql string.
val groupByColumns = if (pushedAggregations.nonEmpty) {
Some(pushedAggregations.get.groupByColumns)
} else {
Option.empty[Array[FieldReference]]
}
JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), prunedSchema, pushedFilter,
pushedAggregateColumn, groupByColumns)
}
}

View file

@ -21,16 +21,16 @@ import java.sql.{Connection, DriverManager}
import java.util.Properties
import org.apache.spark.SparkConf
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.{lit, sum, udf}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils
class JDBCV2Suite extends QueryTest with SharedSparkSession {
class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHelper {
import testImplicits._
val tempDir = Utils.createTempDir()
@ -41,6 +41,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
.set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName)
.set("spark.sql.catalog.h2.url", url)
.set("spark.sql.catalog.h2.driver", "org.h2.Driver")
.set("spark.sql.catalog.h2.pushDownAggregate", "true")
private def withConnection[T](f: Connection => T): T = {
val conn = DriverManager.getConnection(url, new Properties())
@ -64,6 +65,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('fred', 1)").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('mary', 2)").executeUpdate()
conn.prepareStatement(
"CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32), salary NUMERIC(20, 2)," +
" bonus DOUBLE)").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000)")
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200)")
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200)")
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300)")
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)")
.executeUpdate()
}
}
@ -84,6 +98,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
case f: Filter => f
}
assert(filters.isEmpty)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedFilters: [IsNotNull(ID), GreaterThan(ID,1)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Row("mary", 2))
}
@ -145,7 +167,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
test("show tables") {
checkAnswer(sql("SHOW TABLES IN h2.test"),
Seq(Row("test", "people", false), Row("test", "empty_table", false)))
Seq(Row("test", "people", false), Row("test", "empty_table", false),
Row("test", "employee", false)))
}
test("SQL API: create table as select") {
@ -214,4 +237,232 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
checkAnswer(sql("SELECT name, id FROM h2.test.abc"), Row("bob", 4))
}
}
test("scan with aggregate push-down: MAX MIN with filter and group by") {
val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" +
" group by DEPT")
val filters = df.queryExecution.optimizedPlan.collect {
case f: Filter => f
}
assert(filters.isEmpty)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [Max(SALARY), Min(BONUS)], " +
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(10000, 1000), Row(12000, 1200), Row(12000, 1200)))
}
test("scan with aggregate push-down: MAX MIN with filter without group by") {
val df = sql("select MAX(ID), MIN(ID) FROM h2.test.people where id > 0")
val filters = df.queryExecution.optimizedPlan.collect {
case f: Filter => f
}
assert(filters.isEmpty)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [Max(ID), Min(ID)], " +
"PushedFilters: [IsNotNull(ID), GreaterThan(ID,0)], " +
"PushedGroupby: []"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(2, 1)))
}
test("scan with aggregate push-down: aggregate + number") {
val df = sql("select MAX(SALARY) + 1 FROM h2.test.employee")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [Max(SALARY)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(12001)))
}
test("scan with aggregate push-down: COUNT(*)") {
val df = sql("select COUNT(*) FROM h2.test.employee")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [CountStar()]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(5)))
}
test("scan with aggregate push-down: COUNT(col)") {
val df = sql("select COUNT(DEPT) FROM h2.test.employee")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [Count(DEPT,false)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(5)))
}
test("scan with aggregate push-down: COUNT(DISTINCT col)") {
val df = sql("select COUNT(DISTINCT DEPT) FROM h2.test.employee")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [Count(DEPT,true)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(3)))
}
test("scan with aggregate push-down: SUM without filer and group by") {
val df = sql("SELECT SUM(SALARY) FROM h2.test.employee")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(53000)))
}
test("scan with aggregate push-down: DISTINCT SUM without filer and group by") {
val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [Sum(SALARY,DecimalType(30,2),true)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(31000)))
}
test("scan with aggregate push-down: SUM with group by") {
val df = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)], " +
"PushedFilters: [], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000)))
}
test("scan with aggregate push-down: DISTINCT SUM with group by") {
val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee GROUP BY DEPT")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [Sum(SALARY,DecimalType(30,2),true)], " +
"PushedFilters: [], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000)))
}
test("scan with aggregate push-down: with multiple group by columns") {
val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" +
" group by DEPT, NAME")
val filters11 = df.queryExecution.optimizedPlan.collect {
case f: Filter => f
}
assert(filters11.isEmpty)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [Max(SALARY), Min(BONUS)], " +
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
"PushedGroupby: [DEPT, NAME]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300),
Row(10000, 1000), Row(12000, 1200)))
}
test("scan with aggregate push-down: with having clause") {
val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" +
" group by DEPT having MIN(BONUS) > 1000")
val filters = df.queryExecution.optimizedPlan.collect {
case f: Filter => f // filter over aggregate not push down
}
assert(filters.nonEmpty)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [Max(SALARY), Min(BONUS)], " +
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(12000, 1200), Row(12000, 1200)))
}
test("scan with aggregate push-down: alias over aggregate") {
val df = sql("select * from h2.test.employee")
.groupBy($"DEPT")
.min("SALARY").as("total")
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [Min(SALARY)], " +
"PushedFilters: [], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
checkAnswer(df, Seq(Row(1, 9000), Row(2, 10000), Row(6, 12000)))
}
test("scan with aggregate push-down: order by alias over aggregate") {
val df = spark.table("h2.test.employee")
val query = df.select($"DEPT", $"SALARY")
.filter($"DEPT" > 0)
.groupBy($"DEPT")
.agg(sum($"SALARY").as("total"))
.filter($"total" > 1000)
.orderBy($"total")
val filters = query.queryExecution.optimizedPlan.collect {
case f: Filter => f
}
assert(filters.nonEmpty) // filter over aggregate not pushed down
query.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)], " +
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
"PushedGroupby: [DEPT]"
checkKeywordsExistsInExplain(query, expected_plan_fragment)
}
checkAnswer(query, Seq(Row(6, 12000), Row(1, 19000), Row(2, 22000)))
}
test("scan with aggregate push-down: udf over aggregate") {
val df = spark.table("h2.test.employee")
val decrease = udf { (x: Double, y: Double) => x - y }
val query = df.select(decrease(sum($"SALARY"), sum($"BONUS")).as("value"))
query.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [Sum(SALARY,DecimalType(30,2),false), Sum(BONUS,DoubleType,false)"
checkKeywordsExistsInExplain(query, expected_plan_fragment)
}
checkAnswer(query, Seq(Row(47100.0)))
}
test("scan with aggregate push-down: aggregate over alias") {
val cols = Seq("a", "b", "c", "d")
val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
val df2 = df1.groupBy().sum("c")
df2.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: []" // aggregate over alias not push down
checkKeywordsExistsInExplain(df2, expected_plan_fragment)
}
checkAnswer(df2, Seq(Row(53000.00)))
}
}