[SPARK-10978][SQL][FOLLOW-UP] More comprehensive tests for PR #9399

This PR adds test cases that test various column pruning and filter push-down cases.

Author: Cheng Lian <lian@databricks.com>

Closes #9468 from liancheng/spark-10978.follow-up.
This commit is contained in:
Cheng Lian 2015-11-06 11:11:36 -08:00 committed by Yin Huai
parent 574141a298
commit c048929c6a
3 changed files with 318 additions and 43 deletions

View file

@ -17,16 +17,15 @@
package org.apache.spark.sql.sources
import org.apache.spark.sql.execution.datasources.LogicalRelation
import scala.language.existentials
import org.apache.spark.rdd.RDD
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.PredicateHelper
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
class FilteredScanSource extends RelationProvider {
override def createRelation(
@ -130,7 +129,7 @@ object ColumnsRequired {
var set: Set[String] = Set.empty
}
class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
class FilteredScanSuite extends DataSourceTest with SharedSQLContext with PredicateHelper {
protected override lazy val sql = caseInsensitiveContext.sql _
override def beforeAll(): Unit = {
@ -144,9 +143,6 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
| to '10'
|)
""".stripMargin)
// UDF for testing filter push-down
caseInsensitiveContext.udf.register("udf_gt3", (_: Int) > 3)
}
sqlTest(
@ -276,14 +272,15 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1, Set("c"))
testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1, Set("c"))
// Columns only referenced by UDF filter must be required, as UDF filters can't be pushed down.
testPushDown("SELECT c FROM oneToTenFiltered WHERE udf_gt3(A)", 10, Set("a", "c"))
// Filters referencing multiple columns are not convertible, all referenced columns must be
// required.
testPushDown("SELECT c FROM oneToTenFiltered WHERE A + b > 9", 10, Set("a", "b", "c"))
// A query with an unconvertible filter, an unhandled filter, and a handled filter.
// A query with an inconvertible filter, an unhandled filter, and a handled filter.
testPushDown(
"""SELECT a
| FROM oneToTenFiltered
| WHERE udf_gt3(b)
| WHERE a + b > 9
| AND b < 16
| AND c IN ('bbbbbBBBBB', 'cccccCCCCC', 'dddddDDDDD', 'foo')
""".stripMargin.split("\n").map(_.trim).mkString(" "), 3, Set("a", "b"))

View file

@ -17,15 +17,21 @@
package org.apache.spark.sql.sources
import java.io.File
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.execution.PhysicalRDD
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, PredicateHelper}
import org.apache.spark.sql.execution.{LogicalRDD, PhysicalRDD}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, Row, execution}
import org.apache.spark.util.Utils
class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with PredicateHelper {
import testImplicits._
override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName
@ -70,43 +76,304 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
}
}
private val writer = testDF.write.option("dataSchema", dataSchema.json).format(dataSourceName)
private val reader = sqlContext.read.option("dataSchema", dataSchema.json).format(dataSourceName)
private var tempPath: File = _
test("unhandledFilters") {
withTempPath { dir =>
private var partitionedDF: DataFrame = _
val path = dir.getCanonicalPath
writer.save(s"$path/p=0")
writer.save(s"$path/p=1")
private val partitionedDataSchema: StructType = StructType('a.int :: 'b.int :: 'c.string :: Nil)
val isOdd = udf((_: Int) % 2 == 1)
val df = reader.load(path)
.filter(
// This filter is inconvertible
isOdd('a) &&
// This filter is convertible but unhandled
'a > 1 &&
// This filter is convertible and handled
'b > "val_1" &&
// This filter references a partiiton column, won't be pushed down
'p === 1
).select('a, 'p)
val rawScan = df.queryExecution.executedPlan collect {
protected override def beforeAll(): Unit = {
this.tempPath = Utils.createTempDir()
val df = sqlContext.range(10).select(
'id cast IntegerType as 'a,
('id cast IntegerType) * 2 as 'b,
concat(lit("val_"), 'id) as 'c
)
partitionedWriter(df).save(s"${tempPath.getCanonicalPath}/p=0")
partitionedWriter(df).save(s"${tempPath.getCanonicalPath}/p=1")
partitionedDF = partitionedReader.load(tempPath.getCanonicalPath)
}
override protected def afterAll(): Unit = {
Utils.deleteRecursively(tempPath)
}
private def partitionedWriter(df: DataFrame) =
df.write.option("dataSchema", partitionedDataSchema.json).format(dataSourceName)
private def partitionedReader =
sqlContext.read.option("dataSchema", partitionedDataSchema.json).format(dataSourceName)
/**
* Constructs test cases that test column pruning and filter push-down.
*
* For filter push-down, the following filters are not pushed-down.
*
* 1. Partitioning filters don't participate filter push-down, they are handled separately in
* `DataSourceStrategy`
*
* 2. Catalyst filter `Expression`s that cannot be converted to data source `Filter`s are not
* pushed down (e.g. UDF and filters referencing multiple columns).
*
* 3. Catalyst filter `Expression`s that can be converted to data source `Filter`s but cannot be
* handled by the underlying data source are not pushed down (e.g. returned from
* `BaseRelation.unhandledFilters()`).
*
* Note that for [[SimpleTextRelation]], all data source [[Filter]]s other than [[GreaterThan]]
* are unhandled. We made this assumption in [[SimpleTextRelation.unhandledFilters()]] only
* for testing purposes.
*
* @param projections Projection list of the query
* @param filter Filter condition of the query
* @param requiredColumns Expected names of required columns
* @param pushedFilters Expected data source [[Filter]]s that are pushed down
* @param inconvertibleFilters Expected Catalyst filter [[Expression]]s that cannot be converted
* to data source [[Filter]]s
* @param unhandledFilters Expected Catalyst flter [[Expression]]s that can be converted to data
* source [[Filter]]s but cannot be handled by the data source relation
* @param partitioningFilters Expected Catalyst filter [[Expression]]s that reference partition
* columns
* @param expectedRawScanAnswer Expected query result of the raw table scan returned by the data
* source relation
* @param expectedAnswer Expected query result of the full query
*/
def testPruningAndFiltering(
projections: Seq[Column],
filter: Column,
requiredColumns: Seq[String],
pushedFilters: Seq[Filter],
inconvertibleFilters: Seq[Column],
unhandledFilters: Seq[Column],
partitioningFilters: Seq[Column])(
expectedRawScanAnswer: => Seq[Row])(
expectedAnswer: => Seq[Row]): Unit = {
test(s"pruning and filtering: df.select(${projections.mkString(", ")}).where($filter)") {
val df = partitionedDF.where(filter).select(projections: _*)
val queryExecution = df.queryExecution
val executedPlan = queryExecution.executedPlan
val rawScan = executedPlan.collect {
case p: PhysicalRDD => p
} match {
case Seq(p) => p
case Seq(scan) => scan
case _ => fail(s"More than one PhysicalRDD found\n$queryExecution")
}
val outputSchema = new StructType().add("a", IntegerType).add("p", IntegerType)
markup("Checking raw scan answer")
checkAnswer(
DataFrame(sqlContext, LogicalRDD(rawScan.output, rawScan.rdd)(sqlContext)),
expectedRawScanAnswer)
assertResult(Set((2, 1), (3, 1))) {
rawScan.execute().collect()
.map { CatalystTypeConverters.convertToScala(_, outputSchema) }
.map { case Row(a, p) => (a, p) }.toSet
markup("Checking full query answer")
checkAnswer(df, expectedAnswer)
markup("Checking required columns")
assert(requiredColumns === SimpleTextRelation.requiredColumns)
val nonPushedFilters = {
val boundFilters = executedPlan.collect {
case f: execution.Filter => f
} match {
case Nil => Nil
case Seq(f) => splitConjunctivePredicates(f.condition)
case _ => fail(s"More than one PhysicalRDD found\n$queryExecution")
}
// Unbound these bound filters so that we can easily compare them with expected results.
boundFilters.map {
_.transform { case a: AttributeReference => UnresolvedAttribute(a.name) }
}.toSet
}
checkAnswer(df, Row(3, 1))
markup("Checking pushed filters")
assert(SimpleTextRelation.pushedFilters === pushedFilters.toSet)
val expectedInconvertibleFilters = inconvertibleFilters.map(_.expr).toSet
val expectedUnhandledFilters = unhandledFilters.map(_.expr).toSet
val expectedPartitioningFilters = partitioningFilters.map(_.expr).toSet
markup("Checking unhandled and inconvertible filters")
assert(expectedInconvertibleFilters ++ expectedUnhandledFilters === nonPushedFilters)
markup("Checking partitioning filters")
val actualPartitioningFilters = splitConjunctivePredicates(filter.expr).filter {
_.references.contains(UnresolvedAttribute("p"))
}.toSet
// Partitioning filters are handled separately and don't participate filter push-down. So they
// shouldn't be part of non-pushed filters.
assert(expectedPartitioningFilters.intersect(nonPushedFilters).isEmpty)
assert(expectedPartitioningFilters === actualPartitioningFilters)
}
}
testPruningAndFiltering(
projections = Seq('*),
filter = 'p > 0,
requiredColumns = Seq("a", "b", "c"),
pushedFilters = Nil,
inconvertibleFilters = Nil,
unhandledFilters = Nil,
partitioningFilters = Seq('p > 0)
) {
Seq(
Row(0, 0, "val_0", 1),
Row(1, 2, "val_1", 1),
Row(2, 4, "val_2", 1),
Row(3, 6, "val_3", 1),
Row(4, 8, "val_4", 1),
Row(5, 10, "val_5", 1),
Row(6, 12, "val_6", 1),
Row(7, 14, "val_7", 1),
Row(8, 16, "val_8", 1),
Row(9, 18, "val_9", 1))
} {
Seq(
Row(0, 0, "val_0", 1),
Row(1, 2, "val_1", 1),
Row(2, 4, "val_2", 1),
Row(3, 6, "val_3", 1),
Row(4, 8, "val_4", 1),
Row(5, 10, "val_5", 1),
Row(6, 12, "val_6", 1),
Row(7, 14, "val_7", 1),
Row(8, 16, "val_8", 1),
Row(9, 18, "val_9", 1))
}
testPruningAndFiltering(
projections = Seq('c, 'p),
filter = 'a < 3 && 'p > 0,
requiredColumns = Seq("c", "a"),
pushedFilters = Nil,
inconvertibleFilters = Nil,
unhandledFilters = Seq('a < 3),
partitioningFilters = Seq('p > 0)
) {
Seq(
Row("val_0", 1, 0),
Row("val_1", 1, 1),
Row("val_2", 1, 2),
Row("val_3", 1, 3),
Row("val_4", 1, 4),
Row("val_5", 1, 5),
Row("val_6", 1, 6),
Row("val_7", 1, 7),
Row("val_8", 1, 8),
Row("val_9", 1, 9))
} {
Seq(
Row("val_0", 1),
Row("val_1", 1),
Row("val_2", 1))
}
testPruningAndFiltering(
projections = Seq('*),
filter = 'a > 8,
requiredColumns = Seq("a", "b", "c"),
pushedFilters = Seq(GreaterThan("a", 8)),
inconvertibleFilters = Nil,
unhandledFilters = Nil,
partitioningFilters = Nil
) {
Seq(
Row(9, 18, "val_9", 0),
Row(9, 18, "val_9", 1))
} {
Seq(
Row(9, 18, "val_9", 0),
Row(9, 18, "val_9", 1))
}
testPruningAndFiltering(
projections = Seq('b, 'p),
filter = 'a > 8,
requiredColumns = Seq("b"),
pushedFilters = Seq(GreaterThan("a", 8)),
inconvertibleFilters = Nil,
unhandledFilters = Nil,
partitioningFilters = Nil
) {
Seq(
Row(18, 0),
Row(18, 1))
} {
Seq(
Row(18, 0),
Row(18, 1))
}
testPruningAndFiltering(
projections = Seq('b, 'p),
filter = 'a > 8 && 'p > 0,
requiredColumns = Seq("b"),
pushedFilters = Seq(GreaterThan("a", 8)),
inconvertibleFilters = Nil,
unhandledFilters = Nil,
partitioningFilters = Seq('p > 0)
) {
Seq(
Row(18, 1))
} {
Seq(
Row(18, 1))
}
testPruningAndFiltering(
projections = Seq('b, 'p),
filter = 'c > "val_7" && 'b < 18 && 'p > 0,
requiredColumns = Seq("b"),
pushedFilters = Seq(GreaterThan("c", "val_7")),
inconvertibleFilters = Nil,
unhandledFilters = Seq('b < 18),
partitioningFilters = Seq('p > 0)
) {
Seq(
Row(16, 1),
Row(18, 1))
} {
Seq(
Row(16, 1))
}
testPruningAndFiltering(
projections = Seq('b, 'p),
filter = 'a % 2 === 0 && 'c > "val_7" && 'b < 18 && 'p > 0,
requiredColumns = Seq("b", "a"),
pushedFilters = Seq(GreaterThan("c", "val_7")),
inconvertibleFilters = Seq('a % 2 === 0),
unhandledFilters = Seq('b < 18),
partitioningFilters = Seq('p > 0)
) {
Seq(
Row(16, 1, 8),
Row(18, 1, 9))
} {
Seq(
Row(16, 1))
}
testPruningAndFiltering(
projections = Seq('b, 'p),
filter = 'a > 7 && 'a < 9,
requiredColumns = Seq("b", "a"),
pushedFilters = Seq(GreaterThan("a", 7)),
inconvertibleFilters = Nil,
unhandledFilters = Seq('a < 9),
partitioningFilters = Nil
) {
Seq(
Row(16, 0, 8),
Row(16, 1, 8),
Row(18, 0, 9),
Row(18, 1, 9))
} {
Seq(
Row(16, 0),
Row(16, 1))
}
}

View file

@ -128,6 +128,9 @@ class SimpleTextRelation(
filters: Array[Filter],
inputFiles: Array[FileStatus]): RDD[Row] = {
SimpleTextRelation.requiredColumns = requiredColumns
SimpleTextRelation.pushedFilters = filters.toSet
val fields = this.dataSchema.map(_.dataType)
val inputAttributes = this.dataSchema.toAttributes
val outputAttributes = requiredColumns.flatMap(name => inputAttributes.find(_.name == name))
@ -191,6 +194,14 @@ class SimpleTextRelation(
}
}
object SimpleTextRelation {
// Used to test column pruning
var requiredColumns: Seq[String] = Nil
// Used to test filter push-down
var pushedFilters: Set[Filter] = Set.empty
}
/**
* A simple example [[HadoopFsRelationProvider]].
*/