diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a69f09bdce..196971a22a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2108,9 +2108,9 @@ object SQLConf { .doc("A comma-separated list of data source short names or fully qualified data source " + "implementation class names for which Spark tries to push down predicates for nested " + "columns and/or names containing `dots` to data sources. This configuration is only " + - "effective with file-based data source in DSv1. Currently, Parquet implements " + - "both optimizations while ORC only supports predicates for names containing `dots`. The " + - "other data sources don't support this feature yet. So the default value is 'parquet,orc'.") + "effective with file-based data sources in DSv1. Currently, Parquet and ORC implement " + + "both optimizations. The other data sources don't support this feature yet. So the " + + "default value is 'parquet,orc'.") .version("3.0.0") .stringConf .createWithDefault("parquet,orc") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 23454d7d5e..ada04c2382 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -668,6 +668,8 @@ abstract class PushableColumnBase { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper def helper(e: Expression): Option[Seq[String]] = e match { case a: Attribute => + // Attribute that contains dot "." in name is supported only when + // nested predicate pushdown is enabled. if (nestedPredicatePushdownEnabled || !a.name.contains(".")) { Some(Seq(a.name)) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala index e673309188..b277b4da1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.execution.datasources.orc +import java.util.Locale + +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.sources.{And, Filter} -import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType} +import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType, StructField, StructType} /** * Methods that can be shared when upgrading the built-in Hive. @@ -37,12 +40,45 @@ trait OrcFiltersBase { } /** - * Return true if this is a searchable type in ORC. - * Both CharType and VarcharType are cleaned at AstBuilder. + * This method returns a map which contains ORC field name and data type. Each key + * represents a column; `dots` are used as separators for nested columns. If any part + * of the names contains `dots`, it is quoted to avoid confusion. See + * `org.apache.spark.sql.connector.catalog.quoted` for implementation details. + * + * BinaryType, UserDefinedType, ArrayType and MapType are ignored. */ - protected[sql] def isSearchableType(dataType: DataType) = dataType match { - case BinaryType => false - case _: AtomicType => true - case _ => false + protected[sql] def getSearchableTypeMap( + schema: StructType, + caseSensitive: Boolean): Map[String, DataType] = { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + + def getPrimitiveFields( + fields: Seq[StructField], + parentFieldNames: Seq[String] = Seq.empty): Seq[(String, DataType)] = { + fields.flatMap { f => + f.dataType match { + case st: StructType => + getPrimitiveFields(st.fields, parentFieldNames :+ f.name) + case BinaryType => None + case _: AtomicType => + Some(((parentFieldNames :+ f.name).quoted, f.dataType)) + case _ => None + } + } + } + + val primitiveFields = getPrimitiveFields(schema.fields) + if (caseSensitive) { + primitiveFields.toMap + } else { + // Don't consider ambiguity here, i.e. more than one field are matched in case insensitive + // mode, just skip pushdown for these fields, they will trigger Exception when reading, + // See: SPARK-25175. + val dedupPrimitiveFields = primitiveFields + .groupBy(_._1.toLowerCase(Locale.ROOT)) + .filter(_._2.size == 1) + .mapValues(_.head._2) + CaseInsensitiveMap(dedupPrimitiveFields) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 9f40f5faa2..0330dacffa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcFilters import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -60,10 +61,8 @@ case class OrcScanBuilder( // changed `hadoopConf` in executors. OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames) } - val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap - // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed. - val newFilters = filters.filter(!_.containsNestedColumn) - _pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, newFilters).toArray + val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis) + _pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, filters).toArray } filters } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileBasedDataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileBasedDataSourceTest.scala index bdb161d59a..c2dc20b009 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileBasedDataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileBasedDataSourceTest.scala @@ -22,8 +22,10 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.{DataFrame, SaveMode} +import org.apache.spark.sql.{DataFrame, Row, SaveMode} +import org.apache.spark.sql.functions.struct import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.StructType /** * A helper trait that provides convenient facilities for file-based data source testing. @@ -103,4 +105,40 @@ private[sql] trait FileBasedDataSourceTest extends SQLTestUtils { df: DataFrame, path: File): Unit = { df.write.mode(SaveMode.Overwrite).format(dataSourceName).save(path.getCanonicalPath) } + + /** + * Takes single level `inputDF` dataframe to generate multi-level nested + * dataframes as new test data. It tests both non-nested and nested dataframes + * which are written and read back with specified datasource. + */ + protected def withNestedDataFrame(inputDF: DataFrame): Seq[(DataFrame, String, Any => Any)] = { + assert(inputDF.schema.fields.length == 1) + assert(!inputDF.schema.fields.head.dataType.isInstanceOf[StructType]) + val df = inputDF.toDF("temp") + Seq( + ( + df.withColumnRenamed("temp", "a"), + "a", // zero nesting + (x: Any) => x), + ( + df.withColumn("a", struct(df("temp") as "b")).drop("temp"), + "a.b", // one level nesting + (x: Any) => Row(x)), + ( + df.withColumn("a", struct(struct(df("temp") as "c") as "b")).drop("temp"), + "a.b.c", // two level nesting + (x: Any) => Row(Row(x)) + ), + ( + df.withColumnRenamed("temp", "a.b"), + "`a.b`", // zero nesting with column name containing `dots` + (x: Any) => x + ), + ( + df.withColumn("a.b", struct(df("temp") as "c.d") ).drop("temp"), + "`a.b`.`c.d`", // one level nesting with column names containing `dots` + (x: Any) => Row(x) + ) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index e929f904c7..aec61acda5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -143,4 +143,26 @@ abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with Befor FileUtils.copyURLToFile(url, file) spark.read.orc(file.getAbsolutePath) } + + /** + * Takes a sequence of products `data` to generate multi-level nested + * dataframes as new test data. It tests both non-nested and nested dataframes + * which are written and read back with Orc datasource. + * + * This is different from [[withOrcDataFrame]] which does not + * test nested cases. + */ + protected def withNestedOrcDataFrame[T <: Product: ClassTag: TypeTag](data: Seq[T]) + (runTest: (DataFrame, String, Any => Any) => Unit): Unit = + withNestedOrcDataFrame(spark.createDataFrame(data))(runTest) + + protected def withNestedOrcDataFrame(inputDF: DataFrame) + (runTest: (DataFrame, String, Any => Any) => Unit): Unit = { + withNestedDataFrame(inputDF).foreach { case (newDF, colName, resultFun) => + withTempPath { file => + newDF.write.format(dataSourceName).save(file.getCanonicalPath) + readFile(file.getCanonicalPath, true) { df => runTest(df, colName, resultFun) } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 8b922aaed4..5689b9d05d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -122,34 +122,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared private def withNestedParquetDataFrame(inputDF: DataFrame) (runTest: (DataFrame, String, Any => Any) => Unit): Unit = { - assert(inputDF.schema.fields.length == 1) - assert(!inputDF.schema.fields.head.dataType.isInstanceOf[StructType]) - val df = inputDF.toDF("temp") - Seq( - ( - df.withColumnRenamed("temp", "a"), - "a", // zero nesting - (x: Any) => x), - ( - df.withColumn("a", struct(df("temp") as "b")).drop("temp"), - "a.b", // one level nesting - (x: Any) => Row(x)), - ( - df.withColumn("a", struct(struct(df("temp") as "c") as "b")).drop("temp"), - "a.b.c", // two level nesting - (x: Any) => Row(Row(x)) - ), - ( - df.withColumnRenamed("temp", "a.b"), - "`a.b`", // zero nesting with column name containing `dots` - (x: Any) => x - ), - ( - df.withColumn("a.b", struct(df("temp") as "c.d") ).drop("temp"), - "`a.b`.`c.d`", // one level nesting with column names containing `dots` - (x: Any) => Row(x) - ) - ).foreach { case (newDF, colName, resultFun) => + withNestedDataFrame(inputDF).foreach { case (newDF, colName, resultFun) => withTempPath { file => newDF.write.format(dataSourceName).save(file.getCanonicalPath) readParquetFile(file.getCanonicalPath) { df => runTest(df, colName, resultFun) } diff --git a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index b68563956c..bc11bb8c1d 100644 --- a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -27,7 +27,7 @@ import org.apache.orc.storage.serde2.io.HiveDecimalWritable import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} -import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -68,11 +68,9 @@ private[sql] object OrcFilters extends OrcFiltersBase { * Create ORC filter as a SearchArgument instance. */ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap + val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis) // Combines all convertible filters using `And` to produce a single conjunction - // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed. - val newFilters = filters.filter(!_.containsNestedColumn) - val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, newFilters)) + val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) conjunctionOptional.map { conjunction => // Then tries to build a single ORC `SearchArgument` for the conjunction predicate. // The input predicate is fully convertible. There should not be any empty result in the @@ -228,40 +226,38 @@ private[sql] object OrcFilters extends OrcFiltersBase { // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()` // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). - // Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters - // in order to distinguish predicate pushdown for nested columns. expression match { - case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) => + case EqualTo(name, value) if dataTypeMap.contains(name) => val castedValue = castLiteralValue(value, dataTypeMap(name)) Some(builder.startAnd().equals(name, getType(name), castedValue).end()) - case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) => + case EqualNullSafe(name, value) if dataTypeMap.contains(name) => val castedValue = castLiteralValue(value, dataTypeMap(name)) Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end()) - case LessThan(name, value) if isSearchableType(dataTypeMap(name)) => + case LessThan(name, value) if dataTypeMap.contains(name) => val castedValue = castLiteralValue(value, dataTypeMap(name)) Some(builder.startAnd().lessThan(name, getType(name), castedValue).end()) - case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => + case LessThanOrEqual(name, value) if dataTypeMap.contains(name) => val castedValue = castLiteralValue(value, dataTypeMap(name)) Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end()) - case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) => + case GreaterThan(name, value) if dataTypeMap.contains(name) => val castedValue = castLiteralValue(value, dataTypeMap(name)) Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end()) - case GreaterThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => + case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) => val castedValue = castLiteralValue(value, dataTypeMap(name)) Some(builder.startNot().lessThan(name, getType(name), castedValue).end()) - case IsNull(name) if isSearchableType(dataTypeMap(name)) => + case IsNull(name) if dataTypeMap.contains(name) => Some(builder.startAnd().isNull(name, getType(name)).end()) - case IsNotNull(name) if isSearchableType(dataTypeMap(name)) => + case IsNotNull(name) if dataTypeMap.contains(name) => Some(builder.startNot().isNull(name, getType(name)).end()) - case In(name, values) if isSearchableType(dataTypeMap(name)) => + case In(name, values) if dataTypeMap.contains(name) => val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name))) Some(builder.startAnd().in(name, getType(name), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) diff --git a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index 88b4b243b5..2643196cac 100644 --- a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -92,155 +92,199 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } test("filter pushdown - integer") { - withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + withNestedOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { case (inputDF, colName, _) => + implicit val df: DataFrame = inputDF - checkFilterPredicate($"_1" === 1, PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val intAttr = df(colName).expr + assert(df(colName).expr.dataType === IntegerType) - checkFilterPredicate($"_1" < 2, PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= 4, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(intAttr.isNull, PredicateLeaf.Operator.IS_NULL) - checkFilterPredicate(Literal(1) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(1) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(2) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(3) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(1) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(4) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(intAttr === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(intAttr <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate(intAttr < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(intAttr > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(intAttr <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(intAttr >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === intAttr, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> intAttr, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > intAttr, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < intAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= intAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= intAttr, PredicateLeaf.Operator.LESS_THAN) } } test("filter pushdown - long") { - withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + withNestedOrcDataFrame( + (1 to 4).map(i => Tuple1(Option(i.toLong)))) { case (inputDF, colName, _) => + implicit val df: DataFrame = inputDF - checkFilterPredicate($"_1" === 1, PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val longAttr = df(colName).expr + assert(df(colName).expr.dataType === LongType) - checkFilterPredicate($"_1" < 2, PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= 4, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(longAttr.isNull, PredicateLeaf.Operator.IS_NULL) - checkFilterPredicate(Literal(1) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(1) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(2) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(3) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(1) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(4) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(longAttr === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(longAttr <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate(longAttr < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(longAttr > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(longAttr <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(longAttr >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === longAttr, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> longAttr, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > longAttr, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < longAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= longAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= longAttr, PredicateLeaf.Operator.LESS_THAN) } } test("filter pushdown - float") { - withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + withNestedOrcDataFrame( + (1 to 4).map(i => Tuple1(Option(i.toFloat)))) { case (inputDF, colName, _) => + implicit val df: DataFrame = inputDF - checkFilterPredicate($"_1" === 1, PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val floatAttr = df(colName).expr + assert(df(colName).expr.dataType === FloatType) - checkFilterPredicate($"_1" < 2, PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= 4, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(floatAttr.isNull, PredicateLeaf.Operator.IS_NULL) - checkFilterPredicate(Literal(1) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(1) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(2) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(3) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(1) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(4) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(floatAttr === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(floatAttr <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate(floatAttr < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(floatAttr > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(floatAttr <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(floatAttr >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === floatAttr, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> floatAttr, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > floatAttr, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < floatAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= floatAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= floatAttr, PredicateLeaf.Operator.LESS_THAN) } } test("filter pushdown - double") { - withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + withNestedOrcDataFrame( + (1 to 4).map(i => Tuple1(Option(i.toDouble)))) { case (inputDF, colName, _) => + implicit val df: DataFrame = inputDF - checkFilterPredicate($"_1" === 1, PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val doubleAttr = df(colName).expr + assert(df(colName).expr.dataType === DoubleType) - checkFilterPredicate($"_1" < 2, PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= 4, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(doubleAttr.isNull, PredicateLeaf.Operator.IS_NULL) - checkFilterPredicate(Literal(1) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(1) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(2) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(3) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(1) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(4) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(doubleAttr === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(doubleAttr <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate(doubleAttr < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(doubleAttr > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(doubleAttr <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(doubleAttr >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === doubleAttr, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> doubleAttr, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > doubleAttr, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < doubleAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= doubleAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= doubleAttr, PredicateLeaf.Operator.LESS_THAN) } } test("filter pushdown - string") { - withOrcDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + withNestedOrcDataFrame((1 to 4).map(i => Tuple1(i.toString))) { case (inputDF, colName, _) => + implicit val df: DataFrame = inputDF - checkFilterPredicate($"_1" === "1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> "1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val strAttr = df(colName).expr + assert(df(colName).expr.dataType === StringType) - checkFilterPredicate($"_1" < "2", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > "3", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= "1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= "4", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(strAttr.isNull, PredicateLeaf.Operator.IS_NULL) - checkFilterPredicate(Literal("1") === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal("1") <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal("2") > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal("3") < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal("1") >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal("4") <= $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(strAttr === "1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(strAttr <=> "1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate(strAttr < "2", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(strAttr > "3", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(strAttr <= "1", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(strAttr >= "4", PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal("1") === strAttr, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal("1") <=> strAttr, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal("2") > strAttr, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal("3") < strAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal("1") >= strAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal("4") <= strAttr, PredicateLeaf.Operator.LESS_THAN) } } test("filter pushdown - boolean") { - withOrcDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + withNestedOrcDataFrame( + (true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { case (inputDF, colName, _) => + implicit val df: DataFrame = inputDF - checkFilterPredicate($"_1" === true, PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val booleanAttr = df(colName).expr + assert(df(colName).expr.dataType === BooleanType) - checkFilterPredicate($"_1" < true, PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > false, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= false, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(booleanAttr.isNull, PredicateLeaf.Operator.IS_NULL) - checkFilterPredicate(Literal(false) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(false) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(false) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(true) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(true) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(true) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(booleanAttr === true, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(booleanAttr <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate(booleanAttr < true, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(booleanAttr > false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(booleanAttr <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(booleanAttr >= false, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(false) === booleanAttr, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(false) <=> booleanAttr, + PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(false) > booleanAttr, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(true) < booleanAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) >= booleanAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) <= booleanAttr, PredicateLeaf.Operator.LESS_THAN) } } test("filter pushdown - decimal") { - withOrcDataFrame((1 to 4).map(i => Tuple1.apply(BigDecimal.valueOf(i)))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + withNestedOrcDataFrame( + (1 to 4).map(i => Tuple1.apply(BigDecimal.valueOf(i)))) { case (inputDF, colName, _) => + implicit val df: DataFrame = inputDF - checkFilterPredicate($"_1" === BigDecimal.valueOf(1), PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> BigDecimal.valueOf(1), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val decimalAttr = df(colName).expr + assert(df(colName).expr.dataType === DecimalType(38, 18)) - checkFilterPredicate($"_1" < BigDecimal.valueOf(2), PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > BigDecimal.valueOf(3), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= BigDecimal.valueOf(1), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= BigDecimal.valueOf(4), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(decimalAttr.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate(decimalAttr === BigDecimal.valueOf(1), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(decimalAttr <=> BigDecimal.valueOf(1), + PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate(decimalAttr < BigDecimal.valueOf(2), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(decimalAttr > BigDecimal.valueOf(3), + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(decimalAttr <= BigDecimal.valueOf(1), + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(decimalAttr >= BigDecimal.valueOf(4), PredicateLeaf.Operator.LESS_THAN) checkFilterPredicate( - Literal(BigDecimal.valueOf(1)) === $"_1", PredicateLeaf.Operator.EQUALS) + Literal(BigDecimal.valueOf(1)) === decimalAttr, PredicateLeaf.Operator.EQUALS) checkFilterPredicate( - Literal(BigDecimal.valueOf(1)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + Literal(BigDecimal.valueOf(1)) <=> decimalAttr, PredicateLeaf.Operator.NULL_SAFE_EQUALS) checkFilterPredicate( - Literal(BigDecimal.valueOf(2)) > $"_1", PredicateLeaf.Operator.LESS_THAN) + Literal(BigDecimal.valueOf(2)) > decimalAttr, PredicateLeaf.Operator.LESS_THAN) checkFilterPredicate( - Literal(BigDecimal.valueOf(3)) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) + Literal(BigDecimal.valueOf(3)) < decimalAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) checkFilterPredicate( - Literal(BigDecimal.valueOf(1)) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) + Literal(BigDecimal.valueOf(1)) >= decimalAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) checkFilterPredicate( - Literal(BigDecimal.valueOf(4)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + Literal(BigDecimal.valueOf(4)) <= decimalAttr, PredicateLeaf.Operator.LESS_THAN) } } diff --git a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 4b642080d2..5273245fae 100644 --- a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} -import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ @@ -68,11 +68,9 @@ private[sql] object OrcFilters extends OrcFiltersBase { * Create ORC filter as a SearchArgument instance. */ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap + val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis) // Combines all convertible filters using `And` to produce a single conjunction - // TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed. - val newFilters = filters.filter(!_.containsNestedColumn) - val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, newFilters)) + val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) conjunctionOptional.map { conjunction => // Then tries to build a single ORC `SearchArgument` for the conjunction predicate. // The input predicate is fully convertible. There should not be any empty result in the @@ -228,40 +226,38 @@ private[sql] object OrcFilters extends OrcFiltersBase { // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()` // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). - // Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters - // in order to distinguish predicate pushdown for nested columns. expression match { - case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) => + case EqualTo(name, value) if dataTypeMap.contains(name) => val castedValue = castLiteralValue(value, dataTypeMap(name)) Some(builder.startAnd().equals(name, getType(name), castedValue).end()) - case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) => + case EqualNullSafe(name, value) if dataTypeMap.contains(name) => val castedValue = castLiteralValue(value, dataTypeMap(name)) Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end()) - case LessThan(name, value) if isSearchableType(dataTypeMap(name)) => + case LessThan(name, value) if dataTypeMap.contains(name) => val castedValue = castLiteralValue(value, dataTypeMap(name)) Some(builder.startAnd().lessThan(name, getType(name), castedValue).end()) - case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => + case LessThanOrEqual(name, value) if dataTypeMap.contains(name) => val castedValue = castLiteralValue(value, dataTypeMap(name)) Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end()) - case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) => + case GreaterThan(name, value) if dataTypeMap.contains(name) => val castedValue = castLiteralValue(value, dataTypeMap(name)) Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end()) - case GreaterThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) => + case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) => val castedValue = castLiteralValue(value, dataTypeMap(name)) Some(builder.startNot().lessThan(name, getType(name), castedValue).end()) - case IsNull(name) if isSearchableType(dataTypeMap(name)) => + case IsNull(name) if dataTypeMap.contains(name) => Some(builder.startAnd().isNull(name, getType(name)).end()) - case IsNotNull(name) if isSearchableType(dataTypeMap(name)) => + case IsNotNull(name) if dataTypeMap.contains(name) => Some(builder.startNot().isNull(name, getType(name)).end()) - case In(name, values) if isSearchableType(dataTypeMap(name)) => + case In(name, values) if dataTypeMap.contains(name) => val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name))) Some(builder.startAnd().in(name, getType(name), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) diff --git a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index 2263179515..7df9f29b42 100644 --- a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -30,9 +30,8 @@ import org.apache.spark.sql.{AnalysisException, Column, DataFrame} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} -import org.apache.spark.sql.execution.datasources.v2.orc.{OrcScan, OrcTable} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -93,155 +92,200 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } test("filter pushdown - integer") { - withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + withNestedOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { case (inputDF, colName, _) => + implicit val df: DataFrame = inputDF - checkFilterPredicate($"_1" === 1, PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val intAttr = df(colName).expr + assert(df(colName).expr.dataType === IntegerType) - checkFilterPredicate($"_1" < 2, PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= 4, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(intAttr.isNull, PredicateLeaf.Operator.IS_NULL) - checkFilterPredicate(Literal(1) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(1) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(2) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(3) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(1) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(4) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(intAttr === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(intAttr <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate(intAttr < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(intAttr > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(intAttr <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(intAttr >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === intAttr, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> intAttr, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > intAttr, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < intAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= intAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= intAttr, PredicateLeaf.Operator.LESS_THAN) } } test("filter pushdown - long") { - withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + withNestedOrcDataFrame( + (1 to 4).map(i => Tuple1(Option(i.toLong)))) { case (inputDF, colName, _) => + implicit val df: DataFrame = inputDF - checkFilterPredicate($"_1" === 1, PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val longAttr = df(colName).expr + assert(df(colName).expr.dataType === LongType) - checkFilterPredicate($"_1" < 2, PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= 4, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(longAttr.isNull, PredicateLeaf.Operator.IS_NULL) - checkFilterPredicate(Literal(1) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(1) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(2) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(3) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(1) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(4) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(longAttr === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(longAttr <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate(longAttr < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(longAttr > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(longAttr <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(longAttr >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === longAttr, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> longAttr, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > longAttr, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < longAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= longAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= longAttr, PredicateLeaf.Operator.LESS_THAN) } } test("filter pushdown - float") { - withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + withNestedOrcDataFrame( + (1 to 4).map(i => Tuple1(Option(i.toFloat)))) { case (inputDF, colName, _) => + implicit val df: DataFrame = inputDF - checkFilterPredicate($"_1" === 1, PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val floatAttr = df(colName).expr + assert(df(colName).expr.dataType === FloatType) - checkFilterPredicate($"_1" < 2, PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= 4, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(floatAttr.isNull, PredicateLeaf.Operator.IS_NULL) - checkFilterPredicate(Literal(1) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(1) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(2) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(3) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(1) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(4) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(floatAttr === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(floatAttr <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate(floatAttr < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(floatAttr > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(floatAttr <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(floatAttr >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === floatAttr, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> floatAttr, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > floatAttr, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < floatAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= floatAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= floatAttr, PredicateLeaf.Operator.LESS_THAN) } } test("filter pushdown - double") { - withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + withNestedOrcDataFrame( + (1 to 4).map(i => Tuple1(Option(i.toDouble)))) { case (inputDF, colName, _) => + implicit val df: DataFrame = inputDF - checkFilterPredicate($"_1" === 1, PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val doubleAttr = df(colName).expr + assert(df(colName).expr.dataType === DoubleType) - checkFilterPredicate($"_1" < 2, PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= 4, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(doubleAttr.isNull, PredicateLeaf.Operator.IS_NULL) - checkFilterPredicate(Literal(1) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(1) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(2) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(3) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(1) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(4) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(doubleAttr === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(doubleAttr <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate(doubleAttr < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(doubleAttr > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(doubleAttr <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(doubleAttr >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === doubleAttr, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> doubleAttr, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > doubleAttr, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < doubleAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= doubleAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= doubleAttr, PredicateLeaf.Operator.LESS_THAN) } } test("filter pushdown - string") { - withOrcDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + withNestedOrcDataFrame( + (1 to 4).map(i => Tuple1(i.toString))) { case (inputDF, colName, _) => + implicit val df: DataFrame = inputDF - checkFilterPredicate($"_1" === "1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> "1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val strAttr = df(colName).expr + assert(df(colName).expr.dataType === StringType) - checkFilterPredicate($"_1" < "2", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > "3", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= "1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= "4", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(strAttr.isNull, PredicateLeaf.Operator.IS_NULL) - checkFilterPredicate(Literal("1") === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal("1") <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal("2") > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal("3") < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal("1") >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal("4") <= $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(strAttr === "1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(strAttr <=> "1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate(strAttr < "2", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(strAttr > "3", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(strAttr <= "1", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(strAttr >= "4", PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal("1") === strAttr, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal("1") <=> strAttr, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal("2") > strAttr, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal("3") < strAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal("1") >= strAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal("4") <= strAttr, PredicateLeaf.Operator.LESS_THAN) } } test("filter pushdown - boolean") { - withOrcDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + withNestedOrcDataFrame( + (true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { case (inputDF, colName, _) => + implicit val df: DataFrame = inputDF - checkFilterPredicate($"_1" === true, PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val booleanAttr = df(colName).expr + assert(df(colName).expr.dataType === BooleanType) - checkFilterPredicate($"_1" < true, PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > false, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= false, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(booleanAttr.isNull, PredicateLeaf.Operator.IS_NULL) - checkFilterPredicate(Literal(false) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(false) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(false) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(true) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(true) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(true) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(booleanAttr === true, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(booleanAttr <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate(booleanAttr < true, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(booleanAttr > false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(booleanAttr <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(booleanAttr >= false, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(false) === booleanAttr, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(false) <=> booleanAttr, + PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(false) > booleanAttr, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(true) < booleanAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) >= booleanAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) <= booleanAttr, PredicateLeaf.Operator.LESS_THAN) } } test("filter pushdown - decimal") { - withOrcDataFrame((1 to 4).map(i => Tuple1.apply(BigDecimal.valueOf(i)))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + withNestedOrcDataFrame( + (1 to 4).map(i => Tuple1.apply(BigDecimal.valueOf(i)))) { case (inputDF, colName, _) => + implicit val df: DataFrame = inputDF - checkFilterPredicate($"_1" === BigDecimal.valueOf(1), PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> BigDecimal.valueOf(1), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val decimalAttr = df(colName).expr + assert(df(colName).expr.dataType === DecimalType(38, 18)) - checkFilterPredicate($"_1" < BigDecimal.valueOf(2), PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > BigDecimal.valueOf(3), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= BigDecimal.valueOf(1), PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= BigDecimal.valueOf(4), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(decimalAttr.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate(decimalAttr === BigDecimal.valueOf(1), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(decimalAttr <=> BigDecimal.valueOf(1), + PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate(decimalAttr < BigDecimal.valueOf(2), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(decimalAttr > BigDecimal.valueOf(3), + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(decimalAttr <= BigDecimal.valueOf(1), + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(decimalAttr >= BigDecimal.valueOf(4), PredicateLeaf.Operator.LESS_THAN) checkFilterPredicate( - Literal(BigDecimal.valueOf(1)) === $"_1", PredicateLeaf.Operator.EQUALS) + Literal(BigDecimal.valueOf(1)) === decimalAttr, PredicateLeaf.Operator.EQUALS) checkFilterPredicate( - Literal(BigDecimal.valueOf(1)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + Literal(BigDecimal.valueOf(1)) <=> decimalAttr, PredicateLeaf.Operator.NULL_SAFE_EQUALS) checkFilterPredicate( - Literal(BigDecimal.valueOf(2)) > $"_1", PredicateLeaf.Operator.LESS_THAN) + Literal(BigDecimal.valueOf(2)) > decimalAttr, PredicateLeaf.Operator.LESS_THAN) checkFilterPredicate( - Literal(BigDecimal.valueOf(3)) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) + Literal(BigDecimal.valueOf(3)) < decimalAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) checkFilterPredicate( - Literal(BigDecimal.valueOf(1)) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) + Literal(BigDecimal.valueOf(1)) >= decimalAttr, PredicateLeaf.Operator.LESS_THAN_EQUALS) checkFilterPredicate( - Literal(BigDecimal.valueOf(4)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + Literal(BigDecimal.valueOf(4)) <= decimalAttr, PredicateLeaf.Operator.LESS_THAN) } }