[SPARK-34952][SQL] DSv2 Aggregate push down APIs
### What changes were proposed in this pull request?
Add interfaces and APIs to push down Aggregates to V2 Data Source
### Why are the changes needed?
improve performance
### Does this PR introduce _any_ user-facing change?
SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED was added. If this is set to true, Aggregates are pushed down to Data Source.
### How was this patch tested?
New tests were added to test aggregates push down in https://github.com/apache/spark/pull/32049. The original PR is split into two PRs. This PR doesn't contain new tests.
Closes #33352 from huaxingao/aggPushDownInterface.
Authored-by: Huaxin Gao <huaxin_gao@apple.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit c561ee6865
)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
1e17a5bc19
commit
b1f522cf97
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.connector.expressions;
|
||||
|
||||
import org.apache.spark.annotation.Evolving;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* Base class of the Aggregate Functions.
|
||||
*
|
||||
* @since 3.2.0
|
||||
*/
|
||||
@Evolving
|
||||
public interface AggregateFunc extends Expression, Serializable {
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.connector.expressions;
|
||||
|
||||
import org.apache.spark.annotation.Evolving;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* Aggregation in SQL statement.
|
||||
*
|
||||
* @since 3.2.0
|
||||
*/
|
||||
@Evolving
|
||||
public final class Aggregation implements Serializable {
|
||||
private AggregateFunc[] aggregateExpressions;
|
||||
private FieldReference[] groupByColumns;
|
||||
|
||||
public Aggregation(AggregateFunc[] aggregateExpressions, FieldReference[] groupByColumns) {
|
||||
this.aggregateExpressions = aggregateExpressions;
|
||||
this.groupByColumns = groupByColumns;
|
||||
}
|
||||
|
||||
public AggregateFunc[] aggregateExpressions() {
|
||||
return aggregateExpressions;
|
||||
}
|
||||
|
||||
public FieldReference[] groupByColumns() {
|
||||
return groupByColumns;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.connector.expressions;
|
||||
|
||||
import org.apache.spark.annotation.Evolving;
|
||||
|
||||
/**
|
||||
* An aggregate function that returns the number of the specific row in a group.
|
||||
*
|
||||
* @since 3.2.0
|
||||
*/
|
||||
@Evolving
|
||||
public final class Count implements AggregateFunc {
|
||||
private FieldReference column;
|
||||
private boolean isDistinct;
|
||||
|
||||
public Count(FieldReference column, boolean isDistinct) {
|
||||
this.column = column;
|
||||
this.isDistinct = isDistinct;
|
||||
}
|
||||
|
||||
public FieldReference column() {
|
||||
return column;
|
||||
}
|
||||
public boolean isDinstinct() {
|
||||
return isDistinct;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; }
|
||||
|
||||
@Override
|
||||
public String describe() { return this.toString(); }
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.connector.expressions;
|
||||
|
||||
import org.apache.spark.annotation.Evolving;
|
||||
|
||||
/**
|
||||
* An aggregate function that returns the number of rows in a group.
|
||||
*
|
||||
* @since 3.2.0
|
||||
*/
|
||||
@Evolving
|
||||
public final class CountStar implements AggregateFunc {
|
||||
|
||||
public CountStar() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "CountStar()";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String describe() { return this.toString(); }
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.connector.expressions;
|
||||
|
||||
import org.apache.spark.annotation.Evolving;
|
||||
|
||||
/**
|
||||
* An aggregate function that returns the maximum value in a group.
|
||||
*
|
||||
* @since 3.2.0
|
||||
*/
|
||||
@Evolving
|
||||
public final class Max implements AggregateFunc {
|
||||
private FieldReference column;
|
||||
|
||||
public Max(FieldReference column) {
|
||||
this.column = column;
|
||||
}
|
||||
|
||||
public FieldReference column() { return column; }
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Max(" + column.describe() + ")";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String describe() { return this.toString(); }
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.connector.expressions;
|
||||
|
||||
import org.apache.spark.annotation.Evolving;
|
||||
|
||||
/**
|
||||
* An aggregate function that returns the minimum value in a group.
|
||||
*
|
||||
* @since 3.2.0
|
||||
*/
|
||||
@Evolving
|
||||
public final class Min implements AggregateFunc {
|
||||
private FieldReference column;
|
||||
|
||||
public Min(FieldReference column) {
|
||||
this.column = column;
|
||||
}
|
||||
|
||||
public FieldReference column() {
|
||||
return column;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Min(" + column.describe() + ")";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String describe() { return this.toString(); }
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.connector.expressions;
|
||||
|
||||
import org.apache.spark.annotation.Evolving;
|
||||
import org.apache.spark.sql.types.DataType;
|
||||
|
||||
/**
|
||||
* An aggregate function that returns the summation of all the values in a group.
|
||||
*
|
||||
* @since 3.2.0
|
||||
*/
|
||||
@Evolving
|
||||
public final class Sum implements AggregateFunc {
|
||||
private FieldReference column;
|
||||
private DataType dataType;
|
||||
private boolean isDistinct;
|
||||
|
||||
public Sum(FieldReference column, DataType dataType, boolean isDistinct) {
|
||||
this.column = column;
|
||||
this.dataType = dataType;
|
||||
this.isDistinct = isDistinct;
|
||||
}
|
||||
|
||||
public FieldReference column() {
|
||||
return column;
|
||||
}
|
||||
public DataType dataType() {
|
||||
return dataType;
|
||||
}
|
||||
public boolean isDinstinct() {
|
||||
return isDistinct;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String describe() { return this.toString(); }
|
||||
}
|
|
@ -22,7 +22,8 @@ import org.apache.spark.annotation.Evolving;
|
|||
/**
|
||||
* An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ
|
||||
* interfaces to do operator pushdown, and keep the operator pushdown result in the returned
|
||||
* {@link Scan}.
|
||||
* {@link Scan}. When pushing down operators, Spark pushes down filters first, then pushes down
|
||||
* aggregates or applies column pruning.
|
||||
*
|
||||
* @since 3.0.0
|
||||
*/
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.connector.read;
|
||||
|
||||
import org.apache.spark.annotation.Evolving;
|
||||
import org.apache.spark.sql.connector.expressions.Aggregation;
|
||||
|
||||
/**
|
||||
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
|
||||
* push down aggregates. Spark assumes that the data source can't fully complete the
|
||||
* grouping work, and will group the data source output again. For queries like
|
||||
* "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate
|
||||
* to the data source, the data source can still output data with duplicated keys, which is OK
|
||||
* as Spark will do GROUP BY key again. The final query plan can be something like this:
|
||||
* {{{
|
||||
* Aggregate [key#1], [min(min(value)#2) AS m#3]
|
||||
* +- RelationV2[key#1, min(value)#2]
|
||||
* }}}
|
||||
*
|
||||
* <p>
|
||||
* Similarly, if there is no grouping expression, the data source can still output more than one
|
||||
* rows.
|
||||
*
|
||||
* <p>
|
||||
* When pushing down operators, Spark pushes down filters to the data source first, then push down
|
||||
* aggregates or apply column pruning. Depends on data source implementation, aggregates may or
|
||||
* may not be able to be pushed down with filters. If pushed filters still need to be evaluated
|
||||
* after scanning, aggregates can't be pushed down.
|
||||
*
|
||||
* @since 3.2.0
|
||||
*/
|
||||
@Evolving
|
||||
public interface SupportsPushDownAggregates extends ScanBuilder {
|
||||
|
||||
/**
|
||||
* Pushes down Aggregation to datasource. The order of the datasource scan output columns should
|
||||
* be: grouping columns, aggregate columns (in the same order as the aggregate functions in
|
||||
* the given Aggregation).
|
||||
*/
|
||||
boolean pushAggregation(Aggregation aggregation);
|
||||
}
|
|
@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions._
|
|||
import org.apache.spark.sql.catalyst.plans.QueryPlan
|
||||
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
|
||||
import org.apache.spark.sql.catalyst.util.truncatedString
|
||||
import org.apache.spark.sql.connector.expressions.Aggregation
|
||||
import org.apache.spark.sql.execution.datasources._
|
||||
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
|
||||
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
|
||||
|
@ -102,6 +103,7 @@ case class RowDataSourceScanExec(
|
|||
requiredSchema: StructType,
|
||||
filters: Set[Filter],
|
||||
handledFilters: Set[Filter],
|
||||
aggregation: Option[Aggregation],
|
||||
rdd: RDD[InternalRow],
|
||||
@transient relation: BaseRelation,
|
||||
tableIdentifier: Option[TableIdentifier])
|
||||
|
@ -129,12 +131,29 @@ case class RowDataSourceScanExec(
|
|||
override def inputRDD: RDD[InternalRow] = rdd
|
||||
|
||||
override val metadata: Map[String, String] = {
|
||||
val markedFilters = for (filter <- filters) yield {
|
||||
if (handledFilters.contains(filter)) s"*$filter" else s"$filter"
|
||||
|
||||
def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")
|
||||
|
||||
val (aggString, groupByString) = if (aggregation.nonEmpty) {
|
||||
(seqToString(aggregation.get.aggregateExpressions),
|
||||
seqToString(aggregation.get.groupByColumns))
|
||||
} else {
|
||||
("[]", "[]")
|
||||
}
|
||||
|
||||
val markedFilters = if (filters.nonEmpty) {
|
||||
for (filter <- filters) yield {
|
||||
if (handledFilters.contains(filter)) s"*$filter" else s"$filter"
|
||||
}
|
||||
} else {
|
||||
handledFilters
|
||||
}
|
||||
|
||||
Map(
|
||||
"ReadSchema" -> requiredSchema.catalogString,
|
||||
"PushedFilters" -> markedFilters.mkString("[", ", ", "]"))
|
||||
"PushedFilters" -> seqToString(markedFilters.toSeq),
|
||||
"PushedAggregates" -> aggString,
|
||||
"PushedGroupby" -> groupByString)
|
||||
}
|
||||
|
||||
// Don't care about `rdd` and `tableIdentifier` when canonicalizing.
|
||||
|
|
|
@ -33,12 +33,14 @@ import org.apache.spark.sql.catalyst.catalog._
|
|||
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
|
||||
import org.apache.spark.sql.catalyst.planning.ScanOperation
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project}
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
|
||||
import org.apache.spark.sql.connector.catalog.SupportsRead
|
||||
import org.apache.spark.sql.connector.catalog.TableCapability._
|
||||
import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, FieldReference, Max, Min, Sum}
|
||||
import org.apache.spark.sql.errors.QueryCompilationErrors
|
||||
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
|
||||
import org.apache.spark.sql.execution.command._
|
||||
|
@ -332,6 +334,7 @@ object DataSourceStrategy
|
|||
l.output.toStructType,
|
||||
Set.empty,
|
||||
Set.empty,
|
||||
None,
|
||||
toCatalystRDD(l, baseRelation.buildScan()),
|
||||
baseRelation,
|
||||
None) :: Nil
|
||||
|
@ -405,6 +408,7 @@ object DataSourceStrategy
|
|||
requestedColumns.toStructType,
|
||||
pushedFilters.toSet,
|
||||
handledFilters,
|
||||
None,
|
||||
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
|
||||
relation.relation,
|
||||
relation.catalogTable.map(_.identifier))
|
||||
|
@ -427,6 +431,7 @@ object DataSourceStrategy
|
|||
requestedColumns.toStructType,
|
||||
pushedFilters.toSet,
|
||||
handledFilters,
|
||||
None,
|
||||
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
|
||||
relation.relation,
|
||||
relation.catalogTable.map(_.identifier))
|
||||
|
@ -692,6 +697,32 @@ object DataSourceStrategy
|
|||
(nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters)
|
||||
}
|
||||
|
||||
protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = {
|
||||
if (aggregates.filter.isEmpty) {
|
||||
aggregates.aggregateFunction match {
|
||||
case aggregate.Min(PushableColumnWithoutNestedColumn(name)) =>
|
||||
Some(new Min(FieldReference(name).asInstanceOf[FieldReference]))
|
||||
case aggregate.Max(PushableColumnWithoutNestedColumn(name)) =>
|
||||
Some(new Max(FieldReference(name).asInstanceOf[FieldReference]))
|
||||
case count: aggregate.Count if count.children.length == 1 =>
|
||||
count.children.head match {
|
||||
// SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table
|
||||
case Literal(_, _) => Some(new CountStar())
|
||||
case PushableColumnWithoutNestedColumn(name) =>
|
||||
Some(new Count(FieldReference(name).asInstanceOf[FieldReference],
|
||||
aggregates.isDistinct))
|
||||
case _ => None
|
||||
}
|
||||
case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
|
||||
Some(new Sum(FieldReference(name).asInstanceOf[FieldReference],
|
||||
sum.dataType, aggregates.isDistinct))
|
||||
case _ => None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert RDD of Row into RDD of InternalRow with objects in catalyst types
|
||||
*/
|
||||
|
|
|
@ -188,6 +188,9 @@ class JDBCOptions(
|
|||
// An option to allow/disallow pushing down predicate into JDBC data source
|
||||
val pushDownPredicate = parameters.getOrElse(JDBC_PUSHDOWN_PREDICATE, "true").toBoolean
|
||||
|
||||
// An option to allow/disallow pushing down aggregate into JDBC data source
|
||||
val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE, "false").toBoolean
|
||||
|
||||
// The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either
|
||||
// by --files option of spark-submit or manually
|
||||
val keytab = {
|
||||
|
@ -259,6 +262,7 @@ object JDBCOptions {
|
|||
val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
|
||||
val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement")
|
||||
val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate")
|
||||
val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate")
|
||||
val JDBC_KEYTAB = newOption("keytab")
|
||||
val JDBC_PRINCIPAL = newOption("principal")
|
||||
val JDBC_TABLE_COMMENT = newOption("tableComment")
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
|
|||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, FieldReference, Max, Min, Sum}
|
||||
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
|
||||
import org.apache.spark.sql.sources._
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -133,6 +134,34 @@ object JDBCRDD extends Logging {
|
|||
})
|
||||
}
|
||||
|
||||
def compileAggregates(
|
||||
aggregates: Seq[AggregateFunc],
|
||||
dialect: JdbcDialect): Seq[String] = {
|
||||
def quote(colName: String): String = dialect.quoteIdentifier(colName)
|
||||
|
||||
aggregates.map {
|
||||
case min: Min =>
|
||||
assert(min.column.fieldNames.length == 1)
|
||||
s"MIN(${quote(min.column.fieldNames.head)})"
|
||||
case max: Max =>
|
||||
assert(max.column.fieldNames.length == 1)
|
||||
s"MAX(${quote(max.column.fieldNames.head)})"
|
||||
case count: Count =>
|
||||
assert(count.column.fieldNames.length == 1)
|
||||
val distinct = if (count.isDinstinct) "DISTINCT" else ""
|
||||
val column = quote(count.column.fieldNames.head)
|
||||
s"COUNT($distinct $column)"
|
||||
case sum: Sum =>
|
||||
assert(sum.column.fieldNames.length == 1)
|
||||
val distinct = if (sum.isDinstinct) "DISTINCT" else ""
|
||||
val column = quote(sum.column.fieldNames.head)
|
||||
s"SUM($distinct $column)"
|
||||
case _: CountStar =>
|
||||
s"COUNT(1)"
|
||||
case _ => ""
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build and return JDBCRDD from the given information.
|
||||
*
|
||||
|
@ -143,6 +172,8 @@ object JDBCRDD extends Logging {
|
|||
* @param parts - An array of JDBCPartitions specifying partition ids and
|
||||
* per-partition WHERE clauses.
|
||||
* @param options - JDBC options that contains url, table and other information.
|
||||
* @param requiredSchema - The schema of the columns to SELECT.
|
||||
* @param aggregation - The pushed down aggregation
|
||||
*
|
||||
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
|
||||
*/
|
||||
|
@ -152,19 +183,27 @@ object JDBCRDD extends Logging {
|
|||
requiredColumns: Array[String],
|
||||
filters: Array[Filter],
|
||||
parts: Array[Partition],
|
||||
options: JDBCOptions): RDD[InternalRow] = {
|
||||
options: JDBCOptions,
|
||||
outputSchema: Option[StructType] = None,
|
||||
groupByColumns: Option[Array[FieldReference]] = None): RDD[InternalRow] = {
|
||||
val url = options.url
|
||||
val dialect = JdbcDialects.get(url)
|
||||
val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
|
||||
val quotedColumns = if (groupByColumns.isEmpty) {
|
||||
requiredColumns.map(colName => dialect.quoteIdentifier(colName))
|
||||
} else {
|
||||
// these are already quoted in JDBCScanBuilder
|
||||
requiredColumns
|
||||
}
|
||||
new JDBCRDD(
|
||||
sc,
|
||||
JdbcUtils.createConnectionFactory(options),
|
||||
pruneSchema(schema, requiredColumns),
|
||||
outputSchema.getOrElse(pruneSchema(schema, requiredColumns)),
|
||||
quotedColumns,
|
||||
filters,
|
||||
parts,
|
||||
url,
|
||||
options)
|
||||
options,
|
||||
groupByColumns)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -181,7 +220,8 @@ private[jdbc] class JDBCRDD(
|
|||
filters: Array[Filter],
|
||||
partitions: Array[Partition],
|
||||
url: String,
|
||||
options: JDBCOptions)
|
||||
options: JDBCOptions,
|
||||
groupByColumns: Option[Array[FieldReference]])
|
||||
extends RDD[InternalRow](sc, Nil) {
|
||||
|
||||
/**
|
||||
|
@ -221,6 +261,20 @@ private[jdbc] class JDBCRDD(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A GROUP BY clause representing pushed-down grouping columns.
|
||||
*/
|
||||
private def getGroupByClause: String = {
|
||||
if (groupByColumns.nonEmpty && groupByColumns.get.nonEmpty) {
|
||||
assert(groupByColumns.get.forall(_.fieldNames.length == 1))
|
||||
val dialect = JdbcDialects.get(url)
|
||||
val quotedColumns = groupByColumns.get.map(c => dialect.quoteIdentifier(c.fieldNames.head))
|
||||
s"GROUP BY ${quotedColumns.mkString(", ")}"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the SQL query against the JDBC driver.
|
||||
*
|
||||
|
@ -296,7 +350,8 @@ private[jdbc] class JDBCRDD(
|
|||
|
||||
val myWhereClause = getWhereClause(part)
|
||||
|
||||
val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause"
|
||||
val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" +
|
||||
s" $getGroupByClause"
|
||||
stmt = conn.prepareStatement(sqlText,
|
||||
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
|
||||
stmt.setFetchSize(options.fetchSize)
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
|
|||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
|
||||
import org.apache.spark.sql.connector.expressions.FieldReference
|
||||
import org.apache.spark.sql.errors.QueryCompilationErrors
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.jdbc.JdbcDialects
|
||||
|
@ -288,6 +289,23 @@ private[sql] case class JDBCRelation(
|
|||
jdbcOptions).asInstanceOf[RDD[Row]]
|
||||
}
|
||||
|
||||
def buildScan(
|
||||
requiredColumns: Array[String],
|
||||
requireSchema: Option[StructType],
|
||||
filters: Array[Filter],
|
||||
groupByColumns: Option[Array[FieldReference]]): RDD[Row] = {
|
||||
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
|
||||
JDBCRDD.scanTable(
|
||||
sparkSession.sparkContext,
|
||||
schema,
|
||||
requiredColumns,
|
||||
filters,
|
||||
parts,
|
||||
jdbcOptions,
|
||||
requireSchema,
|
||||
groupByColumns).asInstanceOf[RDD[Row]]
|
||||
}
|
||||
|
||||
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
|
||||
data.write
|
||||
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
|
||||
|
|
|
@ -87,7 +87,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
|
|||
|
||||
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
|
||||
case PhysicalOperation(project, filters,
|
||||
relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated, pushed), output)) =>
|
||||
DataSourceV2ScanRelation(_, V1ScanWrapper(scan, pushed, aggregate), output)) =>
|
||||
val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext)
|
||||
if (v1Relation.schema != scan.readSchema()) {
|
||||
throw QueryExecutionErrors.fallbackV1RelationReportsInconsistentSchemaError(
|
||||
|
@ -98,8 +98,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
|
|||
val dsScan = RowDataSourceScanExec(
|
||||
output,
|
||||
output.toStructType,
|
||||
translated.toSet,
|
||||
Set.empty,
|
||||
pushed.toSet,
|
||||
aggregate,
|
||||
unsafeRowRDD,
|
||||
v1Relation,
|
||||
tableIdentifier = None)
|
||||
|
|
|
@ -20,9 +20,13 @@ package org.apache.spark.sql.execution.datasources.v2
|
|||
import scala.collection.mutable
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning}
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
|
||||
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
|
||||
import org.apache.spark.sql.connector.expressions.{Aggregation, FieldReference}
|
||||
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
|
||||
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
|
||||
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
|
||||
import org.apache.spark.sql.execution.datasources.PushableColumnWithoutNestedColumn
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.sources
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
@ -70,6 +74,42 @@ object PushDownUtils extends PredicateHelper {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Pushes down aggregates to the data source reader
|
||||
*
|
||||
* @return pushed aggregation.
|
||||
*/
|
||||
def pushAggregates(
|
||||
scanBuilder: ScanBuilder,
|
||||
aggregates: Seq[AggregateExpression],
|
||||
groupBy: Seq[Expression]): Option[Aggregation] = {
|
||||
|
||||
def columnAsString(e: Expression): Option[FieldReference] = e match {
|
||||
case PushableColumnWithoutNestedColumn(name) =>
|
||||
Some(FieldReference(name).asInstanceOf[FieldReference])
|
||||
case _ => None
|
||||
}
|
||||
|
||||
scanBuilder match {
|
||||
case r: SupportsPushDownAggregates =>
|
||||
val translatedAggregates = aggregates.map(DataSourceStrategy.translateAggregate).flatten
|
||||
val translatedGroupBys = groupBy.map(columnAsString).flatten
|
||||
|
||||
if (translatedAggregates.length != aggregates.length ||
|
||||
translatedGroupBys.length != groupBy.length) {
|
||||
return None
|
||||
}
|
||||
|
||||
val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)
|
||||
if (r.pushAggregation(agg)) {
|
||||
Some(agg)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies column pruning to the data source, w.r.t. the references of the given expressions.
|
||||
*
|
||||
|
|
|
@ -17,23 +17,36 @@
|
|||
|
||||
package org.apache.spark.sql.execution.datasources.v2
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression}
|
||||
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
|
||||
import org.apache.spark.sql.catalyst.planning.ScanOperation
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project}
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
import org.apache.spark.sql.connector.read.{Scan, V1Scan}
|
||||
import org.apache.spark.sql.connector.expressions.Aggregation
|
||||
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
|
||||
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
|
||||
import org.apache.spark.sql.sources
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
object V2ScanRelationPushDown extends Rule[LogicalPlan] {
|
||||
object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
|
||||
import DataSourceV2Implicits._
|
||||
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
|
||||
case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
|
||||
val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options)
|
||||
def apply(plan: LogicalPlan): LogicalPlan = {
|
||||
applyColumnPruning(pushdownAggregate(pushDownFilters(createScanBuilder(plan))))
|
||||
}
|
||||
|
||||
val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output)
|
||||
private def createScanBuilder(plan: LogicalPlan) = plan.transform {
|
||||
case r: DataSourceV2Relation =>
|
||||
ScanBuilderHolder(r.output, r, r.table.asReadable.newScanBuilder(r.options))
|
||||
}
|
||||
|
||||
private def pushDownFilters(plan: LogicalPlan) = plan.transform {
|
||||
// update the scan builder with filter push down and return a new plan with filter pushed
|
||||
case Filter(condition, sHolder: ScanBuilderHolder) =>
|
||||
val filters = splitConjunctivePredicates(condition)
|
||||
val normalizedFilters =
|
||||
DataSourceStrategy.normalizeExprs(filters, sHolder.relation.output)
|
||||
val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) =
|
||||
normalizedFilters.partition(SubqueryExpression.hasSubquery)
|
||||
|
||||
|
@ -41,37 +54,142 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] {
|
|||
// `postScanFilters` need to be evaluated after the scan.
|
||||
// `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter.
|
||||
val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters(
|
||||
scanBuilder, normalizedFiltersWithoutSubquery)
|
||||
sHolder.builder, normalizedFiltersWithoutSubquery)
|
||||
val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery
|
||||
|
||||
val normalizedProjects = DataSourceStrategy
|
||||
.normalizeExprs(project, relation.output)
|
||||
.asInstanceOf[Seq[NamedExpression]]
|
||||
val (scan, output) = PushDownUtils.pruneColumns(
|
||||
scanBuilder, relation, normalizedProjects, postScanFilters)
|
||||
logInfo(
|
||||
s"""
|
||||
|Pushing operators to ${relation.name}
|
||||
|Pushing operators to ${sHolder.relation.name}
|
||||
|Pushed Filters: ${pushedFilters.mkString(", ")}
|
||||
|Post-Scan Filters: ${postScanFilters.mkString(",")}
|
||||
""".stripMargin)
|
||||
|
||||
val filterCondition = postScanFilters.reduceLeftOption(And)
|
||||
filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder)
|
||||
}
|
||||
|
||||
def pushdownAggregate(plan: LogicalPlan): LogicalPlan = plan.transform {
|
||||
// update the scan builder with agg pushdown and return a new plan with agg pushed
|
||||
case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) =>
|
||||
child match {
|
||||
case ScanOperation(project, filters, sHolder: ScanBuilderHolder)
|
||||
if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) =>
|
||||
sHolder.builder match {
|
||||
case _: SupportsPushDownAggregates =>
|
||||
val aggregates = resultExpressions.flatMap { expr =>
|
||||
expr.collect {
|
||||
case agg: AggregateExpression => agg
|
||||
}
|
||||
}
|
||||
val pushedAggregates = PushDownUtils
|
||||
.pushAggregates(sHolder.builder, aggregates, groupingExpressions)
|
||||
if (pushedAggregates.isEmpty) {
|
||||
aggNode // return original plan node
|
||||
} else {
|
||||
// No need to do column pruning because only the aggregate columns are used as
|
||||
// DataSourceV2ScanRelation output columns. All the other columns are not
|
||||
// included in the output.
|
||||
val scan = sHolder.builder.build()
|
||||
|
||||
// scalastyle:off
|
||||
// use the group by columns and aggregate columns as the output columns
|
||||
// e.g. TABLE t (c1 INT, c2 INT, c3 INT)
|
||||
// SELECT min(c1), max(c1) FROM t GROUP BY c2;
|
||||
// Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation
|
||||
// We want to have the following logical plan:
|
||||
// == Optimized Logical Plan ==
|
||||
// Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
|
||||
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
|
||||
// scalastyle:on
|
||||
val newOutput = scan.readSchema().toAttributes
|
||||
assert(newOutput.length == groupingExpressions.length + aggregates.length)
|
||||
val groupAttrs = groupingExpressions.zip(newOutput).map {
|
||||
case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
|
||||
case (_, b) => b
|
||||
}
|
||||
val output = groupAttrs ++ newOutput.drop(groupAttrs.length)
|
||||
|
||||
logInfo(
|
||||
s"""
|
||||
|Pushing operators to ${sHolder.relation.name}
|
||||
|Pushed Aggregate Functions:
|
||||
| ${pushedAggregates.get.aggregateExpressions.mkString(", ")}
|
||||
|Pushed Group by:
|
||||
| ${pushedAggregates.get.groupByColumns.mkString(", ")}
|
||||
|Output: ${output.mkString(", ")}
|
||||
""".stripMargin)
|
||||
|
||||
val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates)
|
||||
|
||||
val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
|
||||
|
||||
val plan = Aggregate(
|
||||
output.take(groupingExpressions.length), resultExpressions, scanRelation)
|
||||
|
||||
// scalastyle:off
|
||||
// Change the optimized logical plan to reflect the pushed down aggregate
|
||||
// e.g. TABLE t (c1 INT, c2 INT, c3 INT)
|
||||
// SELECT min(c1), max(c1) FROM t GROUP BY c2;
|
||||
// The original logical plan is
|
||||
// Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
|
||||
// +- RelationV2[c1#9, c2#10] ...
|
||||
//
|
||||
// After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22]
|
||||
// we have the following
|
||||
// !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
|
||||
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
|
||||
//
|
||||
// We want to change it to
|
||||
// == Optimized Logical Plan ==
|
||||
// Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
|
||||
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
|
||||
// scalastyle:on
|
||||
var i = 0
|
||||
val aggOutput = output.drop(groupAttrs.length)
|
||||
plan.transformExpressions {
|
||||
case agg: AggregateExpression =>
|
||||
val aggFunction: aggregate.AggregateFunction =
|
||||
agg.aggregateFunction match {
|
||||
case max: aggregate.Max => max.copy(child = aggOutput(i))
|
||||
case min: aggregate.Min => min.copy(child = aggOutput(i))
|
||||
case sum: aggregate.Sum => sum.copy(child = aggOutput(i))
|
||||
case _: aggregate.Count => aggregate.Sum(aggOutput(i))
|
||||
case other => other
|
||||
}
|
||||
i += 1
|
||||
agg.copy(aggregateFunction = aggFunction)
|
||||
}
|
||||
}
|
||||
case _ => aggNode
|
||||
}
|
||||
case _ => aggNode
|
||||
}
|
||||
}
|
||||
|
||||
def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform {
|
||||
case ScanOperation(project, filters, sHolder: ScanBuilderHolder) =>
|
||||
// column pruning
|
||||
val normalizedProjects = DataSourceStrategy
|
||||
.normalizeExprs(project, sHolder.output)
|
||||
.asInstanceOf[Seq[NamedExpression]]
|
||||
val (scan, output) = PushDownUtils.pruneColumns(
|
||||
sHolder.builder, sHolder.relation, normalizedProjects, filters)
|
||||
|
||||
logInfo(
|
||||
s"""
|
||||
|Output: ${output.mkString(", ")}
|
||||
""".stripMargin)
|
||||
|
||||
val wrappedScan = scan match {
|
||||
case v1: V1Scan =>
|
||||
val translated = filters.flatMap(DataSourceStrategy.translateFilter(_, true))
|
||||
V1ScanWrapper(v1, translated, pushedFilters)
|
||||
case _ => scan
|
||||
}
|
||||
val wrappedScan = getWrappedScan(scan, sHolder, Option.empty[Aggregation])
|
||||
|
||||
val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan, output)
|
||||
val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
|
||||
|
||||
val projectionOverSchema = ProjectionOverSchema(output.toStructType)
|
||||
val projectionFunc = (expr: Expression) => expr transformDown {
|
||||
case projectionOverSchema(newExpr) => newExpr
|
||||
}
|
||||
|
||||
val filterCondition = postScanFilters.reduceLeftOption(And)
|
||||
val filterCondition = filters.reduceLeftOption(And)
|
||||
val newFilterCondition = filterCondition.map(projectionFunc)
|
||||
val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation)
|
||||
|
||||
|
@ -83,16 +201,36 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] {
|
|||
} else {
|
||||
withFilter
|
||||
}
|
||||
|
||||
withProjection
|
||||
}
|
||||
|
||||
private def getWrappedScan(
|
||||
scan: Scan,
|
||||
sHolder: ScanBuilderHolder,
|
||||
aggregation: Option[Aggregation]): Scan = {
|
||||
scan match {
|
||||
case v1: V1Scan =>
|
||||
val pushedFilters = sHolder.builder match {
|
||||
case f: SupportsPushDownFilters =>
|
||||
f.pushedFilters()
|
||||
case _ => Array.empty[sources.Filter]
|
||||
}
|
||||
V1ScanWrapper(v1, pushedFilters, aggregation)
|
||||
case _ => scan
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class ScanBuilderHolder(
|
||||
output: Seq[AttributeReference],
|
||||
relation: DataSourceV2Relation,
|
||||
builder: ScanBuilder) extends LeafNode
|
||||
|
||||
// A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by
|
||||
// the physical v1 scan node.
|
||||
case class V1ScanWrapper(
|
||||
v1Scan: V1Scan,
|
||||
translatedFilters: Seq[sources.Filter],
|
||||
handledFilters: Seq[sources.Filter]) extends Scan {
|
||||
handledFilters: Seq[sources.Filter],
|
||||
pushedAggregate: Option[Aggregation]) extends Scan {
|
||||
override def readSchema(): StructType = v1Scan.readSchema()
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc
|
|||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{Row, SQLContext}
|
||||
import org.apache.spark.sql.connector.expressions.FieldReference
|
||||
import org.apache.spark.sql.connector.read.V1Scan
|
||||
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
|
||||
import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan}
|
||||
|
@ -26,7 +27,9 @@ import org.apache.spark.sql.types.StructType
|
|||
case class JDBCScan(
|
||||
relation: JDBCRelation,
|
||||
prunedSchema: StructType,
|
||||
pushedFilters: Array[Filter]) extends V1Scan {
|
||||
pushedFilters: Array[Filter],
|
||||
pushedAggregateColumn: Array[String] = Array(),
|
||||
groupByColumns: Option[Array[FieldReference]]) extends V1Scan {
|
||||
|
||||
override def readSchema(): StructType = prunedSchema
|
||||
|
||||
|
@ -36,14 +39,28 @@ case class JDBCScan(
|
|||
override def schema: StructType = prunedSchema
|
||||
override def needConversion: Boolean = relation.needConversion
|
||||
override def buildScan(): RDD[Row] = {
|
||||
relation.buildScan(prunedSchema.map(_.name).toArray, pushedFilters)
|
||||
if (groupByColumns.isEmpty) {
|
||||
relation.buildScan(
|
||||
prunedSchema.map(_.name).toArray, Some(prunedSchema), pushedFilters, groupByColumns)
|
||||
} else {
|
||||
relation.buildScan(
|
||||
pushedAggregateColumn, Some(prunedSchema), pushedFilters, groupByColumns)
|
||||
}
|
||||
}
|
||||
}.asInstanceOf[T]
|
||||
}
|
||||
|
||||
override def description(): String = {
|
||||
val (aggString, groupByString) = if (groupByColumns.nonEmpty) {
|
||||
val groupByColumnsLength = groupByColumns.get.length
|
||||
(seqToString(pushedAggregateColumn.drop(groupByColumnsLength)),
|
||||
seqToString(pushedAggregateColumn.take(groupByColumnsLength)))
|
||||
} else {
|
||||
("[]", "[]")
|
||||
}
|
||||
super.description() + ", prunedSchema: " + seqToString(prunedSchema) +
|
||||
", PushedFilters: " + seqToString(pushedFilters)
|
||||
", PushedFilters: " + seqToString(pushedFilters) +
|
||||
", PushedAggregates: " + aggString + ", PushedGroupBy: " + groupByString
|
||||
}
|
||||
|
||||
private def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")
|
||||
|
|
|
@ -17,18 +17,20 @@
|
|||
package org.apache.spark.sql.execution.datasources.v2.jdbc
|
||||
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
|
||||
import org.apache.spark.sql.connector.expressions.{Aggregation, Count, CountStar, FieldReference, Max, Min, Sum}
|
||||
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
|
||||
import org.apache.spark.sql.execution.datasources.PartitioningUtils
|
||||
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation}
|
||||
import org.apache.spark.sql.jdbc.JdbcDialects
|
||||
import org.apache.spark.sql.sources.Filter
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.sql.types.{LongType, StructField, StructType}
|
||||
|
||||
case class JDBCScanBuilder(
|
||||
session: SparkSession,
|
||||
schema: StructType,
|
||||
jdbcOptions: JDBCOptions)
|
||||
extends ScanBuilder with SupportsPushDownFilters with SupportsPushDownRequiredColumns {
|
||||
extends ScanBuilder with SupportsPushDownFilters with SupportsPushDownRequiredColumns
|
||||
with SupportsPushDownAggregates{
|
||||
|
||||
private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis
|
||||
|
||||
|
@ -49,6 +51,58 @@ case class JDBCScanBuilder(
|
|||
|
||||
override def pushedFilters(): Array[Filter] = pushedFilter
|
||||
|
||||
private var pushedAggregations = Option.empty[Aggregation]
|
||||
|
||||
private var pushedAggregateColumn: Array[String] = Array()
|
||||
|
||||
private def getStructFieldForCol(col: FieldReference): StructField =
|
||||
schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
|
||||
|
||||
override def pushAggregation(aggregation: Aggregation): Boolean = {
|
||||
if (!jdbcOptions.pushDownAggregate) return false
|
||||
|
||||
val dialect = JdbcDialects.get(jdbcOptions.url)
|
||||
val compiledAgg = JDBCRDD.compileAggregates(aggregation.aggregateExpressions, dialect)
|
||||
|
||||
var outputSchema = new StructType()
|
||||
aggregation.groupByColumns.foreach { col =>
|
||||
val structField = getStructFieldForCol(col)
|
||||
outputSchema = outputSchema.add(structField)
|
||||
pushedAggregateColumn = pushedAggregateColumn :+ dialect.quoteIdentifier(structField.name)
|
||||
}
|
||||
|
||||
// The column names here are already quoted and can be used to build sql string directly.
|
||||
// e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") =>
|
||||
// SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee"
|
||||
// GROUP BY "DEPT", "NAME"
|
||||
pushedAggregateColumn = pushedAggregateColumn ++ compiledAgg
|
||||
|
||||
aggregation.aggregateExpressions.foreach {
|
||||
case max: Max =>
|
||||
val structField = getStructFieldForCol(max.column)
|
||||
outputSchema = outputSchema.add(structField.copy("max(" + structField.name + ")"))
|
||||
case min: Min =>
|
||||
val structField = getStructFieldForCol(min.column)
|
||||
outputSchema = outputSchema.add(structField.copy("min(" + structField.name + ")"))
|
||||
case count: Count =>
|
||||
val distinct = if (count.isDinstinct) "DISTINCT " else ""
|
||||
val structField = getStructFieldForCol(count.column)
|
||||
outputSchema =
|
||||
outputSchema.add(StructField(s"count($distinct" + structField.name + ")", LongType))
|
||||
case _: CountStar =>
|
||||
outputSchema = outputSchema.add(StructField("count(*)", LongType))
|
||||
case sum: Sum =>
|
||||
val distinct = if (sum.isDinstinct) "DISTINCT " else ""
|
||||
val structField = getStructFieldForCol(sum.column)
|
||||
outputSchema =
|
||||
outputSchema.add(StructField(s"sum($distinct" + structField.name + ")", sum.dataType))
|
||||
case _ => return false
|
||||
}
|
||||
this.pushedAggregations = Some(aggregation)
|
||||
prunedSchema = outputSchema
|
||||
true
|
||||
}
|
||||
|
||||
override def pruneColumns(requiredSchema: StructType): Unit = {
|
||||
// JDBC doesn't support nested column pruning.
|
||||
// TODO (SPARK-32593): JDBC support nested column and nested column pruning.
|
||||
|
@ -65,6 +119,20 @@ case class JDBCScanBuilder(
|
|||
val resolver = session.sessionState.conf.resolver
|
||||
val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
|
||||
val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions)
|
||||
JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), prunedSchema, pushedFilter)
|
||||
|
||||
// in prunedSchema, the schema is either pruned in pushAggregation (if aggregates are
|
||||
// pushed down), or pruned in pruneColumns (in regular column pruning). These
|
||||
// two are mutual exclusive.
|
||||
// For aggregate push down case, we want to pass down the quoted column lists such as
|
||||
// "DEPT","NAME",MAX("SALARY"),MIN("BONUS"), instead of getting column names from
|
||||
// prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't
|
||||
// be used in sql string.
|
||||
val groupByColumns = if (pushedAggregations.nonEmpty) {
|
||||
Some(pushedAggregations.get.groupByColumns)
|
||||
} else {
|
||||
Option.empty[Array[FieldReference]]
|
||||
}
|
||||
JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), prunedSchema, pushedFilter,
|
||||
pushedAggregateColumn, groupByColumns)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,16 +21,16 @@ import java.sql.{Connection, DriverManager}
|
|||
import java.util.Properties
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.sql.{QueryTest, Row}
|
||||
import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest, Row}
|
||||
import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
|
||||
import org.apache.spark.sql.catalyst.plans.logical.Filter
|
||||
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
|
||||
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
|
||||
import org.apache.spark.sql.functions.lit
|
||||
import org.apache.spark.sql.functions.{lit, sum, udf}
|
||||
import org.apache.spark.sql.test.SharedSparkSession
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class JDBCV2Suite extends QueryTest with SharedSparkSession {
|
||||
class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHelper {
|
||||
import testImplicits._
|
||||
|
||||
val tempDir = Utils.createTempDir()
|
||||
|
@ -41,6 +41,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
|
|||
.set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName)
|
||||
.set("spark.sql.catalog.h2.url", url)
|
||||
.set("spark.sql.catalog.h2.driver", "org.h2.Driver")
|
||||
.set("spark.sql.catalog.h2.pushDownAggregate", "true")
|
||||
|
||||
private def withConnection[T](f: Connection => T): T = {
|
||||
val conn = DriverManager.getConnection(url, new Properties())
|
||||
|
@ -64,6 +65,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
|
|||
.executeUpdate()
|
||||
conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('fred', 1)").executeUpdate()
|
||||
conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('mary', 2)").executeUpdate()
|
||||
conn.prepareStatement(
|
||||
"CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32), salary NUMERIC(20, 2)," +
|
||||
" bonus DOUBLE)").executeUpdate()
|
||||
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000)")
|
||||
.executeUpdate()
|
||||
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200)")
|
||||
.executeUpdate()
|
||||
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200)")
|
||||
.executeUpdate()
|
||||
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300)")
|
||||
.executeUpdate()
|
||||
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)")
|
||||
.executeUpdate()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -84,6 +98,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
|
|||
case f: Filter => f
|
||||
}
|
||||
assert(filters.isEmpty)
|
||||
|
||||
df.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedFilters: [IsNotNull(ID), GreaterThan(ID,1)]"
|
||||
checkKeywordsExistsInExplain(df, expected_plan_fragment)
|
||||
}
|
||||
|
||||
checkAnswer(df, Row("mary", 2))
|
||||
}
|
||||
|
||||
|
@ -145,7 +167,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
|
|||
|
||||
test("show tables") {
|
||||
checkAnswer(sql("SHOW TABLES IN h2.test"),
|
||||
Seq(Row("test", "people", false), Row("test", "empty_table", false)))
|
||||
Seq(Row("test", "people", false), Row("test", "empty_table", false),
|
||||
Row("test", "employee", false)))
|
||||
}
|
||||
|
||||
test("SQL API: create table as select") {
|
||||
|
@ -214,4 +237,232 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
|
|||
checkAnswer(sql("SELECT name, id FROM h2.test.abc"), Row("bob", 4))
|
||||
}
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: MAX MIN with filter and group by") {
|
||||
val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" +
|
||||
" group by DEPT")
|
||||
val filters = df.queryExecution.optimizedPlan.collect {
|
||||
case f: Filter => f
|
||||
}
|
||||
assert(filters.isEmpty)
|
||||
df.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedAggregates: [Max(SALARY), Min(BONUS)], " +
|
||||
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
|
||||
"PushedGroupby: [DEPT]"
|
||||
checkKeywordsExistsInExplain(df, expected_plan_fragment)
|
||||
}
|
||||
checkAnswer(df, Seq(Row(10000, 1000), Row(12000, 1200), Row(12000, 1200)))
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: MAX MIN with filter without group by") {
|
||||
val df = sql("select MAX(ID), MIN(ID) FROM h2.test.people where id > 0")
|
||||
val filters = df.queryExecution.optimizedPlan.collect {
|
||||
case f: Filter => f
|
||||
}
|
||||
assert(filters.isEmpty)
|
||||
df.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedAggregates: [Max(ID), Min(ID)], " +
|
||||
"PushedFilters: [IsNotNull(ID), GreaterThan(ID,0)], " +
|
||||
"PushedGroupby: []"
|
||||
checkKeywordsExistsInExplain(df, expected_plan_fragment)
|
||||
}
|
||||
checkAnswer(df, Seq(Row(2, 1)))
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: aggregate + number") {
|
||||
val df = sql("select MAX(SALARY) + 1 FROM h2.test.employee")
|
||||
df.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedAggregates: [Max(SALARY)]"
|
||||
checkKeywordsExistsInExplain(df, expected_plan_fragment)
|
||||
}
|
||||
checkAnswer(df, Seq(Row(12001)))
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: COUNT(*)") {
|
||||
val df = sql("select COUNT(*) FROM h2.test.employee")
|
||||
df.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedAggregates: [CountStar()]"
|
||||
checkKeywordsExistsInExplain(df, expected_plan_fragment)
|
||||
}
|
||||
checkAnswer(df, Seq(Row(5)))
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: COUNT(col)") {
|
||||
val df = sql("select COUNT(DEPT) FROM h2.test.employee")
|
||||
df.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedAggregates: [Count(DEPT,false)]"
|
||||
checkKeywordsExistsInExplain(df, expected_plan_fragment)
|
||||
}
|
||||
checkAnswer(df, Seq(Row(5)))
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: COUNT(DISTINCT col)") {
|
||||
val df = sql("select COUNT(DISTINCT DEPT) FROM h2.test.employee")
|
||||
df.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedAggregates: [Count(DEPT,true)]"
|
||||
checkKeywordsExistsInExplain(df, expected_plan_fragment)
|
||||
}
|
||||
checkAnswer(df, Seq(Row(3)))
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: SUM without filer and group by") {
|
||||
val df = sql("SELECT SUM(SALARY) FROM h2.test.employee")
|
||||
df.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)]"
|
||||
checkKeywordsExistsInExplain(df, expected_plan_fragment)
|
||||
}
|
||||
checkAnswer(df, Seq(Row(53000)))
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: DISTINCT SUM without filer and group by") {
|
||||
val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee")
|
||||
df.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedAggregates: [Sum(SALARY,DecimalType(30,2),true)]"
|
||||
checkKeywordsExistsInExplain(df, expected_plan_fragment)
|
||||
}
|
||||
checkAnswer(df, Seq(Row(31000)))
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: SUM with group by") {
|
||||
val df = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT")
|
||||
df.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)], " +
|
||||
"PushedFilters: [], " +
|
||||
"PushedGroupby: [DEPT]"
|
||||
checkKeywordsExistsInExplain(df, expected_plan_fragment)
|
||||
}
|
||||
checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000)))
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: DISTINCT SUM with group by") {
|
||||
val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee GROUP BY DEPT")
|
||||
df.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedAggregates: [Sum(SALARY,DecimalType(30,2),true)], " +
|
||||
"PushedFilters: [], " +
|
||||
"PushedGroupby: [DEPT]"
|
||||
checkKeywordsExistsInExplain(df, expected_plan_fragment)
|
||||
}
|
||||
checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000)))
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: with multiple group by columns") {
|
||||
val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" +
|
||||
" group by DEPT, NAME")
|
||||
val filters11 = df.queryExecution.optimizedPlan.collect {
|
||||
case f: Filter => f
|
||||
}
|
||||
assert(filters11.isEmpty)
|
||||
df.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedAggregates: [Max(SALARY), Min(BONUS)], " +
|
||||
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
|
||||
"PushedGroupby: [DEPT, NAME]"
|
||||
checkKeywordsExistsInExplain(df, expected_plan_fragment)
|
||||
}
|
||||
checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300),
|
||||
Row(10000, 1000), Row(12000, 1200)))
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: with having clause") {
|
||||
val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" +
|
||||
" group by DEPT having MIN(BONUS) > 1000")
|
||||
val filters = df.queryExecution.optimizedPlan.collect {
|
||||
case f: Filter => f // filter over aggregate not push down
|
||||
}
|
||||
assert(filters.nonEmpty)
|
||||
df.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedAggregates: [Max(SALARY), Min(BONUS)], " +
|
||||
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
|
||||
"PushedGroupby: [DEPT]"
|
||||
checkKeywordsExistsInExplain(df, expected_plan_fragment)
|
||||
}
|
||||
checkAnswer(df, Seq(Row(12000, 1200), Row(12000, 1200)))
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: alias over aggregate") {
|
||||
val df = sql("select * from h2.test.employee")
|
||||
.groupBy($"DEPT")
|
||||
.min("SALARY").as("total")
|
||||
df.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedAggregates: [Min(SALARY)], " +
|
||||
"PushedFilters: [], " +
|
||||
"PushedGroupby: [DEPT]"
|
||||
checkKeywordsExistsInExplain(df, expected_plan_fragment)
|
||||
}
|
||||
checkAnswer(df, Seq(Row(1, 9000), Row(2, 10000), Row(6, 12000)))
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: order by alias over aggregate") {
|
||||
val df = spark.table("h2.test.employee")
|
||||
val query = df.select($"DEPT", $"SALARY")
|
||||
.filter($"DEPT" > 0)
|
||||
.groupBy($"DEPT")
|
||||
.agg(sum($"SALARY").as("total"))
|
||||
.filter($"total" > 1000)
|
||||
.orderBy($"total")
|
||||
val filters = query.queryExecution.optimizedPlan.collect {
|
||||
case f: Filter => f
|
||||
}
|
||||
assert(filters.nonEmpty) // filter over aggregate not pushed down
|
||||
query.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)], " +
|
||||
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
|
||||
"PushedGroupby: [DEPT]"
|
||||
checkKeywordsExistsInExplain(query, expected_plan_fragment)
|
||||
}
|
||||
checkAnswer(query, Seq(Row(6, 12000), Row(1, 19000), Row(2, 22000)))
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: udf over aggregate") {
|
||||
val df = spark.table("h2.test.employee")
|
||||
val decrease = udf { (x: Double, y: Double) => x - y }
|
||||
val query = df.select(decrease(sum($"SALARY"), sum($"BONUS")).as("value"))
|
||||
query.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedAggregates: [Sum(SALARY,DecimalType(30,2),false), Sum(BONUS,DoubleType,false)"
|
||||
checkKeywordsExistsInExplain(query, expected_plan_fragment)
|
||||
}
|
||||
checkAnswer(query, Seq(Row(47100.0)))
|
||||
}
|
||||
|
||||
test("scan with aggregate push-down: aggregate over alias") {
|
||||
val cols = Seq("a", "b", "c", "d")
|
||||
val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
|
||||
val df2 = df1.groupBy().sum("c")
|
||||
df2.queryExecution.optimizedPlan.collect {
|
||||
case _: DataSourceV2ScanRelation =>
|
||||
val expected_plan_fragment =
|
||||
"PushedAggregates: []" // aggregate over alias not push down
|
||||
checkKeywordsExistsInExplain(df2, expected_plan_fragment)
|
||||
}
|
||||
checkAnswer(df2, Seq(Row(53000.00)))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue