[SPARK-23803][SQL] Support bucket pruning

## What changes were proposed in this pull request?
support bucket pruning when filtering on a single bucketed column on the following predicates -
EqualTo, EqualNullSafe, In, And/Or predicates

## How was this patch tested?
refactored unit tests to test the above.

based on gatorsmile work in e3c75c6398

Author: Asher Saban <asaban@palantir.com>
Author: asaban <asaban@palantir.com>

Closes #20915 from sabanas/filter-prune-buckets.
This commit is contained in:
Asher Saban 2018-06-06 07:14:08 -07:00 committed by Wenchen Fan
parent e9efb62e07
commit e76b0124fb
5 changed files with 231 additions and 35 deletions

View file

@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.sources.{BaseRelation, Filter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.BitSet
trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
val relation: BaseRelation
@ -151,6 +152,7 @@ case class RowDataSourceScanExec(
* @param output Output attributes of the scan, including data attributes and partition attributes.
* @param requiredSchema Required schema of the underlying relation, excluding partition columns.
* @param partitionFilters Predicates to use for partition pruning.
* @param optionalBucketSet Bucket ids for bucket pruning
* @param dataFilters Filters on non-partition columns.
* @param tableIdentifier identifier for the table in the metastore.
*/
@ -159,6 +161,7 @@ case class FileSourceScanExec(
output: Seq[Attribute],
requiredSchema: StructType,
partitionFilters: Seq[Expression],
optionalBucketSet: Option[BitSet],
dataFilters: Seq[Expression],
override val tableIdentifier: Option[TableIdentifier])
extends DataSourceScanExec with ColumnarBatchScan {
@ -286,7 +289,20 @@ case class FileSourceScanExec(
} getOrElse {
metadata
}
withOptPartitionCount
val withSelectedBucketsCount = relation.bucketSpec.map { spec =>
val numSelectedBuckets = optionalBucketSet.map { b =>
b.cardinality()
} getOrElse {
spec.numBuckets
}
withOptPartitionCount + ("SelectedBucketsCount" ->
s"$numSelectedBuckets out of ${spec.numBuckets}")
} getOrElse {
withOptPartitionCount
}
withSelectedBucketsCount
}
private lazy val inputRDD: RDD[InternalRow] = {
@ -365,7 +381,7 @@ case class FileSourceScanExec(
selectedPartitions: Seq[PartitionDirectory],
fsRelation: HadoopFsRelation): RDD[InternalRow] = {
logInfo(s"Planning with ${bucketSpec.numBuckets} buckets")
val bucketed =
val filesGroupedToBuckets =
selectedPartitions.flatMap { p =>
p.files.map { f =>
val hosts = getBlockHosts(getBlockLocations(f), 0, f.getLen)
@ -377,8 +393,17 @@ case class FileSourceScanExec(
.getOrElse(sys.error(s"Invalid bucket file ${f.filePath}"))
}
val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) {
val bucketSet = optionalBucketSet.get
filesGroupedToBuckets.filter {
f => bucketSet.get(f._1)
}
} else {
filesGroupedToBuckets
}
val filePartitions = Seq.tabulate(bucketSpec.numBuckets) { bucketId =>
FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil))
FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Nil))
}
new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions)
@ -503,6 +528,7 @@ case class FileSourceScanExec(
output.map(QueryPlan.normalizeExprId(_, output)),
requiredSchema,
QueryPlan.normalizePredicates(partitionFilters, output),
optionalBucketSet,
QueryPlan.normalizePredicates(dataFilters, output),
None)
}

View file

@ -17,6 +17,9 @@
package org.apache.spark.sql.execution.datasources
import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
object BucketingUtils {
// The file name of bucketed data should have 3 parts:
// 1. some other information in the head of file name
@ -35,5 +38,16 @@ object BucketingUtils {
case other => None
}
// Given bucketColumn, numBuckets and value, returns the corresponding bucketId
def getBucketIdFromValue(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
val mutableInternalRow = new SpecificInternalRow(Seq(bucketColumn.dataType))
mutableInternalRow.update(0, value)
val bucketIdGenerator = UnsafeProjection.create(
HashPartitioning(Seq(bucketColumn), numBuckets).partitionIdExpression :: Nil,
bucketColumn :: Nil)
bucketIdGenerator(mutableInternalRow).getInt(0)
}
def bucketIdToString(id: Int): String = f"_$id%05d"
}

View file

@ -312,18 +312,6 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with
case _ => Nil
}
// Get the bucket ID based on the bucketing values.
// Restriction: Bucket pruning works iff the bucketing column has one and only one column.
def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType))
mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null)
val bucketIdGeneration = UnsafeProjection.create(
HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil,
bucketColumn :: Nil)
bucketIdGeneration(mutableRow).getInt(0)
}
// Based on Public API.
private def pruneFilterProject(
relation: LogicalRelation,

View file

@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.util.collection.BitSet
/**
* A strategy for planning scans over collections of files that might be partitioned or bucketed
@ -50,6 +51,91 @@ import org.apache.spark.sql.execution.SparkPlan
* and add it. Proceed to the next file.
*/
object FileSourceStrategy extends Strategy with Logging {
// should prune buckets iff num buckets is greater than 1 and there is only one bucket column
private def shouldPruneBuckets(bucketSpec: Option[BucketSpec]): Boolean = {
bucketSpec match {
case Some(spec) => spec.bucketColumnNames.length == 1 && spec.numBuckets > 1
case None => false
}
}
private def getExpressionBuckets(
expr: Expression,
bucketColumnName: String,
numBuckets: Int): BitSet = {
def getBucketNumber(attr: Attribute, v: Any): Int = {
BucketingUtils.getBucketIdFromValue(attr, numBuckets, v)
}
def getBucketSetFromIterable(attr: Attribute, iter: Iterable[Any]): BitSet = {
val matchedBuckets = new BitSet(numBuckets)
iter
.map(v => getBucketNumber(attr, v))
.foreach(bucketNum => matchedBuckets.set(bucketNum))
matchedBuckets
}
def getBucketSetFromValue(attr: Attribute, v: Any): BitSet = {
val matchedBuckets = new BitSet(numBuckets)
matchedBuckets.set(getBucketNumber(attr, v))
matchedBuckets
}
expr match {
case expressions.Equality(a: Attribute, Literal(v, _)) if a.name == bucketColumnName =>
getBucketSetFromValue(a, v)
case expressions.In(a: Attribute, list)
if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName =>
getBucketSetFromIterable(a, list.map(e => e.eval(EmptyRow)))
case expressions.InSet(a: Attribute, hset)
if hset.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName =>
getBucketSetFromIterable(a, hset.map(e => expressions.Literal(e).eval(EmptyRow)))
case expressions.IsNull(a: Attribute) if a.name == bucketColumnName =>
getBucketSetFromValue(a, null)
case expressions.And(left, right) =>
getExpressionBuckets(left, bucketColumnName, numBuckets) &
getExpressionBuckets(right, bucketColumnName, numBuckets)
case expressions.Or(left, right) =>
getExpressionBuckets(left, bucketColumnName, numBuckets) |
getExpressionBuckets(right, bucketColumnName, numBuckets)
case _ =>
val matchedBuckets = new BitSet(numBuckets)
matchedBuckets.setUntil(numBuckets)
matchedBuckets
}
}
private def genBucketSet(
normalizedFilters: Seq[Expression],
bucketSpec: BucketSpec): Option[BitSet] = {
if (normalizedFilters.isEmpty) {
return None
}
val bucketColumnName = bucketSpec.bucketColumnNames.head
val numBuckets = bucketSpec.numBuckets
val normalizedFiltersAndExpr = normalizedFilters
.reduce(expressions.And)
val matchedBuckets = getExpressionBuckets(normalizedFiltersAndExpr, bucketColumnName,
numBuckets)
val numBucketsSelected = matchedBuckets.cardinality()
logInfo {
s"Pruned ${numBuckets - numBucketsSelected} out of $numBuckets buckets."
}
// None means all the buckets need to be scanned
if (numBucketsSelected == numBuckets) {
None
} else {
Some(matchedBuckets)
}
}
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(projects, filters,
l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) =>
@ -82,6 +168,13 @@ object FileSourceStrategy extends Strategy with Logging {
logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}")
val bucketSpec: Option[BucketSpec] = fsRelation.bucketSpec
val bucketSet = if (shouldPruneBuckets(bucketSpec)) {
genBucketSet(normalizedFilters, bucketSpec.get)
} else {
None
}
val dataColumns =
l.resolve(fsRelation.dataSchema, fsRelation.sparkSession.sessionState.analyzer.resolver)
@ -111,6 +204,7 @@ object FileSourceStrategy extends Strategy with Logging {
outputAttributes,
outputSchema,
partitionKeyFilters.toSeq,
bucketSet,
dataFilters,
table.map(_.identifier))

View file

@ -22,10 +22,11 @@ import java.net.URI
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.datasources.BucketingUtils
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.functions._
@ -52,6 +53,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
s <- Seq(null, "a", "b", "c", "d", "e", "f", null, "g")
} yield (i % 5, s, i % 13)).toDF("i", "j", "k")
// number of buckets that doesn't yield empty buckets when bucketing on column j on df/nullDF
// empty buckets before filtering might hide bugs in pruning logic
private val NumBucketsForPruningDF = 7
private val NumBucketsForPruningNullDf = 5
test("read bucketed data") {
withTable("bucketed_table") {
df.write
@ -90,32 +96,37 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
originalDataFrame: DataFrame): Unit = {
// This test verifies parts of the plan. Disable whole stage codegen.
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
val strategy = DataSourceStrategy(spark.sessionState.conf)
val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k")
val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec
// Limit: bucket pruning only works when the bucket column has one and only one column
assert(bucketColumnNames.length == 1)
val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head)
val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex)
val matchedBuckets = new BitSet(numBuckets)
bucketValues.foreach { value =>
matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value))
}
// Filter could hide the bug in bucket pruning. Thus, skipping all the filters
val plan = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan
val rdd = plan.find(_.isInstanceOf[DataSourceScanExec])
assert(rdd.isDefined, plan)
val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) =>
if (matchedBuckets.get(index % numBuckets) && iter.nonEmpty) Iterator(index) else Iterator()
// if nothing should be pruned, skip the pruning test
if (bucketValues.nonEmpty) {
val matchedBuckets = new BitSet(numBuckets)
bucketValues.foreach { value =>
matchedBuckets.set(BucketingUtils.getBucketIdFromValue(bucketColumn, numBuckets, value))
}
val invalidBuckets = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) =>
// return indexes of partitions that should have been pruned and are not empty
if (!matchedBuckets.get(index % numBuckets) && iter.nonEmpty) {
Iterator(index)
} else {
Iterator()
}
}.collect()
if (invalidBuckets.nonEmpty) {
fail(s"Buckets ${invalidBuckets.mkString(",")} should have been pruned from:\n$plan")
}
}
// TODO: These tests are not testing the right columns.
// // checking if all the pruned buckets are empty
// val invalidBuckets = checkedResult.collect().toList
// if (invalidBuckets.nonEmpty) {
// fail(s"Buckets $invalidBuckets should have been pruned from:\n$plan")
// }
checkAnswer(
bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"),
@ -125,7 +136,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
test("read partitioning bucketed tables with bucket pruning filters") {
withTable("bucketed_table") {
val numBuckets = 8
val numBuckets = NumBucketsForPruningDF
val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
// json does not support predicate push-down, and thus json is used here
df.write
@ -155,13 +166,21 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
bucketValues = Seq(j, j + 1, j + 2, j + 3),
filterCondition = $"j".isin(j, j + 1, j + 2, j + 3),
df)
// Case 4: InSet
val inSetExpr = expressions.InSet($"j".expr, Set(j, j + 1, j + 2, j + 3).map(lit(_).expr))
checkPrunedAnswers(
bucketSpec,
bucketValues = Seq(j, j + 1, j + 2, j + 3),
filterCondition = Column(inSetExpr),
df)
}
}
}
test("read non-partitioning bucketed tables with bucket pruning filters") {
withTable("bucketed_table") {
val numBuckets = 8
val numBuckets = NumBucketsForPruningDF
val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
// json does not support predicate push-down, and thus json is used here
df.write
@ -181,7 +200,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
test("read partitioning bucketed tables having null in bucketing key") {
withTable("bucketed_table") {
val numBuckets = 8
val numBuckets = NumBucketsForPruningNullDf
val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
// json does not support predicate push-down, and thus json is used here
nullDF.write
@ -208,7 +227,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
test("read partitioning bucketed tables having composite filters") {
withTable("bucketed_table") {
val numBuckets = 8
val numBuckets = NumBucketsForPruningDF
val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
// json does not support predicate push-down, and thus json is used here
df.write
@ -229,10 +248,65 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
bucketValues = j :: Nil,
filterCondition = $"j" === j && $"i" > j % 5,
df)
// check multiple bucket values OR condition
checkPrunedAnswers(
bucketSpec,
bucketValues = Seq(j, j + 1),
filterCondition = $"j" === j || $"j" === (j + 1),
df)
// check bucket value and none bucket value OR condition
checkPrunedAnswers(
bucketSpec,
bucketValues = Nil,
filterCondition = $"j" === j || $"i" === 0,
df)
// check AND condition in complex expression
checkPrunedAnswers(
bucketSpec,
bucketValues = Seq(j),
filterCondition = ($"i" === 0 || $"k" > $"j") && $"j" === j,
df)
}
}
}
test("read bucketed table without filters") {
withTable("bucketed_table") {
val numBuckets = NumBucketsForPruningDF
val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil)
// json does not support predicate push-down, and thus json is used here
df.write
.format("json")
.bucketBy(numBuckets, "j")
.saveAsTable("bucketed_table")
val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k")
val plan = bucketedDataFrame.queryExecution.executedPlan
val rdd = plan.find(_.isInstanceOf[DataSourceScanExec])
assert(rdd.isDefined, plan)
val emptyBuckets = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) =>
// return indexes of empty partitions
if (iter.isEmpty) {
Iterator(index)
} else {
Iterator()
}
}.collect()
if (emptyBuckets.nonEmpty) {
fail(s"Buckets ${emptyBuckets.mkString(",")} should not have been pruned from:\n$plan")
}
checkAnswer(
bucketedDataFrame.orderBy("i", "j", "k"),
df.orderBy("i", "j", "k"))
}
}
private lazy val df1 =
(0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
private lazy val df2 =