diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index f1ae240f97..7955dbc00c 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -21,7 +21,6 @@ import java.math.BigDecimal import java.nio.ByteBuffer import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder} import org.apache.avro.Conversions.DecimalConversion @@ -30,7 +29,7 @@ import org.apache.avro.Schema.Type._ import org.apache.avro.generic._ import org.apache.avro.util.Utf8 -import org.apache.spark.sql.avro.AvroUtils.{toFieldDescription, toFieldStr} +import org.apache.spark.sql.avro.AvroUtils.{toFieldStr, AvroMatchedField} import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters} import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} @@ -352,39 +351,26 @@ private[sql] class AvroDeserializer( avroPath: Seq[String], catalystPath: Seq[String], applyFilters: Int => Boolean): (CatalystDataUpdater, GenericRecord) => Boolean = { - val validFieldIndexes = ArrayBuffer.empty[Int] - val fieldWriters = ArrayBuffer.empty[(CatalystDataUpdater, Any) => Unit] - val avroSchemaHelper = - new AvroUtils.AvroSchemaHelper(avroType, avroPath, positionalFieldMatch) - val length = catalystType.length - var i = 0 - while (i < length) { - val catalystField = catalystType.fields(i) - avroSchemaHelper.getAvroField(catalystField.name, i) match { - case Some(avroField) => - validFieldIndexes += avroField.pos() + val avroSchemaHelper = new AvroUtils.AvroSchemaHelper( + avroType, catalystType, avroPath, catalystPath, positionalFieldMatch) - val baseWriter = newWriter(avroField.schema(), catalystField.dataType, - avroPath :+ avroField.name, catalystPath :+ catalystField.name) - val ordinal = i - val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { - if (value == null) { - fieldUpdater.setNullAt(ordinal) - } else { - baseWriter(fieldUpdater, ordinal, value) - } + avroSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = true) + // no need to validateNoExtraAvroFields since extra Avro fields are ignored + + val (validFieldIndexes, fieldWriters) = avroSchemaHelper.matchedFields.map { + case AvroMatchedField(catalystField, ordinal, avroField) => + val baseWriter = newWriter(avroField.schema(), catalystField.dataType, + avroPath :+ avroField.name, catalystPath :+ catalystField.name) + val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { + if (value == null) { + fieldUpdater.setNullAt(ordinal) + } else { + baseWriter(fieldUpdater, ordinal, value) } - fieldWriters += fieldWriter - case None if !catalystField.nullable => - val fieldDescription = - toFieldDescription(catalystPath :+ catalystField.name, i, positionalFieldMatch) - throw new IncompatibleSchemaException( - s"Cannot find non-nullable $fieldDescription in Avro schema.") - case _ => // nothing to do - } - i += 1 - } + } + (avroField.pos(), fieldWriter) + }.toArray.unzip (fieldUpdater, record) => { var i = 0 diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index fc2ad7fb3e..1cb22bd41d 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -32,7 +32,7 @@ import org.apache.avro.generic.GenericData.Record import org.apache.avro.util.Utf8 import org.apache.spark.internal.Logging -import org.apache.spark.sql.avro.AvroUtils.{toFieldDescription, toFieldStr} +import org.apache.spark.sql.avro.AvroUtils.{toFieldStr, AvroMatchedField} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -252,34 +252,19 @@ private[sql] class AvroSerializer( catalystPath: Seq[String], avroPath: Seq[String]): InternalRow => Record = { - val avroPathStr = toFieldStr(avroPath) - if (avroStruct.getType != RECORD) { - throw new IncompatibleSchemaException(s"$avroPathStr was not a RECORD") - } - val avroFields = avroStruct.getFields.asScala - if (avroFields.size != catalystStruct.length) { - throw new IncompatibleSchemaException( - s"Avro $avroPathStr schema length (${avroFields.size}) doesn't match " + - s"SQL ${toFieldStr(catalystPath)} schema length (${catalystStruct.length})") - } - val avroSchemaHelper = - new AvroUtils.AvroSchemaHelper(avroStruct, avroPath, positionalFieldMatch) + val avroSchemaHelper = new AvroUtils.AvroSchemaHelper( + avroStruct, catalystStruct, avroPath, catalystPath, positionalFieldMatch) - val (avroIndices: Array[Int], fieldConverters: Array[Converter]) = - catalystStruct.zipWithIndex.map { case (catalystField, catalystPos) => - val avroField = avroSchemaHelper.getAvroField(catalystField.name, catalystPos) match { - case Some(f) => f - case None => - val fieldDescription = toFieldDescription( - catalystPath :+ catalystField.name, catalystPos, positionalFieldMatch) - throw new IncompatibleSchemaException( - s"Cannot find $fieldDescription in Avro schema at $avroPathStr") - } + avroSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = false) + avroSchemaHelper.validateNoExtraAvroFields() + + val (avroIndices, fieldConverters) = avroSchemaHelper.matchedFields.map { + case AvroMatchedField(catalystField, _, avroField) => val converter = newConverter(catalystField.dataType, resolveNullableType(avroField.schema(), catalystField.nullable), catalystPath :+ catalystField.name, avroPath :+ avroField.name) (avroField.pos(), converter) - }.toArray.unzip + }.toArray.unzip val numFields = catalystStruct.length row: InternalRow => diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index f09af7451b..149d0b6e73 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -205,18 +205,32 @@ private[sql] object AvroUtils extends Logging { } } + /** Wrapper for a pair of matched fields, one Catalyst and one corresponding Avro field. */ + case class AvroMatchedField( + catalystField: StructField, + catalystPosition: Int, + avroField: Schema.Field) + /** - * Wraps an Avro Schema object so that field lookups are faster. + * Helper class to perform field lookup/matching on Avro schemas. + * + * This will match `avroSchema` against `catalystSchema`, attempting to find a matching field in + * the Avro schema for each field in the Catalyst schema and vice-versa, respecting settings for + * case sensitivity. The match results can be accessed using the getter methods. * * @param avroSchema The schema in which to search for fields. Must be of type RECORD. + * @param catalystSchema The Catalyst schema to use for matching. * @param avroPath The seq of parent field names leading to `avroSchema`. + * @param catalystPath The seq of parent field names leading to `catalystSchema`. * @param positionalFieldMatch If true, perform field matching in a positional fashion * (structural comparison between schemas, ignoring names); * otherwise, perform field matching using field names. */ class AvroSchemaHelper( avroSchema: Schema, + catalystSchema: StructType, avroPath: Seq[String], + catalystPath: Seq[String], positionalFieldMatch: Boolean) { if (avroSchema.getType != Schema.Type.RECORD) { throw new IncompatibleSchemaException( @@ -228,6 +242,50 @@ private[sql] object AvroUtils extends Logging { .groupBy(_.name.toLowerCase(Locale.ROOT)) .mapValues(_.toSeq) // toSeq needed for scala 2.13 + /** The fields which have matching equivalents in both Avro and Catalyst schemas. */ + val matchedFields: Seq[AvroMatchedField] = catalystSchema.zipWithIndex.flatMap { + case (sqlField, sqlPos) => + getAvroField(sqlField.name, sqlPos).map(AvroMatchedField(sqlField, sqlPos, _)) + } + + /** + * Validate that there are no Catalyst fields which don't have a matching Avro field, throwing + * [[IncompatibleSchemaException]] if such extra fields are found. If `ignoreNullable` is false, + * consider nullable Catalyst fields to be eligible to be an extra field; otherwise, + * ignore nullable Catalyst fields when checking for extras. + */ + def validateNoExtraCatalystFields(ignoreNullable: Boolean): Unit = + catalystSchema.zipWithIndex.foreach { case (sqlField, sqlPos) => + if (getAvroField(sqlField.name, sqlPos).isEmpty && + (!ignoreNullable || !sqlField.nullable)) { + if (positionalFieldMatch) { + throw new IncompatibleSchemaException("Cannot find field at position " + + s"$sqlPos of ${toFieldStr(avroPath)} from Avro schema (using positional matching)") + } else { + throw new IncompatibleSchemaException( + s"Cannot find ${toFieldStr(catalystPath :+ sqlField.name)} in Avro schema") + } + } + } + + /** + * Validate that there are no Avro fields which don't have a matching Catalyst field, throwing + * [[IncompatibleSchemaException]] if such extra fields are found. + */ + def validateNoExtraAvroFields(): Unit = { + (avroFieldArray.toSet -- matchedFields.map(_.avroField)).foreach { extraField => + if (positionalFieldMatch) { + throw new IncompatibleSchemaException(s"Found field '${extraField.name()}' at position " + + s"${extraField.pos()} of ${toFieldStr(avroPath)} from Avro schema but there is no " + + s"match in the SQL schema at ${toFieldStr(catalystPath)} (using positional matching)") + } else { + throw new IncompatibleSchemaException( + s"Found ${toFieldStr(avroPath :+ extraField.name())} in Avro schema but there is no " + + "match in the SQL schema") + } + } + } + /** * Extract a single field from the contained avro schema which has the desired field name, * performing the matching with proper case sensitivity according to SQLConf.resolver. @@ -261,21 +319,6 @@ private[sql] object AvroUtils extends Logging { } } - /** - * Take a field's hierarchical names (see [[toFieldStr]]) and position, and convert it to a - * human-readable description of the field. Depending on the value of `positionalFieldMatch`, - * either the position or name will be emphasized (for true and false, respectively); both will - * be included in either case. - */ - private[avro] def toFieldDescription( - names: Seq[String], - position: Int, - positionalFieldMatch: Boolean): String = if (positionalFieldMatch) { - s"field at position $position (${toFieldStr(names)})" - } else { - s"${toFieldStr(names)} (at position $position)" - } - /** * Convert a sequence of hierarchical field names (like `Seq(foo, bar)`) into a human-readable * string representing the field, like "field 'foo.bar'". If `names` is empty, the string diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSchemaHelperSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSchemaHelperSuite.scala index 7af2c47dc8..604b4e80d8 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSchemaHelperSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSchemaHelperSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.avro import org.apache.avro.SchemaBuilder +import org.apache.spark.sql.avro.AvroUtils.AvroMatchedField import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -28,7 +29,7 @@ class AvroSchemaHelperSuite extends SQLTestUtils with SharedSparkSession { val avroSchema = SchemaBuilder.builder().intType() val msg = intercept[IncompatibleSchemaException] { - new AvroUtils.AvroSchemaHelper(avroSchema, Seq(""), false) + new AvroUtils.AvroSchemaHelper(avroSchema, StructType(Seq()), Seq(""), Seq(""), false) }.getMessage assert(msg.contains("Attempting to treat int as a RECORD")) } @@ -42,7 +43,8 @@ class AvroSchemaHelperSuite extends SQLTestUtils with SharedSparkSession { ) val avroSchema = SchemaConverters.toAvroType(catalystSchema) - val helper = new AvroUtils.AvroSchemaHelper(avroSchema, Seq(""), false) + val helper = + new AvroUtils.AvroSchemaHelper(avroSchema, StructType(Seq()), Seq(""), Seq(""), false) withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { assert(helper.getFieldByName("A").get.name() == "A") assert(helper.getFieldByName("a").get.name() == "a") @@ -69,8 +71,10 @@ class AvroSchemaHelperSuite extends SQLTestUtils with SharedSparkSession { val catalystSchema = new StructType().add("foo", IntegerType).add("bar", StringType) val avroSchema = SchemaConverters.toAvroType(catalystSchema) - val posHelper = new AvroUtils.AvroSchemaHelper(avroSchema, Seq(""), true) - val nameHelper = new AvroUtils.AvroSchemaHelper(avroSchema, Seq(""), false) + val posHelper = + new AvroUtils.AvroSchemaHelper(avroSchema, catalystSchema, Seq(""), Seq(""), true) + val nameHelper = + new AvroUtils.AvroSchemaHelper(avroSchema, catalystSchema, Seq(""), Seq(""), false) for (name <- Seq("foo", "bar"); fieldPos <- Seq(0, 1)) { assert(posHelper.getAvroField(name, fieldPos) === Some(avroSchema.getFields.get(fieldPos))) @@ -82,4 +86,51 @@ class AvroSchemaHelperSuite extends SQLTestUtils with SharedSparkSession { assert(posHelper.getAvroField("nonexist", 1).isDefined) assert(nameHelper.getAvroField("nonexist", 1).isEmpty) } + + test("properly match fields between Avro and Catalyst schemas") { + val catalystSchema = StructType( + Seq("catalyst1", "catalyst2", "shared1", "shared2").map(StructField(_, IntegerType)) + ) + val avroSchema = SchemaBuilder.record("toplevel").fields() + .requiredInt("shared1") + .requiredInt("shared2") + .requiredInt("avro1") + .requiredInt("avro2") + .endRecord() + + val helper = new AvroUtils.AvroSchemaHelper(avroSchema, catalystSchema, Seq(""), Seq(""), false) + assert(helper.matchedFields === Seq( + AvroMatchedField(catalystSchema("shared1"), 2, avroSchema.getField("shared1")), + AvroMatchedField(catalystSchema("shared2"), 3, avroSchema.getField("shared2")) + )) + assertThrows[IncompatibleSchemaException] { + helper.validateNoExtraAvroFields() + } + helper.validateNoExtraCatalystFields(ignoreNullable = true) + assertThrows[IncompatibleSchemaException] { + helper.validateNoExtraCatalystFields(ignoreNullable = false) + } + } + + test("respect nullability settings for validateNoExtraSqlFields") { + val avroSchema = SchemaBuilder.record("record").fields().requiredInt("bar").endRecord() + + val catalystNonnull = new StructType().add("foo", IntegerType, nullable = false) + val helperNonnull = + new AvroUtils.AvroSchemaHelper(avroSchema, catalystNonnull, Seq(""), Seq(""), false) + assertThrows[IncompatibleSchemaException] { + helperNonnull.validateNoExtraCatalystFields(ignoreNullable = true) + } + assertThrows[IncompatibleSchemaException] { + helperNonnull.validateNoExtraCatalystFields(ignoreNullable = false) + } + + val catalystNullable = new StructType().add("foo", IntegerType) + val helperNullable = + new AvroUtils.AvroSchemaHelper(avroSchema, catalystNullable, Seq(""), Seq(""), false) + helperNullable.validateNoExtraCatalystFields(ignoreNullable = true) + assertThrows[IncompatibleSchemaException] { + helperNullable.validateNoExtraCatalystFields(ignoreNullable = false) + } + } } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala index 93796912b8..8be3cf0c2e 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala @@ -86,18 +86,18 @@ class AvroSerdeSuite extends SparkFunSuite { // deserialize should have no issues when 'bar' is nullable but fail when it is nonnull Deserializer.create(CATALYST_STRUCT, avro, BY_NAME) assertFailedConversionMessage(avro, Deserializer, BY_NAME, - "Cannot find non-nullable field 'foo.bar' (at position 0) in Avro schema.", + "Cannot find field 'foo.bar' in Avro schema", nonnullCatalyst) assertFailedConversionMessage(avro, Deserializer, BY_POSITION, - "Cannot find non-nullable field at position 1 (field 'foo.baz') in Avro schema.", + "Cannot find field at position 1 of field 'foo' from Avro schema (using positional matching)", extraNonnullCatalyst) // serialize fails whether or not 'bar' is nullable - val expectMsg = "Cannot find field 'foo.bar' (at position 0) in Avro schema at field 'foo'" - assertFailedConversionMessage(avro, Serializer, BY_NAME, expectMsg) - assertFailedConversionMessage(avro, Serializer, BY_NAME, expectMsg, nonnullCatalyst) + val byNameMsg = "Cannot find field 'foo.bar' in Avro schema" + assertFailedConversionMessage(avro, Serializer, BY_NAME, byNameMsg) + assertFailedConversionMessage(avro, Serializer, BY_NAME, byNameMsg, nonnullCatalyst) assertFailedConversionMessage(avro, Serializer, BY_POSITION, - "Avro field 'foo' schema length (1) doesn't match SQL field 'foo' schema length (2)", + "Cannot find field at position 1 of field 'foo' from Avro schema (using positional matching)", extraNonnullCatalyst) } @@ -122,18 +122,28 @@ class AvroSerdeSuite extends SparkFunSuite { test("Fail to convert for serialization with field count mismatch") { // Note that this is allowed for deserialization, but not serialization - withFieldMatchType { fieldMatch => - val tooManyFields = - createAvroSchemaWithTopLevelFields(_.optionalInt("foo").optionalLong("bar")) - assertFailedConversionMessage(tooManyFields, Serializer, fieldMatch, - "Avro top-level record schema length (2) " + - "doesn't match SQL top-level record schema length (1)") + val tooManyFields = + createAvroSchemaWithTopLevelFields(_.optionalInt("foo").optionalLong("bar")) + assertFailedConversionMessage(tooManyFields, Serializer, BY_NAME, + "Found field 'bar' in Avro schema but there is no match in the SQL schema") + assertFailedConversionMessage(tooManyFields, Serializer, BY_POSITION, + "Found field 'bar' at position 1 of top-level record from Avro schema but there is no " + + "match in the SQL schema at top-level record (using positional matching)") - val tooFewFields = createAvroSchemaWithTopLevelFields(f => f) - assertFailedConversionMessage(tooFewFields, Serializer, fieldMatch, - "Avro top-level record schema length (0) " + - "doesn't match SQL top-level record schema length (1)") - } + val tooManyFieldsNested = + createNestedAvroSchemaWithFields("foo", _.optionalInt("bar").optionalInt("baz")) + assertFailedConversionMessage(tooManyFieldsNested, Serializer, BY_NAME, + "Found field 'foo.baz' in Avro schema but there is no match in the SQL schema") + assertFailedConversionMessage(tooManyFieldsNested, Serializer, BY_POSITION, + s"Found field 'baz' at position 1 of field 'foo' from Avro schema but there is no match " + + s"in the SQL schema at field 'foo' (using positional matching)") + + val tooFewFields = createAvroSchemaWithTopLevelFields(f => f) + assertFailedConversionMessage(tooFewFields, Serializer, BY_NAME, + "Cannot find field 'foo' in Avro schema") + assertFailedConversionMessage(tooFewFields, Serializer, BY_POSITION, + "Cannot find field at position 0 of top-level record from Avro schema " + + "(using positional matching)") } /** diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index f93c61a424..43ba20f6bd 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -1402,8 +1402,7 @@ abstract class AvroSuite val e = intercept[SparkException] { df.write.option("avroSchema", avroSchema).format("avro").save(s"$tempDir/save2") } - assertExceptionMsg[IncompatibleSchemaException](e, - "Cannot find field 'FOO' (at position 0) in Avro schema at top-level record") + assertExceptionMsg[IncompatibleSchemaException](e, "Cannot find field 'FOO' in Avro schema") } } }