From 4f309cec07b1b1e5f7cddb0f98598fb1a234c2bd Mon Sep 17 00:00:00 2001 From: "sandeep.katta" Date: Wed, 21 Apr 2021 15:16:17 +0800 Subject: [PATCH] [SPARK-35096][SQL] SchemaPruning should adhere spark.sql.caseSensitive config ### What changes were proposed in this pull request? As a part of the SPARK-26837 pruning of nested fields from object serializers are supported. But it is missed to handle case insensitivity nature of spark In this PR I have resolved the column names to be pruned based on `spark.sql.caseSensitive ` config **Exception Before Fix** ``` Caused by: java.lang.ArrayIndexOutOfBoundsException: 0 at org.apache.spark.sql.types.StructType.apply(StructType.scala:414) at org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning$$anonfun$apply$4.$anonfun$applyOrElse$3(objects.scala:216) at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238) at scala.collection.immutable.List.foreach(List.scala:392) at scala.collection.TraversableLike.map(TraversableLike.scala:238) at scala.collection.TraversableLike.map$(TraversableLike.scala:231) at scala.collection.immutable.List.map(List.scala:298) at org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning$$anonfun$apply$4.applyOrElse(objects.scala:215) at org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning$$anonfun$apply$4.applyOrElse(objects.scala:203) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDown$1(TreeNode.scala:309) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:72) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:309) at ``` ### Why are the changes needed? After Upgrade to Spark 3 `foreachBatch` API throws` java.lang.ArrayIndexOutOfBoundsException`. This issue will be fixed using this PR ### Does this PR introduce _any_ user-facing change? No, Infact fixes the regression ### How was this patch tested? Added tests and also tested verified manually Closes #32194 from sandeep-katta/SPARK-35096. Authored-by: sandeep.katta Signed-off-by: Wenchen Fan --- .../catalyst/expressions/SchemaPruning.scala | 15 ++++--- .../expressions/SchemaPruningSuite.scala | 43 ++++++++++++++++++- 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala index 6213267c41..4ee6488c92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.types._ -object SchemaPruning { +object SchemaPruning extends SQLConfHelper { /** * Filters the schema by the requested fields. For example, if the schema is struct, * and given requested field are "a", the field "b" is pruned in the returned schema. @@ -28,6 +29,7 @@ object SchemaPruning { def pruneDataSchema( dataSchema: StructType, requestedRootFields: Seq[RootField]): StructType = { + val resolver = conf.resolver // Merge the requested root fields into a single schema. Note the ordering of the fields // in the resulting schema may differ from their ordering in the logical relation's // original schema @@ -36,7 +38,7 @@ object SchemaPruning { .reduceLeft(_ merge _) val dataSchemaFieldNames = dataSchema.fieldNames.toSet val mergedDataSchema = - StructType(mergedSchema.filter(f => dataSchemaFieldNames.contains(f.name))) + StructType(mergedSchema.filter(f => dataSchemaFieldNames.exists(resolver(_, f.name)))) // Sort the fields of mergedDataSchema according to their order in dataSchema, // recursively. This makes mergedDataSchema a pruned schema of dataSchema sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType] @@ -61,12 +63,15 @@ object SchemaPruning { sortLeftFieldsByRight(leftValueType, rightValueType), containsNull) case (leftStruct: StructType, rightStruct: StructType) => - val filteredRightFieldNames = rightStruct.fieldNames.filter(leftStruct.fieldNames.contains) + val resolver = conf.resolver + val filteredRightFieldNames = rightStruct.fieldNames + .filter(name => leftStruct.fieldNames.exists(resolver(_, name))) val sortedLeftFields = filteredRightFieldNames.map { fieldName => - val leftFieldType = leftStruct(fieldName).dataType + val resolvedLeftStruct = leftStruct.find(p => resolver(p.name, fieldName)).get + val leftFieldType = resolvedLeftStruct.dataType val rightFieldType = rightStruct(fieldName).dataType val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType) - StructField(fieldName, sortedLeftFieldType, nullable = leftStruct(fieldName).nullable) + StructField(fieldName, sortedLeftFieldType, nullable = resolvedLeftStruct.nullable) } StructType(sortedLeftFields) case _ => left diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala index c04f59ebb1..7895f4d5ef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala @@ -18,9 +18,20 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.SchemaPruning.RootField +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.internal.SQLConf.CASE_SENSITIVE import org.apache.spark.sql.types._ -class SchemaPruningSuite extends SparkFunSuite { +class SchemaPruningSuite extends SparkFunSuite with SQLHelper { + + def getRootFields(requestedFields: StructField*): Seq[RootField] = { + requestedFields.map { f => + // `derivedFromAtt` doesn't affect the result of pruned schema. + SchemaPruning.RootField(field = f, derivedFromAtt = true) + } + } + test("prune schema by the requested fields") { def testPrunedSchema( schema: StructType, @@ -59,4 +70,34 @@ class SchemaPruningSuite extends SparkFunSuite { StructType.fromDDL("e int, f string"))) testPrunedSchema(complexStruct, StructField("c", IntegerType), selectFieldInMap) } + + test("SPARK-35096: test case insensitivity of pruned schema") { + Seq(true, false).foreach(isCaseSensitive => { + withSQLConf(CASE_SENSITIVE.key -> isCaseSensitive.toString) { + if (isCaseSensitive) { + // Schema is case-sensitive + val requestedFields = getRootFields(StructField("id", IntegerType)) + val prunedSchema = SchemaPruning.pruneDataSchema( + StructType.fromDDL("ID int, name String"), requestedFields) + assert(prunedSchema == StructType(Seq.empty)) + // Root fields are case-sensitive + val rootFieldsSchema = SchemaPruning.pruneDataSchema( + StructType.fromDDL("id int, name String"), + getRootFields(StructField("ID", IntegerType))) + assert(rootFieldsSchema == StructType(StructType(Seq.empty))) + } else { + // Schema is case-insensitive + val prunedSchema = SchemaPruning.pruneDataSchema( + StructType.fromDDL("ID int, name String"), + getRootFields(StructField("id", IntegerType))) + assert(prunedSchema == StructType(StructField("ID", IntegerType) :: Nil)) + // Root fields are case-insensitive + val rootFieldsSchema = SchemaPruning.pruneDataSchema( + StructType.fromDDL("id int, name String"), + getRootFields(StructField("ID", IntegerType))) + assert(rootFieldsSchema == StructType(StructField("id", IntegerType) :: Nil)) + } + } + }) + } }