[SPARK-35096][SQL] SchemaPruning should adhere spark.sql.caseSensitive config
### What changes were proposed in this pull request? As a part of the SPARK-26837 pruning of nested fields from object serializers are supported. But it is missed to handle case insensitivity nature of spark In this PR I have resolved the column names to be pruned based on `spark.sql.caseSensitive ` config **Exception Before Fix** ``` Caused by: java.lang.ArrayIndexOutOfBoundsException: 0 at org.apache.spark.sql.types.StructType.apply(StructType.scala:414) at org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning$$anonfun$apply$4.$anonfun$applyOrElse$3(objects.scala:216) at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238) at scala.collection.immutable.List.foreach(List.scala:392) at scala.collection.TraversableLike.map(TraversableLike.scala:238) at scala.collection.TraversableLike.map$(TraversableLike.scala:231) at scala.collection.immutable.List.map(List.scala:298) at org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning$$anonfun$apply$4.applyOrElse(objects.scala:215) at org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning$$anonfun$apply$4.applyOrElse(objects.scala:203) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDown$1(TreeNode.scala:309) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:72) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:309) at ``` ### Why are the changes needed? After Upgrade to Spark 3 `foreachBatch` API throws` java.lang.ArrayIndexOutOfBoundsException`. This issue will be fixed using this PR ### Does this PR introduce _any_ user-facing change? No, Infact fixes the regression ### How was this patch tested? Added tests and also tested verified manually Closes #32194 from sandeep-katta/SPARK-35096. Authored-by: sandeep.katta <sandeep.katta2007@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
97ec57e667
commit
4f309cec07
|
@ -17,9 +17,10 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.catalyst.expressions
|
package org.apache.spark.sql.catalyst.expressions
|
||||||
|
|
||||||
|
import org.apache.spark.sql.catalyst.SQLConfHelper
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
object SchemaPruning {
|
object SchemaPruning extends SQLConfHelper {
|
||||||
/**
|
/**
|
||||||
* Filters the schema by the requested fields. For example, if the schema is struct<a:int, b:int>,
|
* Filters the schema by the requested fields. For example, if the schema is struct<a:int, b:int>,
|
||||||
* and given requested field are "a", the field "b" is pruned in the returned schema.
|
* and given requested field are "a", the field "b" is pruned in the returned schema.
|
||||||
|
@ -28,6 +29,7 @@ object SchemaPruning {
|
||||||
def pruneDataSchema(
|
def pruneDataSchema(
|
||||||
dataSchema: StructType,
|
dataSchema: StructType,
|
||||||
requestedRootFields: Seq[RootField]): StructType = {
|
requestedRootFields: Seq[RootField]): StructType = {
|
||||||
|
val resolver = conf.resolver
|
||||||
// Merge the requested root fields into a single schema. Note the ordering of the fields
|
// Merge the requested root fields into a single schema. Note the ordering of the fields
|
||||||
// in the resulting schema may differ from their ordering in the logical relation's
|
// in the resulting schema may differ from their ordering in the logical relation's
|
||||||
// original schema
|
// original schema
|
||||||
|
@ -36,7 +38,7 @@ object SchemaPruning {
|
||||||
.reduceLeft(_ merge _)
|
.reduceLeft(_ merge _)
|
||||||
val dataSchemaFieldNames = dataSchema.fieldNames.toSet
|
val dataSchemaFieldNames = dataSchema.fieldNames.toSet
|
||||||
val mergedDataSchema =
|
val mergedDataSchema =
|
||||||
StructType(mergedSchema.filter(f => dataSchemaFieldNames.contains(f.name)))
|
StructType(mergedSchema.filter(f => dataSchemaFieldNames.exists(resolver(_, f.name))))
|
||||||
// Sort the fields of mergedDataSchema according to their order in dataSchema,
|
// Sort the fields of mergedDataSchema according to their order in dataSchema,
|
||||||
// recursively. This makes mergedDataSchema a pruned schema of dataSchema
|
// recursively. This makes mergedDataSchema a pruned schema of dataSchema
|
||||||
sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType]
|
sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType]
|
||||||
|
@ -61,12 +63,15 @@ object SchemaPruning {
|
||||||
sortLeftFieldsByRight(leftValueType, rightValueType),
|
sortLeftFieldsByRight(leftValueType, rightValueType),
|
||||||
containsNull)
|
containsNull)
|
||||||
case (leftStruct: StructType, rightStruct: StructType) =>
|
case (leftStruct: StructType, rightStruct: StructType) =>
|
||||||
val filteredRightFieldNames = rightStruct.fieldNames.filter(leftStruct.fieldNames.contains)
|
val resolver = conf.resolver
|
||||||
|
val filteredRightFieldNames = rightStruct.fieldNames
|
||||||
|
.filter(name => leftStruct.fieldNames.exists(resolver(_, name)))
|
||||||
val sortedLeftFields = filteredRightFieldNames.map { fieldName =>
|
val sortedLeftFields = filteredRightFieldNames.map { fieldName =>
|
||||||
val leftFieldType = leftStruct(fieldName).dataType
|
val resolvedLeftStruct = leftStruct.find(p => resolver(p.name, fieldName)).get
|
||||||
|
val leftFieldType = resolvedLeftStruct.dataType
|
||||||
val rightFieldType = rightStruct(fieldName).dataType
|
val rightFieldType = rightStruct(fieldName).dataType
|
||||||
val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType)
|
val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType)
|
||||||
StructField(fieldName, sortedLeftFieldType, nullable = leftStruct(fieldName).nullable)
|
StructField(fieldName, sortedLeftFieldType, nullable = resolvedLeftStruct.nullable)
|
||||||
}
|
}
|
||||||
StructType(sortedLeftFields)
|
StructType(sortedLeftFields)
|
||||||
case _ => left
|
case _ => left
|
||||||
|
|
|
@ -18,9 +18,20 @@
|
||||||
package org.apache.spark.sql.catalyst.expressions
|
package org.apache.spark.sql.catalyst.expressions
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
|
import org.apache.spark.sql.catalyst.expressions.SchemaPruning.RootField
|
||||||
|
import org.apache.spark.sql.catalyst.plans.SQLHelper
|
||||||
|
import org.apache.spark.sql.internal.SQLConf.CASE_SENSITIVE
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
class SchemaPruningSuite extends SparkFunSuite {
|
class SchemaPruningSuite extends SparkFunSuite with SQLHelper {
|
||||||
|
|
||||||
|
def getRootFields(requestedFields: StructField*): Seq[RootField] = {
|
||||||
|
requestedFields.map { f =>
|
||||||
|
// `derivedFromAtt` doesn't affect the result of pruned schema.
|
||||||
|
SchemaPruning.RootField(field = f, derivedFromAtt = true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
test("prune schema by the requested fields") {
|
test("prune schema by the requested fields") {
|
||||||
def testPrunedSchema(
|
def testPrunedSchema(
|
||||||
schema: StructType,
|
schema: StructType,
|
||||||
|
@ -59,4 +70,34 @@ class SchemaPruningSuite extends SparkFunSuite {
|
||||||
StructType.fromDDL("e int, f string")))
|
StructType.fromDDL("e int, f string")))
|
||||||
testPrunedSchema(complexStruct, StructField("c", IntegerType), selectFieldInMap)
|
testPrunedSchema(complexStruct, StructField("c", IntegerType), selectFieldInMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-35096: test case insensitivity of pruned schema") {
|
||||||
|
Seq(true, false).foreach(isCaseSensitive => {
|
||||||
|
withSQLConf(CASE_SENSITIVE.key -> isCaseSensitive.toString) {
|
||||||
|
if (isCaseSensitive) {
|
||||||
|
// Schema is case-sensitive
|
||||||
|
val requestedFields = getRootFields(StructField("id", IntegerType))
|
||||||
|
val prunedSchema = SchemaPruning.pruneDataSchema(
|
||||||
|
StructType.fromDDL("ID int, name String"), requestedFields)
|
||||||
|
assert(prunedSchema == StructType(Seq.empty))
|
||||||
|
// Root fields are case-sensitive
|
||||||
|
val rootFieldsSchema = SchemaPruning.pruneDataSchema(
|
||||||
|
StructType.fromDDL("id int, name String"),
|
||||||
|
getRootFields(StructField("ID", IntegerType)))
|
||||||
|
assert(rootFieldsSchema == StructType(StructType(Seq.empty)))
|
||||||
|
} else {
|
||||||
|
// Schema is case-insensitive
|
||||||
|
val prunedSchema = SchemaPruning.pruneDataSchema(
|
||||||
|
StructType.fromDDL("ID int, name String"),
|
||||||
|
getRootFields(StructField("id", IntegerType)))
|
||||||
|
assert(prunedSchema == StructType(StructField("ID", IntegerType) :: Nil))
|
||||||
|
// Root fields are case-insensitive
|
||||||
|
val rootFieldsSchema = SchemaPruning.pruneDataSchema(
|
||||||
|
StructType.fromDDL("id int, name String"),
|
||||||
|
getRootFields(StructField("ID", IntegerType)))
|
||||||
|
assert(rootFieldsSchema == StructType(StructField("id", IntegerType) :: Nil))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue