[SPARK-15616][SQL] Add optimizer rule PruneHiveTablePartitions

### What changes were proposed in this pull request?
Add optimizer rule PruneHiveTablePartitions pruning hive table partitions based on filters on partition columns.
Doing so, the total size of pruned partitions may be small enough for broadcast join in JoinSelection strategy.

### Why are the changes needed?
In JoinSelection strategy, spark use the "plan.stats.sizeInBytes" to decide whether the plan is suitable for broadcast join.
Currently, "plan.stats.sizeInBytes" does not take "pruned partitions" into account, so it may miss some broadcast join and take sort-merge join instead, which will definitely impact join performance.
This PR aim at taking "pruned partitions" into account for hive table in "plan.stats.sizeInBytes" and then improve performance by using broadcast join if possible.

### Does this PR introduce any user-facing change?
no

### How was this patch tested?
Added unit tests.

This is based on #25919, credits should go to lianhuiwang and advancedxy.

Closes #26805 from fuwhu/SPARK-15616.

Authored-by: fuwhu <bestwwg@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
fuwhu 2020-01-21 21:26:30 +08:00 committed by Wenchen Fan
parent ff39c9271c
commit cfb1706eaa
7 changed files with 193 additions and 8 deletions

View file

@ -651,7 +651,9 @@ case class HiveTableRelation(
tableMeta: CatalogTable,
dataCols: Seq[AttributeReference],
partitionCols: Seq[AttributeReference],
tableStats: Option[Statistics] = None) extends LeafNode with MultiInstanceRelation {
tableStats: Option[Statistics] = None,
@transient prunedPartitions: Option[Seq[CatalogTablePartition]] = None)
extends LeafNode with MultiInstanceRelation {
assert(tableMeta.identifier.database.isDefined)
assert(tableMeta.partitionSchema.sameType(partitionCols.toStructType))
assert(tableMeta.dataSchema.sameType(dataCols.toStructType))

View file

@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, FileScan, FileTable}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, FileScan}
import org.apache.spark.sql.types.StructType
private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {

View file

@ -84,7 +84,7 @@ object TPCDSQueryBenchmark extends SqlBasedBenchmark {
queryRelations.add(alias.identifier)
case LogicalRelation(_, _, Some(catalogTable), _) =>
queryRelations.add(catalogTable.identifier.table)
case HiveTableRelation(tableMeta, _, _, _) =>
case HiveTableRelation(tableMeta, _, _, _, _) =>
queryRelations.add(tableMeta.identifier.table)
case _ =>
}

View file

@ -21,13 +21,15 @@ import org.apache.spark.annotation.Unstable
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.{Analyzer, ResolveSessionCatalog}
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlanner
import org.apache.spark.sql.execution.{SparkOptimizer, SparkPlanner}
import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.v2.TableCapabilityCheck
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.hive.execution.PruneHiveTablePartitions
import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState}
/**
@ -93,6 +95,20 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session
customCheckRules
}
/**
* Logical query plan optimizer that takes into account Hive.
*/
override protected def optimizer: Optimizer = {
new SparkOptimizer(catalogManager, catalog, experimentalMethods) {
override def postHocOptimizationBatches: Seq[Batch] = Seq(
Batch("Prune Hive Table Partitions", Once, new PruneHiveTablePartitions(session))
)
override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] =
super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules
}
}
/**
* Planner that takes into account Hive-specific strategies.
*/

View file

@ -166,14 +166,14 @@ case class HiveTableScanExec(
@transient lazy val rawPartitions = {
val prunedPartitions =
if (sparkSession.sessionState.conf.metastorePartitionPruning &&
partitionPruningPred.size > 0) {
partitionPruningPred.nonEmpty) {
// Retrieve the original attributes based on expression ID so that capitalization matches.
val normalizedFilters = partitionPruningPred.map(_.transform {
case a: AttributeReference => originalAttributes(a)
})
sparkSession.sessionState.catalog.listPartitionsByFilter(
relation.tableMeta.identifier,
normalizedFilters)
relation.prunedPartitions.getOrElse(
sparkSession.sessionState.catalog
.listPartitionsByFilter(relation.tableMeta.identifier, normalizedFilters))
} else {
sparkSession.sessionState.catalog.listPartitions(relation.tableMeta.identifier)
}

View file

@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hive.execution
import org.apache.hadoop.hive.common.StatsSetupConst
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.CastSupport
import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition, ExternalCatalogUtils, HiveTableRelation}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.internal.SQLConf
/**
* TODO: merge this with PruneFileSourcePartitions after we completely make hive as a data source.
*/
private[sql] class PruneHiveTablePartitions(session: SparkSession)
extends Rule[LogicalPlan] with CastSupport {
override val conf: SQLConf = session.sessionState.conf
/**
* Extract the partition filters from the filters on the table.
*/
private def getPartitionKeyFilters(
filters: Seq[Expression],
relation: HiveTableRelation): ExpressionSet = {
val normalizedFilters = DataSourceStrategy.normalizeExprs(
filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), relation.output)
val partitionColumnSet = AttributeSet(relation.partitionCols)
ExpressionSet(normalizedFilters.filter { f =>
!f.references.isEmpty && f.references.subsetOf(partitionColumnSet)
})
}
/**
* Prune the hive table using filters on the partitions of the table.
*/
private def prunePartitions(
relation: HiveTableRelation,
partitionFilters: ExpressionSet): Seq[CatalogTablePartition] = {
if (conf.metastorePartitionPruning) {
session.sessionState.catalog.listPartitionsByFilter(
relation.tableMeta.identifier, partitionFilters.toSeq)
} else {
ExternalCatalogUtils.prunePartitionsByFilter(relation.tableMeta,
session.sessionState.catalog.listPartitions(relation.tableMeta.identifier),
partitionFilters.toSeq, conf.sessionLocalTimeZone)
}
}
/**
* Update the statistics of the table.
*/
private def updateTableMeta(
tableMeta: CatalogTable,
prunedPartitions: Seq[CatalogTablePartition]): CatalogTable = {
val sizeOfPartitions = prunedPartitions.map { partition =>
val rawDataSize = partition.parameters.get(StatsSetupConst.RAW_DATA_SIZE).map(_.toLong)
val totalSize = partition.parameters.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong)
if (rawDataSize.isDefined && rawDataSize.get > 0) {
rawDataSize.get
} else if (totalSize.isDefined && totalSize.get > 0L) {
totalSize.get
} else {
0L
}
}
if (sizeOfPartitions.forall(_ > 0)) {
val sizeInBytes = sizeOfPartitions.sum
tableMeta.copy(stats = Some(CatalogStatistics(sizeInBytes = BigInt(sizeInBytes))))
} else {
tableMeta
}
}
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case op @ PhysicalOperation(projections, filters, relation: HiveTableRelation)
if filters.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty =>
val partitionKeyFilters = getPartitionKeyFilters(filters, relation)
if (partitionKeyFilters.nonEmpty) {
val newPartitions = prunePartitions(relation, partitionKeyFilters)
val newTableMeta = updateTableMeta(relation.tableMeta, newPartitions)
val newRelation = relation.copy(
tableMeta = newTableMeta, prunedPartitions = Some(newPartitions))
// Keep partition filters so that they are visible in physical planning
Project(projections, Filter(filters.reduceLeft(And), newRelation))
} else {
op
}
}
}

View file

@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("PruneHiveTablePartitions", Once,
EliminateSubqueryAliases, new PruneHiveTablePartitions(spark)) :: Nil
}
test("SPARK-15616 statistics pruned after going throuhg PruneHiveTablePartitions") {
withTable("test", "temp") {
sql(
s"""
|CREATE TABLE test(i int)
|PARTITIONED BY (p int)
|STORED AS textfile""".stripMargin)
spark.range(0, 1000, 1).selectExpr("id as col")
.createOrReplaceTempView("temp")
for (part <- Seq(1, 2, 3, 4)) {
sql(
s"""
|INSERT OVERWRITE TABLE test PARTITION (p='$part')
|select col from temp""".stripMargin)
}
val analyzed1 = sql("select i from test where p > 0").queryExecution.analyzed
val analyzed2 = sql("select i from test where p = 1").queryExecution.analyzed
assert(Optimize.execute(analyzed1).stats.sizeInBytes / 4 ===
Optimize.execute(analyzed2).stats.sizeInBytes)
}
}
}