From be06e4156e9a7600a66822768b39af2f30b3a49e Mon Sep 17 00:00:00 2001 From: Erik Krogen Date: Mon, 2 Aug 2021 20:45:23 +0800 Subject: [PATCH] [SPARK-35918][AVRO] Unify schema mismatch handling for read/write and enhance error messages ### What changes were proposed in this pull request? This unifies struct schema mismatch-handling logic between `AvroSerializer` and `AvroDeserializer`, pushing it into `AvroUtils` which is used by both. The newly unified exception-handling logic is updated to provide more contextual information in error messages. When a schema mismatch is found, previously we would only report the first missing field that is found, but there may be any others as well, which can make it less clear what exactly is going wrong. Now, we will report on all missing fields. ### Why are the changes needed? While working on #31490, we discussed that there is room for improvement in how schema mismatch errors are reported ([comment1](https://github.com/apache/spark/pull/31490#discussion_r659970793), [comment2](https://github.com/apache/spark/pull/31490#issuecomment-869866848)). Additionally, the logic between `AvroSerializer` and `AvroDeserializer` was quite similar for handling these issues, but didn't share common code, causing duplication and making it harder to see exactly what differences existed between the two. ### Does this PR introduce _any_ user-facing change? Some error messages when matching Catalyst struct schemas against Avro record schemas now include more information. ### How was this patch tested? New unit tests added. Closes #33308 from xkrogen/xkrogen-SPARK-35918-avroserde-unify-better-error-messages. Authored-by: Erik Krogen Signed-off-by: Gengliang Wang --- .../spark/sql/avro/AvroDeserializer.scala | 50 +++++-------- .../spark/sql/avro/AvroSerializer.scala | 33 +++----- .../org/apache/spark/sql/avro/AvroUtils.scala | 75 +++++++++++++++---- .../sql/avro/AvroSchemaHelperSuite.scala | 59 ++++++++++++++- .../spark/sql/avro/AvroSerdeSuite.scala | 44 ++++++----- .../org/apache/spark/sql/avro/AvroSuite.scala | 3 +- 6 files changed, 169 insertions(+), 95 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index f1ae240f97..7955dbc00c 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -21,7 +21,6 @@ import java.math.BigDecimal import java.nio.ByteBuffer import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder} import org.apache.avro.Conversions.DecimalConversion @@ -30,7 +29,7 @@ import org.apache.avro.Schema.Type._ import org.apache.avro.generic._ import org.apache.avro.util.Utf8 -import org.apache.spark.sql.avro.AvroUtils.{toFieldDescription, toFieldStr} +import org.apache.spark.sql.avro.AvroUtils.{toFieldStr, AvroMatchedField} import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters} import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} @@ -352,39 +351,26 @@ private[sql] class AvroDeserializer( avroPath: Seq[String], catalystPath: Seq[String], applyFilters: Int => Boolean): (CatalystDataUpdater, GenericRecord) => Boolean = { - val validFieldIndexes = ArrayBuffer.empty[Int] - val fieldWriters = ArrayBuffer.empty[(CatalystDataUpdater, Any) => Unit] - val avroSchemaHelper = - new AvroUtils.AvroSchemaHelper(avroType, avroPath, positionalFieldMatch) - val length = catalystType.length - var i = 0 - while (i < length) { - val catalystField = catalystType.fields(i) - avroSchemaHelper.getAvroField(catalystField.name, i) match { - case Some(avroField) => - validFieldIndexes += avroField.pos() + val avroSchemaHelper = new AvroUtils.AvroSchemaHelper( + avroType, catalystType, avroPath, catalystPath, positionalFieldMatch) - val baseWriter = newWriter(avroField.schema(), catalystField.dataType, - avroPath :+ avroField.name, catalystPath :+ catalystField.name) - val ordinal = i - val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { - if (value == null) { - fieldUpdater.setNullAt(ordinal) - } else { - baseWriter(fieldUpdater, ordinal, value) - } + avroSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = true) + // no need to validateNoExtraAvroFields since extra Avro fields are ignored + + val (validFieldIndexes, fieldWriters) = avroSchemaHelper.matchedFields.map { + case AvroMatchedField(catalystField, ordinal, avroField) => + val baseWriter = newWriter(avroField.schema(), catalystField.dataType, + avroPath :+ avroField.name, catalystPath :+ catalystField.name) + val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { + if (value == null) { + fieldUpdater.setNullAt(ordinal) + } else { + baseWriter(fieldUpdater, ordinal, value) } - fieldWriters += fieldWriter - case None if !catalystField.nullable => - val fieldDescription = - toFieldDescription(catalystPath :+ catalystField.name, i, positionalFieldMatch) - throw new IncompatibleSchemaException( - s"Cannot find non-nullable $fieldDescription in Avro schema.") - case _ => // nothing to do - } - i += 1 - } + } + (avroField.pos(), fieldWriter) + }.toArray.unzip (fieldUpdater, record) => { var i = 0 diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index fc2ad7fb3e..1cb22bd41d 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -32,7 +32,7 @@ import org.apache.avro.generic.GenericData.Record import org.apache.avro.util.Utf8 import org.apache.spark.internal.Logging -import org.apache.spark.sql.avro.AvroUtils.{toFieldDescription, toFieldStr} +import org.apache.spark.sql.avro.AvroUtils.{toFieldStr, AvroMatchedField} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -252,34 +252,19 @@ private[sql] class AvroSerializer( catalystPath: Seq[String], avroPath: Seq[String]): InternalRow => Record = { - val avroPathStr = toFieldStr(avroPath) - if (avroStruct.getType != RECORD) { - throw new IncompatibleSchemaException(s"$avroPathStr was not a RECORD") - } - val avroFields = avroStruct.getFields.asScala - if (avroFields.size != catalystStruct.length) { - throw new IncompatibleSchemaException( - s"Avro $avroPathStr schema length (${avroFields.size}) doesn't match " + - s"SQL ${toFieldStr(catalystPath)} schema length (${catalystStruct.length})") - } - val avroSchemaHelper = - new AvroUtils.AvroSchemaHelper(avroStruct, avroPath, positionalFieldMatch) + val avroSchemaHelper = new AvroUtils.AvroSchemaHelper( + avroStruct, catalystStruct, avroPath, catalystPath, positionalFieldMatch) - val (avroIndices: Array[Int], fieldConverters: Array[Converter]) = - catalystStruct.zipWithIndex.map { case (catalystField, catalystPos) => - val avroField = avroSchemaHelper.getAvroField(catalystField.name, catalystPos) match { - case Some(f) => f - case None => - val fieldDescription = toFieldDescription( - catalystPath :+ catalystField.name, catalystPos, positionalFieldMatch) - throw new IncompatibleSchemaException( - s"Cannot find $fieldDescription in Avro schema at $avroPathStr") - } + avroSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = false) + avroSchemaHelper.validateNoExtraAvroFields() + + val (avroIndices, fieldConverters) = avroSchemaHelper.matchedFields.map { + case AvroMatchedField(catalystField, _, avroField) => val converter = newConverter(catalystField.dataType, resolveNullableType(avroField.schema(), catalystField.nullable), catalystPath :+ catalystField.name, avroPath :+ avroField.name) (avroField.pos(), converter) - }.toArray.unzip + }.toArray.unzip val numFields = catalystStruct.length row: InternalRow => diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index f09af7451b..149d0b6e73 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -205,18 +205,32 @@ private[sql] object AvroUtils extends Logging { } } + /** Wrapper for a pair of matched fields, one Catalyst and one corresponding Avro field. */ + case class AvroMatchedField( + catalystField: StructField, + catalystPosition: Int, + avroField: Schema.Field) + /** - * Wraps an Avro Schema object so that field lookups are faster. + * Helper class to perform field lookup/matching on Avro schemas. + * + * This will match `avroSchema` against `catalystSchema`, attempting to find a matching field in + * the Avro schema for each field in the Catalyst schema and vice-versa, respecting settings for + * case sensitivity. The match results can be accessed using the getter methods. * * @param avroSchema The schema in which to search for fields. Must be of type RECORD. + * @param catalystSchema The Catalyst schema to use for matching. * @param avroPath The seq of parent field names leading to `avroSchema`. + * @param catalystPath The seq of parent field names leading to `catalystSchema`. * @param positionalFieldMatch If true, perform field matching in a positional fashion * (structural comparison between schemas, ignoring names); * otherwise, perform field matching using field names. */ class AvroSchemaHelper( avroSchema: Schema, + catalystSchema: StructType, avroPath: Seq[String], + catalystPath: Seq[String], positionalFieldMatch: Boolean) { if (avroSchema.getType != Schema.Type.RECORD) { throw new IncompatibleSchemaException( @@ -228,6 +242,50 @@ private[sql] object AvroUtils extends Logging { .groupBy(_.name.toLowerCase(Locale.ROOT)) .mapValues(_.toSeq) // toSeq needed for scala 2.13 + /** The fields which have matching equivalents in both Avro and Catalyst schemas. */ + val matchedFields: Seq[AvroMatchedField] = catalystSchema.zipWithIndex.flatMap { + case (sqlField, sqlPos) => + getAvroField(sqlField.name, sqlPos).map(AvroMatchedField(sqlField, sqlPos, _)) + } + + /** + * Validate that there are no Catalyst fields which don't have a matching Avro field, throwing + * [[IncompatibleSchemaException]] if such extra fields are found. If `ignoreNullable` is false, + * consider nullable Catalyst fields to be eligible to be an extra field; otherwise, + * ignore nullable Catalyst fields when checking for extras. + */ + def validateNoExtraCatalystFields(ignoreNullable: Boolean): Unit = + catalystSchema.zipWithIndex.foreach { case (sqlField, sqlPos) => + if (getAvroField(sqlField.name, sqlPos).isEmpty && + (!ignoreNullable || !sqlField.nullable)) { + if (positionalFieldMatch) { + throw new IncompatibleSchemaException("Cannot find field at position " + + s"$sqlPos of ${toFieldStr(avroPath)} from Avro schema (using positional matching)") + } else { + throw new IncompatibleSchemaException( + s"Cannot find ${toFieldStr(catalystPath :+ sqlField.name)} in Avro schema") + } + } + } + + /** + * Validate that there are no Avro fields which don't have a matching Catalyst field, throwing + * [[IncompatibleSchemaException]] if such extra fields are found. + */ + def validateNoExtraAvroFields(): Unit = { + (avroFieldArray.toSet -- matchedFields.map(_.avroField)).foreach { extraField => + if (positionalFieldMatch) { + throw new IncompatibleSchemaException(s"Found field '${extraField.name()}' at position " + + s"${extraField.pos()} of ${toFieldStr(avroPath)} from Avro schema but there is no " + + s"match in the SQL schema at ${toFieldStr(catalystPath)} (using positional matching)") + } else { + throw new IncompatibleSchemaException( + s"Found ${toFieldStr(avroPath :+ extraField.name())} in Avro schema but there is no " + + "match in the SQL schema") + } + } + } + /** * Extract a single field from the contained avro schema which has the desired field name, * performing the matching with proper case sensitivity according to SQLConf.resolver. @@ -261,21 +319,6 @@ private[sql] object AvroUtils extends Logging { } } - /** - * Take a field's hierarchical names (see [[toFieldStr]]) and position, and convert it to a - * human-readable description of the field. Depending on the value of `positionalFieldMatch`, - * either the position or name will be emphasized (for true and false, respectively); both will - * be included in either case. - */ - private[avro] def toFieldDescription( - names: Seq[String], - position: Int, - positionalFieldMatch: Boolean): String = if (positionalFieldMatch) { - s"field at position $position (${toFieldStr(names)})" - } else { - s"${toFieldStr(names)} (at position $position)" - } - /** * Convert a sequence of hierarchical field names (like `Seq(foo, bar)`) into a human-readable * string representing the field, like "field 'foo.bar'". If `names` is empty, the string diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSchemaHelperSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSchemaHelperSuite.scala index 7af2c47dc8..604b4e80d8 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSchemaHelperSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSchemaHelperSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.avro import org.apache.avro.SchemaBuilder +import org.apache.spark.sql.avro.AvroUtils.AvroMatchedField import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -28,7 +29,7 @@ class AvroSchemaHelperSuite extends SQLTestUtils with SharedSparkSession { val avroSchema = SchemaBuilder.builder().intType() val msg = intercept[IncompatibleSchemaException] { - new AvroUtils.AvroSchemaHelper(avroSchema, Seq(""), false) + new AvroUtils.AvroSchemaHelper(avroSchema, StructType(Seq()), Seq(""), Seq(""), false) }.getMessage assert(msg.contains("Attempting to treat int as a RECORD")) } @@ -42,7 +43,8 @@ class AvroSchemaHelperSuite extends SQLTestUtils with SharedSparkSession { ) val avroSchema = SchemaConverters.toAvroType(catalystSchema) - val helper = new AvroUtils.AvroSchemaHelper(avroSchema, Seq(""), false) + val helper = + new AvroUtils.AvroSchemaHelper(avroSchema, StructType(Seq()), Seq(""), Seq(""), false) withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { assert(helper.getFieldByName("A").get.name() == "A") assert(helper.getFieldByName("a").get.name() == "a") @@ -69,8 +71,10 @@ class AvroSchemaHelperSuite extends SQLTestUtils with SharedSparkSession { val catalystSchema = new StructType().add("foo", IntegerType).add("bar", StringType) val avroSchema = SchemaConverters.toAvroType(catalystSchema) - val posHelper = new AvroUtils.AvroSchemaHelper(avroSchema, Seq(""), true) - val nameHelper = new AvroUtils.AvroSchemaHelper(avroSchema, Seq(""), false) + val posHelper = + new AvroUtils.AvroSchemaHelper(avroSchema, catalystSchema, Seq(""), Seq(""), true) + val nameHelper = + new AvroUtils.AvroSchemaHelper(avroSchema, catalystSchema, Seq(""), Seq(""), false) for (name <- Seq("foo", "bar"); fieldPos <- Seq(0, 1)) { assert(posHelper.getAvroField(name, fieldPos) === Some(avroSchema.getFields.get(fieldPos))) @@ -82,4 +86,51 @@ class AvroSchemaHelperSuite extends SQLTestUtils with SharedSparkSession { assert(posHelper.getAvroField("nonexist", 1).isDefined) assert(nameHelper.getAvroField("nonexist", 1).isEmpty) } + + test("properly match fields between Avro and Catalyst schemas") { + val catalystSchema = StructType( + Seq("catalyst1", "catalyst2", "shared1", "shared2").map(StructField(_, IntegerType)) + ) + val avroSchema = SchemaBuilder.record("toplevel").fields() + .requiredInt("shared1") + .requiredInt("shared2") + .requiredInt("avro1") + .requiredInt("avro2") + .endRecord() + + val helper = new AvroUtils.AvroSchemaHelper(avroSchema, catalystSchema, Seq(""), Seq(""), false) + assert(helper.matchedFields === Seq( + AvroMatchedField(catalystSchema("shared1"), 2, avroSchema.getField("shared1")), + AvroMatchedField(catalystSchema("shared2"), 3, avroSchema.getField("shared2")) + )) + assertThrows[IncompatibleSchemaException] { + helper.validateNoExtraAvroFields() + } + helper.validateNoExtraCatalystFields(ignoreNullable = true) + assertThrows[IncompatibleSchemaException] { + helper.validateNoExtraCatalystFields(ignoreNullable = false) + } + } + + test("respect nullability settings for validateNoExtraSqlFields") { + val avroSchema = SchemaBuilder.record("record").fields().requiredInt("bar").endRecord() + + val catalystNonnull = new StructType().add("foo", IntegerType, nullable = false) + val helperNonnull = + new AvroUtils.AvroSchemaHelper(avroSchema, catalystNonnull, Seq(""), Seq(""), false) + assertThrows[IncompatibleSchemaException] { + helperNonnull.validateNoExtraCatalystFields(ignoreNullable = true) + } + assertThrows[IncompatibleSchemaException] { + helperNonnull.validateNoExtraCatalystFields(ignoreNullable = false) + } + + val catalystNullable = new StructType().add("foo", IntegerType) + val helperNullable = + new AvroUtils.AvroSchemaHelper(avroSchema, catalystNullable, Seq(""), Seq(""), false) + helperNullable.validateNoExtraCatalystFields(ignoreNullable = true) + assertThrows[IncompatibleSchemaException] { + helperNullable.validateNoExtraCatalystFields(ignoreNullable = false) + } + } } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala index 93796912b8..8be3cf0c2e 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala @@ -86,18 +86,18 @@ class AvroSerdeSuite extends SparkFunSuite { // deserialize should have no issues when 'bar' is nullable but fail when it is nonnull Deserializer.create(CATALYST_STRUCT, avro, BY_NAME) assertFailedConversionMessage(avro, Deserializer, BY_NAME, - "Cannot find non-nullable field 'foo.bar' (at position 0) in Avro schema.", + "Cannot find field 'foo.bar' in Avro schema", nonnullCatalyst) assertFailedConversionMessage(avro, Deserializer, BY_POSITION, - "Cannot find non-nullable field at position 1 (field 'foo.baz') in Avro schema.", + "Cannot find field at position 1 of field 'foo' from Avro schema (using positional matching)", extraNonnullCatalyst) // serialize fails whether or not 'bar' is nullable - val expectMsg = "Cannot find field 'foo.bar' (at position 0) in Avro schema at field 'foo'" - assertFailedConversionMessage(avro, Serializer, BY_NAME, expectMsg) - assertFailedConversionMessage(avro, Serializer, BY_NAME, expectMsg, nonnullCatalyst) + val byNameMsg = "Cannot find field 'foo.bar' in Avro schema" + assertFailedConversionMessage(avro, Serializer, BY_NAME, byNameMsg) + assertFailedConversionMessage(avro, Serializer, BY_NAME, byNameMsg, nonnullCatalyst) assertFailedConversionMessage(avro, Serializer, BY_POSITION, - "Avro field 'foo' schema length (1) doesn't match SQL field 'foo' schema length (2)", + "Cannot find field at position 1 of field 'foo' from Avro schema (using positional matching)", extraNonnullCatalyst) } @@ -122,18 +122,28 @@ class AvroSerdeSuite extends SparkFunSuite { test("Fail to convert for serialization with field count mismatch") { // Note that this is allowed for deserialization, but not serialization - withFieldMatchType { fieldMatch => - val tooManyFields = - createAvroSchemaWithTopLevelFields(_.optionalInt("foo").optionalLong("bar")) - assertFailedConversionMessage(tooManyFields, Serializer, fieldMatch, - "Avro top-level record schema length (2) " + - "doesn't match SQL top-level record schema length (1)") + val tooManyFields = + createAvroSchemaWithTopLevelFields(_.optionalInt("foo").optionalLong("bar")) + assertFailedConversionMessage(tooManyFields, Serializer, BY_NAME, + "Found field 'bar' in Avro schema but there is no match in the SQL schema") + assertFailedConversionMessage(tooManyFields, Serializer, BY_POSITION, + "Found field 'bar' at position 1 of top-level record from Avro schema but there is no " + + "match in the SQL schema at top-level record (using positional matching)") - val tooFewFields = createAvroSchemaWithTopLevelFields(f => f) - assertFailedConversionMessage(tooFewFields, Serializer, fieldMatch, - "Avro top-level record schema length (0) " + - "doesn't match SQL top-level record schema length (1)") - } + val tooManyFieldsNested = + createNestedAvroSchemaWithFields("foo", _.optionalInt("bar").optionalInt("baz")) + assertFailedConversionMessage(tooManyFieldsNested, Serializer, BY_NAME, + "Found field 'foo.baz' in Avro schema but there is no match in the SQL schema") + assertFailedConversionMessage(tooManyFieldsNested, Serializer, BY_POSITION, + s"Found field 'baz' at position 1 of field 'foo' from Avro schema but there is no match " + + s"in the SQL schema at field 'foo' (using positional matching)") + + val tooFewFields = createAvroSchemaWithTopLevelFields(f => f) + assertFailedConversionMessage(tooFewFields, Serializer, BY_NAME, + "Cannot find field 'foo' in Avro schema") + assertFailedConversionMessage(tooFewFields, Serializer, BY_POSITION, + "Cannot find field at position 0 of top-level record from Avro schema " + + "(using positional matching)") } /** diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index f93c61a424..43ba20f6bd 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -1402,8 +1402,7 @@ abstract class AvroSuite val e = intercept[SparkException] { df.write.option("avroSchema", avroSchema).format("avro").save(s"$tempDir/save2") } - assertExceptionMsg[IncompatibleSchemaException](e, - "Cannot find field 'FOO' (at position 0) in Avro schema at top-level record") + assertExceptionMsg[IncompatibleSchemaException](e, "Cannot find field 'FOO' in Avro schema") } } }