[SPARK-4453][SPARK-4213][SQL] Simplifies Parquet filter generation code

While reviewing PR #3083 and #3161, I noticed that Parquet record filter generation code can be simplified significantly according to the clue stated in [SPARK-4453](https://issues.apache.org/jira/browse/SPARK-4213). This PR addresses both SPARK-4453 and SPARK-4213 with this simplification.

While generating `ParquetTableScan` operator, we need to remove all Catalyst predicates that have already been pushed down to Parquet. Originally, we first generate the record filter, and then call `findExpression` to traverse the generated filter to find out all pushed down predicates [[1](64c6b9bad5/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala (L213-L228))]. In this way, we have to introduce the `CatalystFilter` class hierarchy to bind the Catalyst predicates together with their generated Parquet filter, and complicate the code base a lot.

The basic idea of this PR is that, we don't need `findExpression` after filter generation, because we already know a predicate can be pushed down if we can successfully generate its corresponding Parquet filter. SPARK-4213 is fixed by returning `None` for any unsupported predicate type.

<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/3317)
<!-- Reviewable:end -->

Author: Cheng Lian <lian@databricks.com>

Closes #3317 from liancheng/simplify-parquet-filters and squashes the following commits:

d6a9499 [Cheng Lian] Fixes import styling issue
43760e8 [Cheng Lian] Simplifies Parquet filter generation logic
This commit is contained in:
Cheng Lian 2014-11-17 16:55:12 -08:00 committed by Michael Armbrust
parent ef7c464eff
commit 36b0956a3e
5 changed files with 180 additions and 712 deletions

View file

@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.util.Metadata
object NamedExpression {
private val curId = new java.util.concurrent.atomic.AtomicLong()
def newExprId = ExprId(curId.getAndIncrement())
def unapply(expr: NamedExpression): Option[(String, DataType)] = Some(expr.name, expr.dataType)
}
/**

View file

@ -209,22 +209,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) =>
val prunePushedDownFilters =
if (sqlContext.parquetFilterPushDown) {
(filters: Seq[Expression]) => {
filters.filter { filter =>
// Note: filters cannot be pushed down to Parquet if they contain more complex
// expressions than simple "Attribute cmp Literal" comparisons. Here we remove
// all filters that have been pushed down. Note that a predicate such as
// "(A AND B) OR C" can result in "A OR C" being pushed down.
val recordFilter = ParquetFilters.createFilter(filter)
if (!recordFilter.isDefined) {
// First case: the pushdown did not result in any record filter.
true
} else {
// Second case: a record filter was created; here we are conservative in
// the sense that even if "A" was pushed and we check for "A AND B" we
// still want to keep "A AND B" in the higher-level filter, not just "B".
!ParquetFilters.findExpression(recordFilter.get, filter).isDefined
}
(predicates: Seq[Expression]) => {
// Note: filters cannot be pushed down to Parquet if they contain more complex
// expressions than simple "Attribute cmp Literal" comparisons. Here we remove all
// filters that have been pushed down. Note that a predicate such as "(A AND B) OR C"
// can result in "A OR C" being pushed down. Here we are conservative in the sense
// that even if "A" was pushed and we check for "A AND B" we still want to keep
// "A AND B" in the higher-level filter, not just "B".
predicates.map(p => p -> ParquetFilters.createFilter(p)).collect {
case (predicate, None) => predicate
}
}
} else {

View file

@ -18,406 +18,152 @@
package org.apache.spark.sql.parquet
import java.nio.ByteBuffer
import java.sql.{Date, Timestamp}
import org.apache.hadoop.conf.Configuration
import parquet.common.schema.ColumnPath
import parquet.filter2.compat.FilterCompat
import parquet.filter2.compat.FilterCompat._
import parquet.filter2.predicate.Operators.{Column, SupportsLtGt}
import parquet.filter2.predicate.{FilterApi, FilterPredicate}
import parquet.filter2.predicate.FilterApi._
import parquet.io.api.Binary
import parquet.column.ColumnReader
import com.google.common.io.BaseEncoding
import org.apache.hadoop.conf.Configuration
import parquet.filter2.compat.FilterCompat
import parquet.filter2.compat.FilterCompat._
import parquet.filter2.predicate.FilterApi._
import parquet.filter2.predicate.{FilterApi, FilterPredicate}
import parquet.io.api.Binary
import org.apache.spark.SparkEnv
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.apache.spark.sql.catalyst.expressions.{Predicate => CatalystPredicate}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.parquet.ParquetColumns._
import org.apache.spark.sql.catalyst.types._
private[sql] object ParquetFilters {
val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter"
def createRecordFilter(filterExpressions: Seq[Expression]): Filter = {
val filters: Seq[CatalystFilter] = filterExpressions.collect {
case (expression: Expression) if createFilter(expression).isDefined =>
createFilter(expression).get
}
if (filters.length > 0) FilterCompat.get(filters.reduce(FilterApi.and)) else null
def createRecordFilter(filterExpressions: Seq[Expression]): Option[Filter] = {
filterExpressions.flatMap(createFilter).reduceOption(FilterApi.and).map(FilterCompat.get)
}
def createFilter(expression: Expression): Option[CatalystFilter] = {
def createEqualityFilter(
name: String,
literal: Literal,
predicate: CatalystPredicate) = literal.dataType match {
def createFilter(predicate: Expression): Option[FilterPredicate] = {
val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
case BooleanType =>
ComparisonFilter.createBooleanEqualityFilter(
name,
literal.value.asInstanceOf[Boolean],
predicate)
case ByteType =>
new ComparisonFilter(
name,
FilterApi.eq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]),
predicate)
case ShortType =>
new ComparisonFilter(
name,
FilterApi.eq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]),
predicate)
(n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean])
case IntegerType =>
new ComparisonFilter(
name,
FilterApi.eq(intColumn(name), literal.value.asInstanceOf[Integer]),
predicate)
(n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer])
case LongType =>
new ComparisonFilter(
name,
FilterApi.eq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]),
predicate)
case DoubleType =>
new ComparisonFilter(
name,
FilterApi.eq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]),
predicate)
(n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long])
case FloatType =>
new ComparisonFilter(
name,
FilterApi.eq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]),
predicate)
(n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float])
case DoubleType =>
(n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
ComparisonFilter.createStringEqualityFilter(
name,
literal.value.asInstanceOf[String],
predicate)
(n: String, v: Any) =>
FilterApi.eq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
case BinaryType =>
ComparisonFilter.createBinaryEqualityFilter(
name,
literal.value.asInstanceOf[Array[Byte]],
predicate)
case DateType =>
new ComparisonFilter(
name,
FilterApi.eq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])),
predicate)
case TimestampType =>
new ComparisonFilter(
name,
FilterApi.eq(timestampColumn(name),
new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])),
predicate)
case DecimalType.Unlimited =>
new ComparisonFilter(
name,
FilterApi.eq(decimalColumn(name), literal.value.asInstanceOf[Decimal]),
predicate)
(n: String, v: Any) =>
FilterApi.eq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
}
def createLessThanFilter(
name: String,
literal: Literal,
predicate: CatalystPredicate) = literal.dataType match {
case ByteType =>
new ComparisonFilter(
name,
FilterApi.lt(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]),
predicate)
case ShortType =>
new ComparisonFilter(
name,
FilterApi.lt(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]),
predicate)
val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
case IntegerType =>
new ComparisonFilter(
name,
FilterApi.lt(intColumn(name), literal.value.asInstanceOf[Integer]),
predicate)
(n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer])
case LongType =>
new ComparisonFilter(
name,
FilterApi.lt(longColumn(name), literal.value.asInstanceOf[java.lang.Long]),
predicate)
case DoubleType =>
new ComparisonFilter(
name,
FilterApi.lt(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]),
predicate)
(n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long])
case FloatType =>
new ComparisonFilter(
name,
FilterApi.lt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]),
predicate)
case StringType =>
ComparisonFilter.createStringLessThanFilter(
name,
literal.value.asInstanceOf[String],
predicate)
case BinaryType =>
ComparisonFilter.createBinaryLessThanFilter(
name,
literal.value.asInstanceOf[Array[Byte]],
predicate)
case DateType =>
new ComparisonFilter(
name,
FilterApi.lt(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])),
predicate)
case TimestampType =>
new ComparisonFilter(
name,
FilterApi.lt(timestampColumn(name),
new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])),
predicate)
case DecimalType.Unlimited =>
new ComparisonFilter(
name,
FilterApi.lt(decimalColumn(name), literal.value.asInstanceOf[Decimal]),
predicate)
}
def createLessThanOrEqualFilter(
name: String,
literal: Literal,
predicate: CatalystPredicate) = literal.dataType match {
case ByteType =>
new ComparisonFilter(
name,
FilterApi.ltEq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]),
predicate)
case ShortType =>
new ComparisonFilter(
name,
FilterApi.ltEq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]),
predicate)
case IntegerType =>
new ComparisonFilter(
name,
FilterApi.ltEq(intColumn(name), literal.value.asInstanceOf[Integer]),
predicate)
case LongType =>
new ComparisonFilter(
name,
FilterApi.ltEq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]),
predicate)
(n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float])
case DoubleType =>
new ComparisonFilter(
name,
FilterApi.ltEq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]),
predicate)
case FloatType =>
new ComparisonFilter(
name,
FilterApi.ltEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]),
predicate)
(n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
ComparisonFilter.createStringLessThanOrEqualFilter(
name,
literal.value.asInstanceOf[String],
predicate)
(n: String, v: Any) =>
FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
case BinaryType =>
ComparisonFilter.createBinaryLessThanOrEqualFilter(
name,
literal.value.asInstanceOf[Array[Byte]],
predicate)
case DateType =>
new ComparisonFilter(
name,
FilterApi.ltEq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])),
predicate)
case TimestampType =>
new ComparisonFilter(
name,
FilterApi.ltEq(timestampColumn(name),
new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])),
predicate)
case DecimalType.Unlimited =>
new ComparisonFilter(
name,
FilterApi.ltEq(decimalColumn(name), literal.value.asInstanceOf[Decimal]),
predicate)
}
// TODO: combine these two types somehow?
def createGreaterThanFilter(
name: String,
literal: Literal,
predicate: CatalystPredicate) = literal.dataType match {
case ByteType =>
new ComparisonFilter(
name,
FilterApi.gt(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]),
predicate)
case ShortType =>
new ComparisonFilter(
name,
FilterApi.gt(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]),
predicate)
case IntegerType =>
new ComparisonFilter(
name,
FilterApi.gt(intColumn(name), literal.value.asInstanceOf[Integer]),
predicate)
case LongType =>
new ComparisonFilter(
name,
FilterApi.gt(longColumn(name), literal.value.asInstanceOf[java.lang.Long]),
predicate)
case DoubleType =>
new ComparisonFilter(
name,
FilterApi.gt(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]),
predicate)
case FloatType =>
new ComparisonFilter(
name,
FilterApi.gt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]),
predicate)
case StringType =>
ComparisonFilter.createStringGreaterThanFilter(
name,
literal.value.asInstanceOf[String],
predicate)
case BinaryType =>
ComparisonFilter.createBinaryGreaterThanFilter(
name,
literal.value.asInstanceOf[Array[Byte]],
predicate)
case DateType =>
new ComparisonFilter(
name,
FilterApi.gt(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])),
predicate)
case TimestampType =>
new ComparisonFilter(
name,
FilterApi.gt(timestampColumn(name),
new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])),
predicate)
case DecimalType.Unlimited =>
new ComparisonFilter(
name,
FilterApi.gt(decimalColumn(name), literal.value.asInstanceOf[Decimal]),
predicate)
}
def createGreaterThanOrEqualFilter(
name: String,
literal: Literal,
predicate: CatalystPredicate) = literal.dataType match {
case ByteType =>
new ComparisonFilter(
name,
FilterApi.gtEq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]),
predicate)
case ShortType =>
new ComparisonFilter(
name,
FilterApi.gtEq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]),
predicate)
case IntegerType =>
new ComparisonFilter(
name,
FilterApi.gtEq(intColumn(name), literal.value.asInstanceOf[Integer]),
predicate)
case LongType =>
new ComparisonFilter(
name,
FilterApi.gtEq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]),
predicate)
case DoubleType =>
new ComparisonFilter(
name,
FilterApi.gtEq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]),
predicate)
case FloatType =>
new ComparisonFilter(
name,
FilterApi.gtEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]),
predicate)
case StringType =>
ComparisonFilter.createStringGreaterThanOrEqualFilter(
name,
literal.value.asInstanceOf[String],
predicate)
case BinaryType =>
ComparisonFilter.createBinaryGreaterThanOrEqualFilter(
name,
literal.value.asInstanceOf[Array[Byte]],
predicate)
case DateType =>
new ComparisonFilter(
name,
FilterApi.gtEq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])),
predicate)
case TimestampType =>
new ComparisonFilter(
name,
FilterApi.gtEq(timestampColumn(name),
new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])),
predicate)
case DecimalType.Unlimited =>
new ComparisonFilter(
name,
FilterApi.gtEq(decimalColumn(name), literal.value.asInstanceOf[Decimal]),
predicate)
(n: String, v: Any) =>
FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
}
/**
* TODO: we currently only filter on non-nullable (Parquet REQUIRED) attributes until
* https://github.com/Parquet/parquet-mr/issues/371
* has been resolved.
*/
expression match {
case p @ Or(left: Expression, right: Expression)
if createFilter(left).isDefined && createFilter(right).isDefined => {
// If either side of this Or-predicate is empty then this means
// it contains a more complex comparison than between attribute and literal
// (e.g., it contained a CAST). The only safe thing to do is then to disregard
// this disjunction, which could be contained in a conjunction. If it stands
// alone then it is also safe to drop it, since a Null return value of this
// function is interpreted as having no filters at all.
val leftFilter = createFilter(left).get
val rightFilter = createFilter(right).get
Some(new OrFilter(leftFilter, rightFilter))
}
case p @ And(left: Expression, right: Expression) => {
// This treats nested conjunctions; since either side of the conjunction
// may contain more complex filter expressions we may actually generate
// strictly weaker filter predicates in the process.
val leftFilter = createFilter(left)
val rightFilter = createFilter(right)
(leftFilter, rightFilter) match {
case (None, Some(filter)) => Some(filter)
case (Some(filter), None) => Some(filter)
case (Some(leftF), Some(rightF)) =>
Some(new AndFilter(leftF, rightF))
case _ => None
}
}
case p @ EqualTo(left: Literal, right: NamedExpression) if left.dataType != NullType =>
Some(createEqualityFilter(right.name, left, p))
case p @ EqualTo(left: NamedExpression, right: Literal) if right.dataType != NullType =>
Some(createEqualityFilter(left.name, right, p))
case p @ LessThan(left: Literal, right: NamedExpression) =>
Some(createLessThanFilter(right.name, left, p))
case p @ LessThan(left: NamedExpression, right: Literal) =>
Some(createLessThanFilter(left.name, right, p))
case p @ LessThanOrEqual(left: Literal, right: NamedExpression) =>
Some(createLessThanOrEqualFilter(right.name, left, p))
case p @ LessThanOrEqual(left: NamedExpression, right: Literal) =>
Some(createLessThanOrEqualFilter(left.name, right, p))
case p @ GreaterThan(left: Literal, right: NamedExpression) =>
Some(createGreaterThanFilter(right.name, left, p))
case p @ GreaterThan(left: NamedExpression, right: Literal) =>
Some(createGreaterThanFilter(left.name, right, p))
case p @ GreaterThanOrEqual(left: Literal, right: NamedExpression) =>
Some(createGreaterThanOrEqualFilter(right.name, left, p))
case p @ GreaterThanOrEqual(left: NamedExpression, right: Literal) =>
Some(createGreaterThanOrEqualFilter(left.name, right, p))
val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
case IntegerType =>
(n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[java.lang.Integer])
case LongType =>
(n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long])
case FloatType =>
(n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float])
case DoubleType =>
(n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
}
val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
case IntegerType =>
(n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[java.lang.Integer])
case LongType =>
(n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long])
case FloatType =>
(n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float])
case DoubleType =>
(n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
}
val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
case IntegerType =>
(n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[java.lang.Integer])
case LongType =>
(n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long])
case FloatType =>
(n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float])
case DoubleType =>
(n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
}
predicate match {
case EqualTo(NamedExpression(name, _), Literal(value, dataType)) if dataType != NullType =>
makeEq.lift(dataType).map(_(name, value))
case EqualTo(Literal(value, dataType), NamedExpression(name, _)) if dataType != NullType =>
makeEq.lift(dataType).map(_(name, value))
case LessThan(NamedExpression(name, _), Literal(value, dataType)) =>
makeLt.lift(dataType).map(_(name, value))
case LessThan(Literal(value, dataType), NamedExpression(name, _)) =>
makeLt.lift(dataType).map(_(name, value))
case LessThanOrEqual(NamedExpression(name, _), Literal(value, dataType)) =>
makeLtEq.lift(dataType).map(_(name, value))
case LessThanOrEqual(Literal(value, dataType), NamedExpression(name, _)) =>
makeLtEq.lift(dataType).map(_(name, value))
case GreaterThan(NamedExpression(name, _), Literal(value, dataType)) =>
makeGt.lift(dataType).map(_(name, value))
case GreaterThan(Literal(value, dataType), NamedExpression(name, _)) =>
makeGt.lift(dataType).map(_(name, value))
case GreaterThanOrEqual(NamedExpression(name, _), Literal(value, dataType)) =>
makeGtEq.lift(dataType).map(_(name, value))
case GreaterThanOrEqual(Literal(value, dataType), NamedExpression(name, _)) =>
makeGtEq.lift(dataType).map(_(name, value))
case And(lhs, rhs) =>
(createFilter(lhs) ++ createFilter(rhs)).reduceOption(FilterApi.and)
case Or(lhs, rhs) =>
for {
lhsFilter <- createFilter(lhs)
rhsFilter <- createFilter(rhs)
} yield FilterApi.or(lhsFilter, rhsFilter)
case Not(pred) =>
createFilter(pred).map(FilterApi.not)
case _ => None
}
}
@ -428,7 +174,7 @@ private[sql] object ParquetFilters {
* the actual filter predicate.
*/
def serializeFilterExpressions(filters: Seq[Expression], conf: Configuration): Unit = {
if (filters.length > 0) {
if (filters.nonEmpty) {
val serialized: Array[Byte] =
SparkEnv.get.closureSerializer.newInstance().serialize(filters).array()
val encoded: String = BaseEncoding.base64().encode(serialized)
@ -450,245 +196,4 @@ private[sql] object ParquetFilters {
Seq()
}
}
/**
* Try to find the given expression in the tree of filters in order to
* determine whether it is safe to remove it from the higher level filters. Note
* that strictly speaking we could stop the search whenever an expression is found
* that contains this expression as subexpression (e.g., when searching for "a"
* and "(a or c)" is found) but we don't care about optimizations here since the
* filter tree is assumed to be small.
*
* @param filter The [[org.apache.spark.sql.parquet.CatalystFilter]] to expand
* and search
* @param expression The expression to look for
* @return An optional [[org.apache.spark.sql.parquet.CatalystFilter]] that
* contains the expression.
*/
def findExpression(
filter: CatalystFilter,
expression: Expression): Option[CatalystFilter] = filter match {
case f @ OrFilter(_, leftFilter, rightFilter, _) =>
if (f.predicate == expression) {
Some(f)
} else {
val left = findExpression(leftFilter, expression)
if (left.isDefined) left else findExpression(rightFilter, expression)
}
case f @ AndFilter(_, leftFilter, rightFilter, _) =>
if (f.predicate == expression) {
Some(f)
} else {
val left = findExpression(leftFilter, expression)
if (left.isDefined) left else findExpression(rightFilter, expression)
}
case f @ ComparisonFilter(_, _, predicate) =>
if (predicate == expression) Some(f) else None
case _ => None
}
}
abstract private[parquet] class CatalystFilter(
@transient val predicate: CatalystPredicate) extends FilterPredicate
private[parquet] case class ComparisonFilter(
val columnName: String,
private var filter: FilterPredicate,
@transient override val predicate: CatalystPredicate)
extends CatalystFilter(predicate) {
override def accept[R](visitor: FilterPredicate.Visitor[R]): R = {
filter.accept(visitor)
}
}
private[parquet] case class OrFilter(
private var filter: FilterPredicate,
@transient val left: CatalystFilter,
@transient val right: CatalystFilter,
@transient override val predicate: Or)
extends CatalystFilter(predicate) {
def this(l: CatalystFilter, r: CatalystFilter) =
this(
FilterApi.or(l, r),
l,
r,
Or(l.predicate, r.predicate))
override def accept[R](visitor: FilterPredicate.Visitor[R]): R = {
filter.accept(visitor);
}
}
private[parquet] case class AndFilter(
private var filter: FilterPredicate,
@transient val left: CatalystFilter,
@transient val right: CatalystFilter,
@transient override val predicate: And)
extends CatalystFilter(predicate) {
def this(l: CatalystFilter, r: CatalystFilter) =
this(
FilterApi.and(l, r),
l,
r,
And(l.predicate, r.predicate))
override def accept[R](visitor: FilterPredicate.Visitor[R]): R = {
filter.accept(visitor);
}
}
private[parquet] object ComparisonFilter {
def createBooleanEqualityFilter(
columnName: String,
value: Boolean,
predicate: CatalystPredicate): CatalystFilter =
new ComparisonFilter(
columnName,
FilterApi.eq(booleanColumn(columnName), value.asInstanceOf[java.lang.Boolean]),
predicate)
def createStringEqualityFilter(
columnName: String,
value: String,
predicate: CatalystPredicate): CatalystFilter =
new ComparisonFilter(
columnName,
FilterApi.eq(binaryColumn(columnName), Binary.fromString(value)),
predicate)
def createStringLessThanFilter(
columnName: String,
value: String,
predicate: CatalystPredicate): CatalystFilter =
new ComparisonFilter(
columnName,
FilterApi.lt(binaryColumn(columnName), Binary.fromString(value)),
predicate)
def createStringLessThanOrEqualFilter(
columnName: String,
value: String,
predicate: CatalystPredicate): CatalystFilter =
new ComparisonFilter(
columnName,
FilterApi.ltEq(binaryColumn(columnName), Binary.fromString(value)),
predicate)
def createStringGreaterThanFilter(
columnName: String,
value: String,
predicate: CatalystPredicate): CatalystFilter =
new ComparisonFilter(
columnName,
FilterApi.gt(binaryColumn(columnName), Binary.fromString(value)),
predicate)
def createStringGreaterThanOrEqualFilter(
columnName: String,
value: String,
predicate: CatalystPredicate): CatalystFilter =
new ComparisonFilter(
columnName,
FilterApi.gtEq(binaryColumn(columnName), Binary.fromString(value)),
predicate)
def createBinaryEqualityFilter(
columnName: String,
value: Array[Byte],
predicate: CatalystPredicate): CatalystFilter =
new ComparisonFilter(
columnName,
FilterApi.eq(binaryColumn(columnName), Binary.fromByteArray(value)),
predicate)
def createBinaryLessThanFilter(
columnName: String,
value: Array[Byte],
predicate: CatalystPredicate): CatalystFilter =
new ComparisonFilter(
columnName,
FilterApi.lt(binaryColumn(columnName), Binary.fromByteArray(value)),
predicate)
def createBinaryLessThanOrEqualFilter(
columnName: String,
value: Array[Byte],
predicate: CatalystPredicate): CatalystFilter =
new ComparisonFilter(
columnName,
FilterApi.ltEq(binaryColumn(columnName), Binary.fromByteArray(value)),
predicate)
def createBinaryGreaterThanFilter(
columnName: String,
value: Array[Byte],
predicate: CatalystPredicate): CatalystFilter =
new ComparisonFilter(
columnName,
FilterApi.gt(binaryColumn(columnName), Binary.fromByteArray(value)),
predicate)
def createBinaryGreaterThanOrEqualFilter(
columnName: String,
value: Array[Byte],
predicate: CatalystPredicate): CatalystFilter =
new ComparisonFilter(
columnName,
FilterApi.gtEq(binaryColumn(columnName), Binary.fromByteArray(value)),
predicate)
}
private[spark] object ParquetColumns {
def byteColumn(columnPath: String): ByteColumn = {
new ByteColumn(ColumnPath.fromDotString(columnPath))
}
final class ByteColumn(columnPath: ColumnPath)
extends Column[java.lang.Byte](columnPath, classOf[java.lang.Byte]) with SupportsLtGt
def shortColumn(columnPath: String): ShortColumn = {
new ShortColumn(ColumnPath.fromDotString(columnPath))
}
final class ShortColumn(columnPath: ColumnPath)
extends Column[java.lang.Short](columnPath, classOf[java.lang.Short]) with SupportsLtGt
def dateColumn(columnPath: String): DateColumn = {
new DateColumn(ColumnPath.fromDotString(columnPath))
}
final class DateColumn(columnPath: ColumnPath)
extends Column[WrappedDate](columnPath, classOf[WrappedDate]) with SupportsLtGt
def timestampColumn(columnPath: String): TimestampColumn = {
new TimestampColumn(ColumnPath.fromDotString(columnPath))
}
final class TimestampColumn(columnPath: ColumnPath)
extends Column[WrappedTimestamp](columnPath, classOf[WrappedTimestamp]) with SupportsLtGt
def decimalColumn(columnPath: String): DecimalColumn = {
new DecimalColumn(ColumnPath.fromDotString(columnPath))
}
final class DecimalColumn(columnPath: ColumnPath)
extends Column[Decimal](columnPath, classOf[Decimal]) with SupportsLtGt
final class WrappedDate(val date: Date) extends Comparable[WrappedDate] {
override def compareTo(other: WrappedDate): Int = {
date.compareTo(other.date)
}
}
final class WrappedTimestamp(val timestamp: Timestamp) extends Comparable[WrappedTimestamp] {
override def compareTo(other: WrappedTimestamp): Int = {
timestamp.compareTo(other.timestamp)
}
}
}

View file

@ -23,8 +23,6 @@ import java.text.SimpleDateFormat
import java.util.concurrent.{Callable, TimeUnit}
import java.util.{ArrayList, Collections, Date, List => JList}
import org.apache.spark.annotation.DeveloperApi
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.util.Try
@ -34,22 +32,20 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path}
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat => NewFileOutputFormat}
import parquet.hadoop._
import parquet.hadoop.api.ReadSupport.ReadContext
import parquet.hadoop.api.{InitContext, ReadSupport}
import parquet.hadoop.metadata.GlobalMetaData
import parquet.hadoop.api.ReadSupport.ReadContext
import parquet.hadoop.util.ContextUtil
import parquet.io.ParquetDecodingException
import parquet.schema.MessageType
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row, _}
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
import org.apache.spark.{Logging, SerializableWritable, TaskContext}
@ -82,8 +78,6 @@ case class ParquetTableScan(
override def execute(): RDD[Row] = {
import parquet.filter2.compat.FilterCompat.FilterPredicateCompat
import parquet.filter2.compat.FilterCompat.Filter
import parquet.filter2.predicate.FilterPredicate
val sc = sqlContext.sparkContext
val job = new Job(sc.hadoopConfiguration)
@ -111,14 +105,11 @@ case class ParquetTableScan(
// Note 1: the input format ignores all predicates that cannot be expressed
// as simple column predicate filters in Parquet. Here we just record
// the whole pruning predicate.
if (columnPruningPred.length > 0) {
ParquetFilters
.createRecordFilter(columnPruningPred)
.map(_.asInstanceOf[FilterPredicateCompat].getFilterPredicate)
// Set this in configuration of ParquetInputFormat, needed for RowGroupFiltering
val filter: Filter = ParquetFilters.createRecordFilter(columnPruningPred)
if (filter != null){
val filterPredicate = filter.asInstanceOf[FilterPredicateCompat].getFilterPredicate
ParquetInputFormat.setFilterPredicate(conf, filterPredicate)
}
}
.foreach(ParquetInputFormat.setFilterPredicate(conf, _))
// Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata
conf.set(
@ -317,7 +308,7 @@ case class InsertIntoParquetTable(
}
writer.close(hadoopContext)
committer.commitTask(hadoopContext)
return 1
1
}
val jobFormat = new AppendingParquetOutputFormat(taskIdOffset)
/* apparently we need a TaskAttemptID to construct an OutputCommitter;
@ -375,9 +366,8 @@ private[parquet] class FilteringParquetRowInputFormat
override def createRecordReader(
inputSplit: InputSplit,
taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = {
import parquet.filter2.compat.FilterCompat.NoOpFilter
import parquet.filter2.compat.FilterCompat.Filter
val readSupport: ReadSupport[Row] = new RowReadSupport()
@ -392,7 +382,7 @@ private[parquet] class FilteringParquetRowInputFormat
}
override def getFooters(jobContext: JobContext): JList[Footer] = {
import FilteringParquetRowInputFormat.footerCache
import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.footerCache
if (footers eq null) {
val conf = ContextUtil.getConfiguration(jobContext)
@ -442,13 +432,13 @@ private[parquet] class FilteringParquetRowInputFormat
val taskSideMetaData = configuration.getBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true)
val maxSplitSize: JLong = configuration.getLong("mapred.max.split.size", Long.MaxValue)
val minSplitSize: JLong =
Math.max(getFormatMinSplitSize(), configuration.getLong("mapred.min.split.size", 0L))
Math.max(getFormatMinSplitSize, configuration.getLong("mapred.min.split.size", 0L))
if (maxSplitSize < 0 || minSplitSize < 0) {
throw new ParquetDecodingException(
s"maxSplitSize or minSplitSie should not be negative: maxSplitSize = $maxSplitSize;" +
s" minSplitSize = $minSplitSize")
}
// Uses strict type checking by default
val getGlobalMetaData =
classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]])
@ -458,29 +448,29 @@ private[parquet] class FilteringParquetRowInputFormat
if (globalMetaData == null) {
val splits = mutable.ArrayBuffer.empty[ParquetInputSplit]
return splits
}
}
val readContext = getReadSupport(configuration).init(
new InitContext(configuration,
globalMetaData.getKeyValueMetaData(),
globalMetaData.getSchema()))
globalMetaData.getKeyValueMetaData,
globalMetaData.getSchema))
if (taskSideMetaData){
logInfo("Using Task Side Metadata Split Strategy")
return getTaskSideSplits(configuration,
getTaskSideSplits(configuration,
footers,
maxSplitSize,
minSplitSize,
readContext)
} else {
logInfo("Using Client Side Metadata Split Strategy")
return getClientSideSplits(configuration,
getClientSideSplits(configuration,
footers,
maxSplitSize,
minSplitSize,
readContext)
}
}
def getClientSideSplits(
@ -489,12 +479,11 @@ private[parquet] class FilteringParquetRowInputFormat
maxSplitSize: JLong,
minSplitSize: JLong,
readContext: ReadContext): JList[ParquetInputSplit] = {
import FilteringParquetRowInputFormat.blockLocationCache
import parquet.filter2.compat.FilterCompat;
import parquet.filter2.compat.FilterCompat.Filter;
import parquet.filter2.compat.RowGroupFilter;
import parquet.filter2.compat.FilterCompat.Filter
import parquet.filter2.compat.RowGroupFilter
import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache
val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true)
val splits = mutable.ArrayBuffer.empty[ParquetInputSplit]
@ -503,7 +492,7 @@ private[parquet] class FilteringParquetRowInputFormat
var totalRowGroups: Long = 0
// Ugly hack, stuck with it until PR:
// https://github.com/apache/incubator-parquet-mr/pull/17
// https://github.com/apache/incubator-parquet-mr/pull/17
// is resolved
val generateSplits =
Class.forName("parquet.hadoop.ClientSideMetadataSplitStrategy")
@ -523,7 +512,7 @@ private[parquet] class FilteringParquetRowInputFormat
blocks,
parquetMetaData.getFileMetaData.getSchema)
rowGroupsDropped = rowGroupsDropped + (blocks.size - filteredBlocks.size)
if (!filteredBlocks.isEmpty){
var blockLocations: Array[BlockLocation] = null
if (!cacheMetadata) {
@ -566,7 +555,7 @@ private[parquet] class FilteringParquetRowInputFormat
readContext: ReadContext): JList[ParquetInputSplit] = {
val splits = mutable.ArrayBuffer.empty[ParquetInputSplit]
// Ugly hack, stuck with it until PR:
// https://github.com/apache/incubator-parquet-mr/pull/17
// is resolved
@ -576,7 +565,7 @@ private[parquet] class FilteringParquetRowInputFormat
sys.error(
s"Failed to reflectively invoke TaskSideMetadataSplitStrategy.generateTaskSideMDSplits"))
generateSplits.setAccessible(true)
for (footer <- footers) {
val file = footer.getFile
val fs = file.getFileSystem(configuration)
@ -594,7 +583,7 @@ private[parquet] class FilteringParquetRowInputFormat
}
splits
}
}
}
@ -636,11 +625,9 @@ private[parquet] object FileSystemHelper {
files.map(_.getName).map {
case nameP(taskid) => taskid.toInt
case hiddenFileP() => 0
case other: String => {
case other: String =>
sys.error("ERROR: attempting to append to set of Parquet files and found file" +
s"that does not match name pattern: $other")
0
}
case _ => 0
}.reduceLeft((a, b) => if (a < b) b else a)
}

View file

@ -17,11 +17,13 @@
package org.apache.spark.sql.parquet
import _root_.parquet.filter2.predicate.{FilterPredicate, Operators}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
import org.scalatest.{BeforeAndAfterAll, FunSuiteLike}
import parquet.hadoop.ParquetFileWriter
import parquet.hadoop.util.ContextUtil
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.IntegerType
@ -447,44 +449,24 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(true)
}
test("create RecordFilter for simple predicates") {
val attribute1 = new AttributeReference("first", IntegerType, false)()
val predicate1 = new EqualTo(attribute1, new Literal(1, IntegerType))
val filter1 = ParquetFilters.createFilter(predicate1)
assert(filter1.isDefined)
assert(filter1.get.predicate == predicate1, "predicates do not match")
assert(filter1.get.isInstanceOf[ComparisonFilter])
val cmpFilter1 = filter1.get.asInstanceOf[ComparisonFilter]
assert(cmpFilter1.columnName == "first", "column name incorrect")
test("make RecordFilter for simple predicates") {
def checkFilter[T <: FilterPredicate](predicate: Expression, defined: Boolean = true): Unit = {
val filter = ParquetFilters.createFilter(predicate)
if (defined) {
assert(filter.isDefined)
assert(filter.get.isInstanceOf[T])
} else {
assert(filter.isEmpty)
}
}
val predicate2 = new LessThan(attribute1, new Literal(4, IntegerType))
val filter2 = ParquetFilters.createFilter(predicate2)
assert(filter2.isDefined)
assert(filter2.get.predicate == predicate2, "predicates do not match")
assert(filter2.get.isInstanceOf[ComparisonFilter])
val cmpFilter2 = filter2.get.asInstanceOf[ComparisonFilter]
assert(cmpFilter2.columnName == "first", "column name incorrect")
checkFilter[Operators.Eq[Integer]]('a.int === 1)
checkFilter[Operators.Lt[Integer]]('a.int < 4)
checkFilter[Operators.And]('a.int === 1 && 'a.int < 4)
checkFilter[Operators.Or]('a.int === 1 || 'a.int < 4)
val predicate3 = new And(predicate1, predicate2)
val filter3 = ParquetFilters.createFilter(predicate3)
assert(filter3.isDefined)
assert(filter3.get.predicate == predicate3, "predicates do not match")
assert(filter3.get.isInstanceOf[AndFilter])
val predicate4 = new Or(predicate1, predicate2)
val filter4 = ParquetFilters.createFilter(predicate4)
assert(filter4.isDefined)
assert(filter4.get.predicate == predicate4, "predicates do not match")
assert(filter4.get.isInstanceOf[OrFilter])
val attribute2 = new AttributeReference("second", IntegerType, false)()
val predicate5 = new GreaterThan(attribute1, attribute2)
val badfilter = ParquetFilters.createFilter(predicate5)
assert(badfilter.isDefined === false)
val predicate6 = And(GreaterThan(attribute1, attribute2), GreaterThan(attribute1, attribute2))
val badfilter2 = ParquetFilters.createFilter(predicate6)
assert(badfilter2.isDefined === false)
checkFilter('a.int > 'b.int, defined = false)
checkFilter(('a.int > 'b.int) && ('a.int > 'b.int), defined = false)
}
test("test filter by predicate pushdown") {