[SPARK-34133][AVRO] Respect case sensitivity when performing Catalyst-to-Avro field matching
### What changes were proposed in this pull request? Make the field name matching between Avro and Catalyst schemas, on both the reader and writer paths, respect the global SQL settings for case sensitivity (i.e. case-insensitive by default). `AvroSerializer` and `AvroDeserializer` share a common utility in `AvroUtils` to search for an Avro field to match a given Catalyst field. ### Why are the changes needed? Spark SQL is normally case-insensitive (by default), but currently when `AvroSerializer` and `AvroDeserializer` perform matching between Catalyst schemas and Avro schemas, the matching is done in a case-sensitive manner. So for example the following will fail: ```scala val avroSchema = """ |{ | "type" : "record", | "name" : "test_schema", | "fields" : [ | {"name": "foo", "type": "int"}, | {"name": "BAR", "type": "int"} | ] |} """.stripMargin val df = Seq((1, 3), (2, 4)).toDF("FOO", "bar") df.write.option("avroSchema", avroSchema).format("avro").save(savePath) ``` The same is true on the read path, if we assume `testAvro` has been written using the schema above, the below will fail to match the fields: ```scala df.read.schema(new StructType().add("FOO", IntegerType).add("bar", IntegerType)) .format("avro").load(testAvro) ``` ### Does this PR introduce _any_ user-facing change? When reading Avro data, or writing Avro data using the `avroSchema` option, field matching will be performed with case sensitivity respecting the global SQL settings. ### How was this patch tested? New tests added to `AvroSuite` to validate the case sensitivity logic in an end-to-end manner through the SQL engine. Closes #31201 from xkrogen/xkrogen-SPARK-34133-avro-serde-casesensitivity-errormessages. Authored-by: Erik Krogen <xkrogen@apache.org> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
144ee9eb30
commit
9371ea8c7b
|
@ -330,8 +330,8 @@ private[sql] class AvroDeserializer(
|
|||
var i = 0
|
||||
while (i < length) {
|
||||
val sqlField = sqlType.fields(i)
|
||||
val avroField = avroType.getField(sqlField.name)
|
||||
if (avroField != null) {
|
||||
AvroUtils.getAvroFieldByName(avroType, sqlField.name) match {
|
||||
case Some(avroField) =>
|
||||
validFieldIndexes += avroField.pos()
|
||||
|
||||
val baseWriter = newWriter(avroField.schema(), sqlField.dataType, path :+ sqlField.name)
|
||||
|
@ -344,13 +344,15 @@ private[sql] class AvroDeserializer(
|
|||
}
|
||||
}
|
||||
fieldWriters += fieldWriter
|
||||
} else if (!sqlField.nullable) {
|
||||
case None if !sqlField.nullable =>
|
||||
val fieldStr = s"${path.mkString(".")}.${sqlField.name}"
|
||||
throw new IncompatibleSchemaException(
|
||||
s"""
|
||||
|Cannot find non-nullable field ${path.mkString(".")}.${sqlField.name} in Avro schema.
|
||||
|Cannot find non-nullable field $fieldStr in Avro schema.
|
||||
|Source Avro schema: $rootAvroType.
|
||||
|Target Catalyst type: $rootCatalystType.
|
||||
""".stripMargin)
|
||||
case _ => // nothing to do
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
|
|
|
@ -230,10 +230,10 @@ private[sql] class AvroSerializer(
|
|||
|
||||
val (avroIndices: Array[Int], fieldConverters: Array[Converter]) =
|
||||
catalystStruct.map { catalystField =>
|
||||
val avroField = avroStruct.getField(catalystField.name)
|
||||
if (avroField == null) {
|
||||
throw new IncompatibleSchemaException(
|
||||
s"Cannot convert Catalyst type $catalystStruct to Avro type $avroStruct.")
|
||||
val avroField = AvroUtils.getAvroFieldByName(avroStruct, catalystField.name) match {
|
||||
case Some(f) => f
|
||||
case None => throw new IncompatibleSchemaException(
|
||||
s"Cannot find ${catalystField.name} in Avro schema")
|
||||
}
|
||||
val converter = newConverter(catalystField.dataType, resolveNullableType(
|
||||
avroField.schema(), catalystField.nullable))
|
||||
|
|
|
@ -18,6 +18,8 @@ package org.apache.spark.sql.avro
|
|||
|
||||
import java.io.{FileNotFoundException, IOException}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.avro.Schema
|
||||
import org.apache.avro.file.{DataFileReader, FileReader}
|
||||
import org.apache.avro.file.DataFileConstants.{BZIP2_CODEC, DEFLATE_CODEC, SNAPPY_CODEC, XZ_CODEC}
|
||||
|
@ -201,4 +203,33 @@ private[sql] object AvroUtils extends Logging {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract a single field from `avroSchema` which has the desired field name,
|
||||
* performing the matching with proper case sensitivity according to [[SQLConf.resolver]].
|
||||
*
|
||||
* @param avroSchema The schema in which to search for the field. Must be of type RECORD.
|
||||
* @param name The name of the field to search for.
|
||||
* @return `Some(match)` if a matching Avro field is found, otherwise `None`.
|
||||
* @throws IncompatibleSchemaException if `avroSchema` is not a RECORD or contains multiple
|
||||
* fields matching `name` (i.e., case-insensitive matching
|
||||
* is used and `avroSchema` has two or more fields that have
|
||||
* the same name with difference case).
|
||||
*/
|
||||
private[avro] def getAvroFieldByName(
|
||||
avroSchema: Schema,
|
||||
name: String): Option[Schema.Field] = {
|
||||
if (avroSchema.getType != Schema.Type.RECORD) {
|
||||
throw new IncompatibleSchemaException(
|
||||
s"Attempting to treat ${avroSchema.getName} as a RECORD, but it was: ${avroSchema.getType}")
|
||||
}
|
||||
avroSchema.getFields.asScala.filter(f => SQLConf.get.resolver(f.name(), name)).toSeq match {
|
||||
case Seq(avroField) => Some(avroField)
|
||||
case Seq() => None
|
||||
case matches => throw new IncompatibleSchemaException(
|
||||
s"Searching for '$name' in Avro schema gave ${matches.size} matches. Candidates: " +
|
||||
matches.map(_.name()).mkString("[", ", ", "]")
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,6 +37,7 @@ import org.apache.hadoop.conf.Configuration
|
|||
import org.apache.hadoop.fs.Path
|
||||
|
||||
import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, SparkUpgradeException}
|
||||
import org.apache.spark.TestUtils.assertExceptionMsg
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.TestingUDT.IntervalData
|
||||
import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters}
|
||||
|
@ -1261,6 +1262,94 @@ abstract class AvroSuite
|
|||
}
|
||||
}
|
||||
|
||||
test("SPARK-34133: Reading user provided schema respects case sensitivity for field matching") {
|
||||
val wrongCaseSchema = new StructType()
|
||||
.add("STRING", StringType, nullable = false)
|
||||
.add("UNION_STRING_NULL", StringType, nullable = true)
|
||||
val withSchema = spark.read
|
||||
.schema(wrongCaseSchema)
|
||||
.format("avro").load(testAvro).collect()
|
||||
|
||||
val withOutSchema = spark.read.format("avro").load(testAvro)
|
||||
.select("STRING", "UNION_STRING_NULL")
|
||||
.collect()
|
||||
assert(withSchema.sameElements(withOutSchema))
|
||||
|
||||
withSQLConf((SQLConf.CASE_SENSITIVE.key, "true")) {
|
||||
val out = spark.read.format("avro").schema(wrongCaseSchema).load(testAvro).collect()
|
||||
assert(out.forall(_.isNullAt(0)))
|
||||
assert(out.forall(_.isNullAt(1)))
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-34133: Writing user provided schema respects case sensitivity for field matching") {
|
||||
withTempDir { tempDir =>
|
||||
val avroSchema =
|
||||
"""
|
||||
|{
|
||||
| "type" : "record",
|
||||
| "name" : "test_schema",
|
||||
| "fields" : [
|
||||
| {"name": "foo", "type": "int"},
|
||||
| {"name": "BAR", "type": "int"}
|
||||
| ]
|
||||
|}
|
||||
""".stripMargin
|
||||
val df = Seq((1, 3), (2, 4)).toDF("FOO", "bar")
|
||||
|
||||
val savePath = s"$tempDir/save"
|
||||
df.write.option("avroSchema", avroSchema).format("avro").save(savePath)
|
||||
|
||||
val loaded = spark.read.format("avro").load(savePath)
|
||||
assert(loaded.schema === new StructType().add("foo", IntegerType).add("BAR", IntegerType))
|
||||
assert(loaded.collect().map(_.getInt(0)).toSet === Set(1, 2))
|
||||
assert(loaded.collect().map(_.getInt(1)).toSet === Set(3, 4))
|
||||
|
||||
withSQLConf((SQLConf.CASE_SENSITIVE.key, "true")) {
|
||||
val e = intercept[SparkException] {
|
||||
df.write.option("avroSchema", avroSchema).format("avro").save(s"$tempDir/save2")
|
||||
}
|
||||
assertExceptionMsg(e, "Cannot find FOO in Avro schema")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-34133: Writing user provided schema with multiple matching Avro fields fails") {
|
||||
withTempDir { tempDir =>
|
||||
val avroSchema =
|
||||
"""
|
||||
|{
|
||||
| "type" : "record",
|
||||
| "name" : "test_schema",
|
||||
| "fields" : [
|
||||
| {"name": "foo", "type": "int"},
|
||||
| {"name": "FOO", "type": "string"}
|
||||
| ]
|
||||
|}
|
||||
""".stripMargin
|
||||
|
||||
val errorMsg = "Searching for 'foo' in Avro schema gave 2 matches. Candidates: [foo, FOO]"
|
||||
assertExceptionMsg(intercept[SparkException] {
|
||||
val fooBarDf = Seq((1, "3"), (2, "4")).toDF("foo", "bar")
|
||||
fooBarDf.write.option("avroSchema", avroSchema).format("avro").save(s"$tempDir/save-fail")
|
||||
}, errorMsg)
|
||||
|
||||
val savePath = s"$tempDir/save"
|
||||
withSQLConf((SQLConf.CASE_SENSITIVE.key, "true")) {
|
||||
val fooFooDf = Seq((1, "3"), (2, "4")).toDF("foo", "FOO")
|
||||
fooFooDf.write.option("avroSchema", avroSchema).format("avro").save(savePath)
|
||||
|
||||
val loadedDf = spark.read.format("avro").schema(fooFooDf.schema).load(savePath)
|
||||
assert(loadedDf.collect().toSet === fooFooDf.collect().toSet)
|
||||
}
|
||||
|
||||
assertExceptionMsg(intercept[SparkException] {
|
||||
val fooSchema = new StructType().add("foo", IntegerType)
|
||||
spark.read.format("avro").schema(fooSchema).load(savePath).collect()
|
||||
}, errorMsg)
|
||||
}
|
||||
}
|
||||
|
||||
test("read avro with user defined schema: read partial columns") {
|
||||
val partialColumns = StructType(Seq(
|
||||
StructField("string", StringType, false),
|
||||
|
|
Loading…
Reference in a new issue