diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 69badb4f7d..4dff1ec7eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -153,6 +153,11 @@ class OrcFileFormat filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + if (sparkSession.sessionState.conf.orcFilterPushDown) { + OrcFilters.createFilter(dataSchema, filters).foreach { f => + OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames) + } + } val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields) val sqlConf = sparkSession.sessionState.conf @@ -164,8 +169,6 @@ class OrcFileFormat val broadcastedConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val orcFilterPushDown = sparkSession.sessionState.conf.orcFilterPushDown - val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles (file: PartitionedFile) => { val conf = broadcastedConf.value.value @@ -183,15 +186,6 @@ class OrcFileFormat if (resultedColPruneInfo.isEmpty) { Iterator.empty } else { - // ORC predicate pushdown - if (orcFilterPushDown) { - OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).map { fileSchema => - OrcFilters.createFilter(fileSchema, filters).foreach { f => - OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) - } - } - } - val (requestedColIds, canPruneCols) = resultedColPruneInfo.get val resultSchemaString = OrcUtils.orcResultSchemaString(canPruneCols, dataSchema, resultSchema, partitionSchema, conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala index ee0c08dd93..b277b4da1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala @@ -39,8 +39,6 @@ trait OrcFiltersBase { } } - case class OrcPrimitiveField(fieldName: String, fieldType: DataType) - /** * This method returns a map which contains ORC field name and data type. Each key * represents a column; `dots` are used as separators for nested columns. If any part @@ -51,21 +49,19 @@ trait OrcFiltersBase { */ protected[sql] def getSearchableTypeMap( schema: StructType, - caseSensitive: Boolean): Map[String, OrcPrimitiveField] = { + caseSensitive: Boolean): Map[String, DataType] = { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper def getPrimitiveFields( fields: Seq[StructField], - parentFieldNames: Seq[String] = Seq.empty): Seq[(String, OrcPrimitiveField)] = { + parentFieldNames: Seq[String] = Seq.empty): Seq[(String, DataType)] = { fields.flatMap { f => f.dataType match { case st: StructType => getPrimitiveFields(st.fields, parentFieldNames :+ f.name) case BinaryType => None case _: AtomicType => - val fieldName = (parentFieldNames :+ f.name).quoted - val orcField = OrcPrimitiveField(fieldName, f.dataType) - Some((fieldName, orcField)) + Some(((parentFieldNames :+ f.name).quoted, f.dataType)) case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 264cf8165e..072e670081 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -92,20 +92,6 @@ object OrcUtils extends Logging { } } - def readCatalystSchema( - file: Path, - conf: Configuration, - ignoreCorruptFiles: Boolean): Option[StructType] = { - readSchema(file, conf, ignoreCorruptFiles) match { - case Some(schema) => - Some(CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType]) - - case None => - // Field names is empty or `FileFormatException` was thrown but ignoreCorruptFiles is true. - None - } - } - /** * Reads ORC file schemas in multi-threaded manner, using native version of ORC. * This is visible for testing. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index 1f38128e98..7f25f7bd13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -31,10 +31,9 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} import org.apache.spark.sql.execution.datasources.PartitionedFile -import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcFilters, OrcUtils} +import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcUtils} import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.{AtomicType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -53,13 +52,10 @@ case class OrcPartitionReaderFactory( broadcastedConf: Broadcast[SerializableConfiguration], dataSchema: StructType, readDataSchema: StructType, - partitionSchema: StructType, - filters: Array[Filter]) extends FilePartitionReaderFactory { + partitionSchema: StructType) extends FilePartitionReaderFactory { private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val capacity = sqlConf.orcVectorizedReaderBatchSize - private val orcFilterPushDown = sqlConf.orcFilterPushDown - private val ignoreCorruptFiles = sqlConf.ignoreCorruptFiles override def supportColumnarReads(partition: InputPartition): Boolean = { sqlConf.orcVectorizedReaderEnabled && sqlConf.wholeStageEnabled && @@ -67,16 +63,6 @@ case class OrcPartitionReaderFactory( resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) } - private def pushDownPredicates(filePath: Path, conf: Configuration): Unit = { - if (orcFilterPushDown) { - OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).map { fileSchema => - OrcFilters.createFilter(fileSchema, filters).foreach { f => - OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) - } - } - } - } - override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { val conf = broadcastedConf.value.value @@ -84,8 +70,6 @@ case class OrcPartitionReaderFactory( val filePath = new Path(new URI(file.filePath)) - pushDownPredicates(filePath, conf) - val fs = filePath.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val resultedColPruneInfo = @@ -132,8 +116,6 @@ case class OrcPartitionReaderFactory( val filePath = new Path(new URI(file.filePath)) - pushDownPredicates(filePath, conf) - val fs = filePath.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val resultedColPruneInfo = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 1710abed57..38b8ced51a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -48,7 +48,7 @@ case class OrcScan( // The partition values are already truncated in `FileScan.partitions`. // We should use `readPartitionSchema` as the partition schema here. OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, - dataSchema, readDataSchema, readPartitionSchema, pushedFilters) + dataSchema, readDataSchema, readPartitionSchema) } override def equals(obj: Any): Boolean = obj match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 2f9387532c..0330dacffa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -56,6 +56,11 @@ case class OrcScanBuilder( override def pushFilters(filters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.orcFilterPushDown) { + OrcFilters.createFilter(schema, filters).foreach { f => + // The pushed filters will be set in `hadoopConf`. After that, we can simply use the + // changed `hadoopConf` in executors. + OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames) + } val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis) _pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, filters).toArray } diff --git a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 359e075bff..bc11bb8c1d 100644 --- a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -81,7 +81,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { def convertibleFilters( schema: StructType, - dataTypeMap: Map[String, OrcPrimitiveField], + dataTypeMap: Map[String, DataType], filters: Seq[Filter]): Seq[Filter] = { import org.apache.spark.sql.sources._ @@ -179,7 +179,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { * @return the builder so far. */ private def buildSearchArgument( - dataTypeMap: Map[String, OrcPrimitiveField], + dataTypeMap: Map[String, DataType], expression: Filter, builder: Builder): Builder = { import org.apache.spark.sql.sources._ @@ -215,7 +215,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { * @return the builder so far. */ private def buildLeafSearchArgument( - dataTypeMap: Map[String, OrcPrimitiveField], + dataTypeMap: Map[String, DataType], expression: Filter, builder: Builder): Option[Builder] = { def getType(attribute: String): PredicateLeaf.Type = @@ -228,44 +228,38 @@ private[sql] object OrcFilters extends OrcFiltersBase { // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). expression match { case EqualTo(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startAnd() - .equals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().equals(name, getType(name), castedValue).end()) case EqualNullSafe(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startAnd() - .nullSafeEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end()) case LessThan(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startAnd() - .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().lessThan(name, getType(name), castedValue).end()) case LessThanOrEqual(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startAnd() - .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end()) case GreaterThan(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startNot() - .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end()) case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startNot() - .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startNot().lessThan(name, getType(name), castedValue).end()) case IsNull(name) if dataTypeMap.contains(name) => - Some(builder.startAnd().isNull(dataTypeMap(name).fieldName, getType(name)).end()) + Some(builder.startAnd().isNull(name, getType(name)).end()) case IsNotNull(name) if dataTypeMap.contains(name) => - Some(builder.startNot().isNull(dataTypeMap(name).fieldName, getType(name)).end()) + Some(builder.startNot().isNull(name, getType(name)).end()) case In(name, values) if dataTypeMap.contains(name) => - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name).fieldType)) - Some(builder.startAnd().in(dataTypeMap(name).fieldName, getType(name), + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name))) + Some(builder.startAnd().in(name, getType(name), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None diff --git a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index e159a0588d..dfb3595be9 100644 --- a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -24,7 +24,6 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} -import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row} @@ -587,7 +586,8 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { checkAnswer(sql(s"select a from $tableName"), (0 until count).map(c => Row(c - 1))) val actual = stripSparkFilter(sql(s"select a from $tableName where a < 0")) - assert(actual.count() == 1) + // TODO: ORC predicate pushdown should work under case-insensitive analysis. + // assert(actual.count() == 1) } } @@ -606,71 +606,5 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } } } - - test("SPARK-32646: Case-insensitive field resolution for pushdown when reading ORC") { - import org.apache.spark.sql.sources._ - - def getOrcFilter( - schema: StructType, - filters: Seq[Filter], - caseSensitive: String): Option[SearchArgument] = { - var orcFilter: Option[SearchArgument] = None - withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { - orcFilter = - OrcFilters.createFilter(schema, filters) - } - orcFilter - } - - def testFilter( - schema: StructType, - filters: Seq[Filter], - expected: SearchArgument): Unit = { - val caseSensitiveFilters = getOrcFilter(schema, filters, "true") - val caseInsensitiveFilters = getOrcFilter(schema, filters, "false") - - assert(caseSensitiveFilters.isEmpty) - assert(caseInsensitiveFilters.isDefined) - - assert(caseInsensitiveFilters.get.getLeaves().size() > 0) - assert(caseInsensitiveFilters.get.getLeaves().size() == expected.getLeaves().size()) - (0 until expected.getLeaves().size()).foreach { index => - assert(caseInsensitiveFilters.get.getLeaves().get(index) == expected.getLeaves().get(index)) - } - } - - val schema1 = StructType(Seq(StructField("cint", IntegerType))) - testFilter(schema1, Seq(GreaterThan("CINT", 1)), - newBuilder.startNot() - .lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) - testFilter(schema1, Seq( - And(GreaterThan("CINT", 1), EqualTo("Cint", 2))), - newBuilder.startAnd() - .startNot() - .lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`() - .equals("cint", OrcFilters.getPredicateLeafType(IntegerType), 2L) - .`end`().build()) - - // Nested column case - val schema2 = StructType(Seq(StructField("a", - StructType(Seq(StructField("cint", IntegerType)))))) - - testFilter(schema2, Seq(GreaterThan("A.CINT", 1)), - newBuilder.startNot() - .lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) - testFilter(schema2, Seq(GreaterThan("a.CINT", 1)), - newBuilder.startNot() - .lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) - testFilter(schema2, Seq(GreaterThan("A.cint", 1)), - newBuilder.startNot() - .lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) - testFilter(schema2, Seq( - And(GreaterThan("a.CINT", 1), EqualTo("a.Cint", 2))), - newBuilder.startAnd() - .startNot() - .lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`() - .equals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 2L) - .`end`().build()) - } } diff --git a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 9acb04c8ca..5273245fae 100644 --- a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -81,7 +81,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { def convertibleFilters( schema: StructType, - dataTypeMap: Map[String, OrcPrimitiveField], + dataTypeMap: Map[String, DataType], filters: Seq[Filter]): Seq[Filter] = { import org.apache.spark.sql.sources._ @@ -139,7 +139,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { /** * Get PredicateLeafType which is corresponding to the given DataType. */ - private[sql] def getPredicateLeafType(dataType: DataType) = dataType match { + private def getPredicateLeafType(dataType: DataType) = dataType match { case BooleanType => PredicateLeaf.Type.BOOLEAN case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG case FloatType | DoubleType => PredicateLeaf.Type.FLOAT @@ -179,7 +179,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { * @return the builder so far. */ private def buildSearchArgument( - dataTypeMap: Map[String, OrcPrimitiveField], + dataTypeMap: Map[String, DataType], expression: Filter, builder: Builder): Builder = { import org.apache.spark.sql.sources._ @@ -215,11 +215,11 @@ private[sql] object OrcFilters extends OrcFiltersBase { * @return the builder so far. */ private def buildLeafSearchArgument( - dataTypeMap: Map[String, OrcPrimitiveField], + dataTypeMap: Map[String, DataType], expression: Filter, builder: Builder): Option[Builder] = { def getType(attribute: String): PredicateLeaf.Type = - getPredicateLeafType(dataTypeMap(attribute).fieldType) + getPredicateLeafType(dataTypeMap(attribute)) import org.apache.spark.sql.sources._ @@ -228,46 +228,38 @@ private[sql] object OrcFilters extends OrcFiltersBase { // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). expression match { case EqualTo(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startAnd() - .equals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().equals(name, getType(name), castedValue).end()) case EqualNullSafe(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startAnd() - .nullSafeEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end()) case LessThan(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startAnd() - .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().lessThan(name, getType(name), castedValue).end()) case LessThanOrEqual(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startAnd() - .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end()) case GreaterThan(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startNot() - .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end()) case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startNot() - .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + val castedValue = castLiteralValue(value, dataTypeMap(name)) + Some(builder.startNot().lessThan(name, getType(name), castedValue).end()) case IsNull(name) if dataTypeMap.contains(name) => - Some(builder.startAnd() - .isNull(dataTypeMap(name).fieldName, getType(name)).end()) + Some(builder.startAnd().isNull(name, getType(name)).end()) case IsNotNull(name) if dataTypeMap.contains(name) => - Some(builder.startNot() - .isNull(dataTypeMap(name).fieldName, getType(name)).end()) + Some(builder.startNot().isNull(name, getType(name)).end()) case In(name, values) if dataTypeMap.contains(name) => - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name).fieldType)) - Some(builder.startAnd().in(dataTypeMap(name).fieldName, getType(name), + val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name))) + Some(builder.startAnd().in(name, getType(name), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None diff --git a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index afc83d7c39..84cd2777da 100644 --- a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -24,7 +24,6 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} -import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row} @@ -588,7 +587,8 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { checkAnswer(sql(s"select a from $tableName"), (0 until count).map(c => Row(c - 1))) val actual = stripSparkFilter(sql(s"select a from $tableName where a < 0")) - assert(actual.count() == 1) + // TODO: ORC predicate pushdown should work under case-insensitive analysis. + // assert(actual.count() == 1) } } @@ -607,71 +607,5 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } } } - - test("SPARK-32646: Case-insensitive field resolution for pushdown when reading ORC") { - import org.apache.spark.sql.sources._ - - def getOrcFilter( - schema: StructType, - filters: Seq[Filter], - caseSensitive: String): Option[SearchArgument] = { - var orcFilter: Option[SearchArgument] = None - withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { - orcFilter = - OrcFilters.createFilter(schema, filters) - } - orcFilter - } - - def testFilter( - schema: StructType, - filters: Seq[Filter], - expected: SearchArgument): Unit = { - val caseSensitiveFilters = getOrcFilter(schema, filters, "true") - val caseInsensitiveFilters = getOrcFilter(schema, filters, "false") - - assert(caseSensitiveFilters.isEmpty) - assert(caseInsensitiveFilters.isDefined) - - assert(caseInsensitiveFilters.get.getLeaves().size() > 0) - assert(caseInsensitiveFilters.get.getLeaves().size() == expected.getLeaves().size()) - (0 until expected.getLeaves().size()).foreach { index => - assert(caseInsensitiveFilters.get.getLeaves().get(index) == expected.getLeaves().get(index)) - } - } - - val schema1 = StructType(Seq(StructField("cint", IntegerType))) - testFilter(schema1, Seq(GreaterThan("CINT", 1)), - newBuilder.startNot() - .lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) - testFilter(schema1, Seq( - And(GreaterThan("CINT", 1), EqualTo("Cint", 2))), - newBuilder.startAnd() - .startNot() - .lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`() - .equals("cint", OrcFilters.getPredicateLeafType(IntegerType), 2L) - .`end`().build()) - - // Nested column case - val schema2 = StructType(Seq(StructField("a", - StructType(Seq(StructField("cint", IntegerType)))))) - - testFilter(schema2, Seq(GreaterThan("A.CINT", 1)), - newBuilder.startNot() - .lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) - testFilter(schema2, Seq(GreaterThan("a.CINT", 1)), - newBuilder.startNot() - .lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) - testFilter(schema2, Seq(GreaterThan("A.cint", 1)), - newBuilder.startNot() - .lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build()) - testFilter(schema2, Seq( - And(GreaterThan("a.CINT", 1), EqualTo("a.Cint", 2))), - newBuilder.startAnd() - .startNot() - .lessThanEquals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`() - .equals("a.cint", OrcFilters.getPredicateLeafType(IntegerType), 2L) - .`end`().build()) - } }