[SPARK-24883][SQL] Avro: remove implicit class AvroDataFrameWriter/AvroDataFrameReader

## What changes were proposed in this pull request?

As per Reynold's comment: https://github.com/apache/spark/pull/21742#discussion_r203496489

It makes sense to remove the implicit class AvroDataFrameWriter/AvroDataFrameReader, since the Avro package is external module.

## How was this patch tested?

Unit test

Author: Gengliang Wang <gengliang.wang@databricks.com>

Closes #21841 from gengliangwang/removeImplicit.
This commit is contained in:
Gengliang Wang 2018-07-23 15:27:33 +08:00 committed by hyukjinkwon
parent 8817c68f50
commit f59de52a2a
2 changed files with 96 additions and 110 deletions

View file

@ -17,30 +17,9 @@
package org.apache.spark.sql
import org.apache.avro.Schema
import org.apache.spark.annotation.Experimental
package object avro {
/**
* Adds a method, `avro`, to DataFrameWriter that allows you to write avro files using
* the DataFileWriter
*/
implicit class AvroDataFrameWriter[T](writer: DataFrameWriter[T]) {
def avro: String => Unit = writer.format("avro").save
}
/**
* Adds a method, `avro`, to DataFrameReader that allows you to read avro files using
* the DataFileReader
*/
implicit class AvroDataFrameReader(reader: DataFrameReader) {
def avro: String => DataFrame = reader.format("avro").load
@scala.annotation.varargs
def avro(sources: String*): DataFrame = reader.format("avro").load(sources: _*)
}
/**
* Converts a binary column of avro format into its corresponding catalyst value. The specified
* schema must match the read data, otherwise the behavior is undefined: it may fail or return

View file

@ -46,24 +46,24 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = {
val originalEntries = spark.read.avro(testAvro).collect()
val newEntries = spark.read.avro(newFile)
val originalEntries = spark.read.format("avro").load(testAvro).collect()
val newEntries = spark.read.format("avro").load(newFile)
checkAnswer(newEntries, originalEntries)
}
test("reading from multiple paths") {
val df = spark.read.avro(episodesAvro, episodesAvro)
val df = spark.read.format("avro").load(episodesAvro, episodesAvro)
assert(df.count == 16)
}
test("reading and writing partitioned data") {
val df = spark.read.avro(episodesAvro)
val df = spark.read.format("avro").load(episodesAvro)
val fields = List("title", "air_date", "doctor")
for (field <- fields) {
withTempPath { dir =>
val outputDir = s"$dir/${UUID.randomUUID}"
df.write.partitionBy(field).avro(outputDir)
val input = spark.read.avro(outputDir)
df.write.partitionBy(field).format("avro").save(outputDir)
val input = spark.read.format("avro").load(outputDir)
// makes sure that no fields got dropped.
// We convert Rows to Seqs in order to work around SPARK-10325
assert(input.select(field).collect().map(_.toSeq).toSet ===
@ -73,14 +73,14 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("request no fields") {
val df = spark.read.avro(episodesAvro)
val df = spark.read.format("avro").load(episodesAvro)
df.createOrReplaceTempView("avro_table")
assert(spark.sql("select count(*) from avro_table").collect().head === Row(8))
}
test("convert formats") {
withTempPath { dir =>
val df = spark.read.avro(episodesAvro)
val df = spark.read.format("avro").load(episodesAvro)
df.write.parquet(dir.getCanonicalPath)
assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count)
}
@ -88,8 +88,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
test("rearrange internal schema") {
withTempPath { dir =>
val df = spark.read.avro(episodesAvro)
df.select("doctor", "title").write.avro(dir.getCanonicalPath)
val df = spark.read.format("avro").load(episodesAvro)
df.select("doctor", "title").write.format("avro").save(dir.getCanonicalPath)
}
}
@ -109,7 +109,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
dataFileWriter.close()
intercept[IncompatibleSchemaException] {
spark.read.avro(s"$dir.avro")
spark.read.format("avro").load(s"$dir.avro")
}
}
}
@ -136,7 +136,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
dataFileWriter.append(rec2)
dataFileWriter.flush()
dataFileWriter.close()
val df = spark.read.avro(s"$dir.avro")
val df = spark.read.format("avro").load(s"$dir.avro")
assert(df.schema.fields === Seq(StructField("field1", LongType, nullable = true)))
assert(df.collect().toSet == Set(Row(1L), Row(2L)))
}
@ -164,7 +164,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
dataFileWriter.append(rec2)
dataFileWriter.flush()
dataFileWriter.close()
val df = spark.read.avro(s"$dir.avro")
val df = spark.read.format("avro").load(s"$dir.avro")
assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true)))
assert(df.collect().toSet == Set(Row(1.toDouble), Row(2.toDouble)))
}
@ -196,7 +196,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
dataFileWriter.append(rec2)
dataFileWriter.flush()
dataFileWriter.close()
val df = spark.read.avro(s"$dir.avro")
val df = spark.read.format("avro").load(s"$dir.avro")
assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true)))
assert(df.collect().toSet == Set(Row(1.toDouble), Row(null)))
}
@ -220,7 +220,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
dataFileWriter.flush()
dataFileWriter.close()
val df = spark.read.avro(s"$dir.avro")
val df = spark.read.format("avro").load(s"$dir.avro")
assert(df.first() == Row(8))
}
}
@ -255,7 +255,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
dataFileWriter.flush()
dataFileWriter.close()
val df = spark.sqlContext.read.avro(s"$dir.avro")
val df = spark.sqlContext.read.format("avro").load(s"$dir.avro")
assertResult(field1)(df.selectExpr("field1.member0").first().get(0))
assertResult(field2)(df.selectExpr("field2.member1").first().get(0))
assertResult(field3)(df.selectExpr("field3.member2").first().get(0))
@ -277,8 +277,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
Row(null, null, null, null, null),
Row(null, null, null, null, null)))
val df = spark.createDataFrame(rdd, schema)
df.write.avro(dir.toString)
assert(spark.read.avro(dir.toString).count == rdd.count)
df.write.format("avro").save(dir.toString)
assert(spark.read.format("avro").load(dir.toString).count == rdd.count)
}
}
@ -296,8 +296,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
Row(3f, 3.toShort, 3.toByte, true)
))
val df = spark.createDataFrame(rdd, schema)
df.write.avro(dir.toString)
assert(spark.read.avro(dir.toString).count == rdd.count)
df.write.format("avro").save(dir.toString)
assert(spark.read.format("avro").load(dir.toString).count == rdd.count)
}
}
@ -314,9 +314,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
Row(3f, new Date(1460066400500L))
))
val df = spark.createDataFrame(rdd, schema)
df.write.avro(dir.toString)
assert(spark.read.avro(dir.toString).count == rdd.count)
assert(spark.read.avro(dir.toString).select("date").collect().map(_(0)).toSet ==
df.write.format("avro").save(dir.toString)
assert(spark.read.format("avro").load(dir.toString).count == rdd.count)
assert(
spark.read.format("avro").load(dir.toString).select("date").collect().map(_(0)).toSet ==
Array(null, 1451865600000L, 1459987200000L).toSet)
}
}
@ -350,8 +351,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
Array[Array[String]](Array[String]("CSH, tearing down the walls that divide us", "-jd")),
Array[Row](Row("Bobby G. can't swim")))))
val df = spark.createDataFrame(rdd, testSchema)
df.write.avro(dir.toString)
assert(spark.read.avro(dir.toString).count == rdd.count)
df.write.format("avro").save(dir.toString)
assert(spark.read.format("avro").load(dir.toString).count == rdd.count)
}
}
@ -363,14 +364,14 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val deflateDir = s"$dir/deflate"
val snappyDir = s"$dir/snappy"
val df = spark.read.avro(testAvro)
val df = spark.read.format("avro").load(testAvro)
spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed")
df.write.avro(uncompressDir)
df.write.format("avro").save(uncompressDir)
spark.conf.set(AVRO_COMPRESSION_CODEC, "deflate")
spark.conf.set(AVRO_DEFLATE_LEVEL, "9")
df.write.avro(deflateDir)
df.write.format("avro").save(deflateDir)
spark.conf.set(AVRO_COMPRESSION_CODEC, "snappy")
df.write.avro(snappyDir)
df.write.format("avro").save(snappyDir)
val uncompressSize = FileUtils.sizeOfDirectory(new File(uncompressDir))
val deflateSize = FileUtils.sizeOfDirectory(new File(deflateDir))
@ -382,49 +383,50 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
test("dsl test") {
val results = spark.read.avro(episodesAvro).select("title").collect()
val results = spark.read.format("avro").load(episodesAvro).select("title").collect()
assert(results.length === 8)
}
test("support of various data types") {
// This test uses data from test.avro. You can see the data and the schema of this file in
// test.json and test.avsc
val all = spark.read.avro(testAvro).collect()
val all = spark.read.format("avro").load(testAvro).collect()
assert(all.length == 3)
val str = spark.read.avro(testAvro).select("string").collect()
val str = spark.read.format("avro").load(testAvro).select("string").collect()
assert(str.map(_(0)).toSet.contains("Terran is IMBA!"))
val simple_map = spark.read.avro(testAvro).select("simple_map").collect()
val simple_map = spark.read.format("avro").load(testAvro).select("simple_map").collect()
assert(simple_map(0)(0).getClass.toString.contains("Map"))
assert(simple_map.map(_(0).asInstanceOf[Map[String, Some[Int]]].size).toSet == Set(2, 0))
val union0 = spark.read.avro(testAvro).select("union_string_null").collect()
val union0 = spark.read.format("avro").load(testAvro).select("union_string_null").collect()
assert(union0.map(_(0)).toSet == Set("abc", "123", null))
val union1 = spark.read.avro(testAvro).select("union_int_long_null").collect()
val union1 = spark.read.format("avro").load(testAvro).select("union_int_long_null").collect()
assert(union1.map(_(0)).toSet == Set(66, 1, null))
val union2 = spark.read.avro(testAvro).select("union_float_double").collect()
val union2 = spark.read.format("avro").load(testAvro).select("union_float_double").collect()
assert(
union2
.map(x => new java.lang.Double(x(0).toString))
.exists(p => Math.abs(p - Math.PI) < 0.001))
val fixed = spark.read.avro(testAvro).select("fixed3").collect()
val fixed = spark.read.format("avro").load(testAvro).select("fixed3").collect()
assert(fixed.map(_(0).asInstanceOf[Array[Byte]]).exists(p => p(1) == 3))
val enum = spark.read.avro(testAvro).select("enum").collect()
val enum = spark.read.format("avro").load(testAvro).select("enum").collect()
assert(enum.map(_(0)).toSet == Set("SPADES", "CLUBS", "DIAMONDS"))
val record = spark.read.avro(testAvro).select("record").collect()
val record = spark.read.format("avro").load(testAvro).select("record").collect()
assert(record(0)(0).getClass.toString.contains("Row"))
assert(record.map(_(0).asInstanceOf[Row](0)).contains("TEST_STR123"))
val array_of_boolean = spark.read.avro(testAvro).select("array_of_boolean").collect()
val array_of_boolean =
spark.read.format("avro").load(testAvro).select("array_of_boolean").collect()
assert(array_of_boolean.map(_(0).asInstanceOf[Seq[Boolean]].size).toSet == Set(3, 1, 0))
val bytes = spark.read.avro(testAvro).select("bytes").collect()
val bytes = spark.read.format("avro").load(testAvro).select("bytes").collect()
assert(bytes.map(_(0).asInstanceOf[Array[Byte]].length).toSet == Set(3, 1, 0))
}
@ -444,7 +446,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
// get the same values back.
withTempPath { dir =>
val avroDir = s"$dir/avro"
spark.read.avro(testAvro).write.avro(avroDir)
spark.read.format("avro").load(testAvro).write.format("avro").save(avroDir)
checkReloadMatchesSaved(testAvro, avroDir)
}
}
@ -458,7 +460,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val parameters = Map("recordName" -> name, "recordNamespace" -> namespace)
val avroDir = tempDir + "/namedAvro"
spark.read.avro(testAvro).write.options(parameters).avro(avroDir)
spark.read.format("avro").load(testAvro)
.write.options(parameters).format("avro").save(avroDir)
checkReloadMatchesSaved(testAvro, avroDir)
// Look at raw file and make sure has namespace info
@ -489,22 +492,22 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val cityDataFrame = spark.createDataFrame(cityRDD, testSchema)
val avroDir = tempDir + "/avro"
cityDataFrame.write.avro(avroDir)
assert(spark.read.avro(avroDir).collect().length == 3)
cityDataFrame.write.format("avro").save(avroDir)
assert(spark.read.format("avro").load(avroDir).collect().length == 3)
// TimesStamps are converted to longs
val times = spark.read.avro(avroDir).select("Time").collect()
val times = spark.read.format("avro").load(avroDir).select("Time").collect()
assert(times.map(_(0)).toSet == Set(666, 777, 42))
// DecimalType should be converted to string
val decimals = spark.read.avro(avroDir).select("Decimal").collect()
val decimals = spark.read.format("avro").load(avroDir).select("Decimal").collect()
assert(decimals.map(_(0)).contains("3.14"))
// There should be a null entry
val length = spark.read.avro(avroDir).select("Length").collect()
val length = spark.read.format("avro").load(avroDir).select("Length").collect()
assert(length.map(_(0)).contains(null))
val binary = spark.read.avro(avroDir).select("Binary").collect()
val binary = spark.read.format("avro").load(avroDir).select("Binary").collect()
for (i <- arrayOfByte.indices) {
assert(binary(1)(0).asInstanceOf[Array[Byte]](i) == arrayOfByte(i))
}
@ -523,10 +526,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val writeDs = Seq((currentDate, currentTime)).toDS
val avroDir = tempDir + "/avro"
writeDs.write.avro(avroDir)
assert(spark.read.avro(avroDir).collect().length == 1)
writeDs.write.format("avro").save(avroDir)
assert(spark.read.format("avro").load(avroDir).collect().length == 1)
val readDs = spark.read.schema(schema).avro(avroDir).as[(Date, Timestamp)]
val readDs = spark.read.schema(schema).format("avro").load(avroDir).as[(Date, Timestamp)]
assert(readDs.collect().sameElements(writeDs.collect()))
}
@ -534,10 +537,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
test("support of globbed paths") {
val resourceDir = testFile(".")
val e1 = spark.read.avro(resourceDir + "../*/episodes.avro").collect()
val e1 = spark.read.format("avro").load(resourceDir + "../*/episodes.avro").collect()
assert(e1.length == 8)
val e2 = spark.read.avro(resourceDir + "../../*/*/episodes.avro").collect()
val e2 = spark.read.format("avro").load(resourceDir + "../../*/*/episodes.avro").collect()
assert(e2.length == 8)
}
@ -555,8 +558,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val writeDs = Seq((nullDate, nullTime)).toDS
val avroDir = tempDir + "/avro"
writeDs.write.avro(avroDir)
val readValues = spark.read.schema(schema).avro(avroDir).as[(Date, Timestamp)].collect
writeDs.write.format("avro").save(avroDir)
val readValues =
spark.read.schema(schema).format("avro").load(avroDir).as[(Date, Timestamp)].collect
assert(readValues.size == 1)
assert(readValues.head == ((nullDate, nullTime)))
@ -579,9 +583,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val result = spark
.read
.option("avroSchema", avroSchema)
.avro(testAvro)
.format("avro")
.load(testAvro)
.collect()
val expected = spark.read.avro(testAvro).select("string").collect()
val expected = spark.read.format("avro").load(testAvro).select("string").collect()
assert(result.sameElements(expected))
}
@ -601,7 +606,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val result = spark
.read
.option("avroSchema", avroSchema)
.avro(testAvro).select("missingField").first
.format("avro").load(testAvro).select("missingField").first
assert(result === Row("foo"))
}
@ -609,17 +614,17 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
// Directory given has no avro files
intercept[AnalysisException] {
withTempPath(dir => spark.read.avro(dir.getCanonicalPath))
withTempPath(dir => spark.read.format("avro").load(dir.getCanonicalPath))
}
intercept[AnalysisException] {
spark.read.avro("very/invalid/path/123.avro")
spark.read.format("avro").load("very/invalid/path/123.avro")
}
// In case of globbed path that can't be matched to anything, another exception is thrown (and
// exception message is helpful)
intercept[AnalysisException] {
spark.read.avro("*/*/*/*/*/*/*/something.avro")
spark.read.format("avro").load("*/*/*/*/*/*/*/something.avro")
}
intercept[FileNotFoundException] {
@ -628,7 +633,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration
try {
hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true")
spark.read.avro(dir.toString)
spark.read.format("avro").load(dir.toString)
} finally {
hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty)
}
@ -642,7 +647,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
spark
.read
.option("ignoreExtension", false)
.avro(dir.toString)
.format("avro")
.load(dir.toString)
}
}
}
@ -681,13 +687,13 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
test("test save and load") {
// Test if load works as expected
withTempPath { tempDir =>
val df = spark.read.avro(episodesAvro)
val df = spark.read.format("avro").load(episodesAvro)
assert(df.count == 8)
val tempSaveDir = s"$tempDir/save/"
df.write.avro(tempSaveDir)
val newDf = spark.read.avro(tempSaveDir)
df.write.format("avro").save(tempSaveDir)
val newDf = spark.read.format("avro").load(tempSaveDir)
assert(newDf.count == 8)
}
}
@ -695,20 +701,18 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
test("test load with non-Avro file") {
// Test if load works as expected
withTempPath { tempDir =>
val df = spark.read.avro(episodesAvro)
val df = spark.read.format("avro").load(episodesAvro)
assert(df.count == 8)
val tempSaveDir = s"$tempDir/save/"
df.write.avro(tempSaveDir)
df.write.format("avro").save(tempSaveDir)
Files.createFile(new File(tempSaveDir, "non-avro").toPath)
val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration
val count = try {
hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true")
val newDf = spark
.read
.avro(tempSaveDir)
val newDf = spark.read.format("avro").load(tempSaveDir)
newDf.count()
} finally {
hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty)
@ -730,10 +734,11 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
StructField("record", StructType(Seq(StructField("value_field", StringType, false))), false),
StructField("array_of_boolean", ArrayType(BooleanType), false),
StructField("bytes", BinaryType, true)))
val withSchema = spark.read.schema(partialColumns).avro(testAvro).collect()
val withSchema = spark.read.schema(partialColumns).format("avro").load(testAvro).collect()
val withOutSchema = spark
.read
.avro(testAvro)
.format("avro")
.load(testAvro)
.select("string", "simple_map", "complex_map", "union_string_null", "union_int_long_null",
"fixed3", "fixed2", "enum", "record", "array_of_boolean", "bytes")
.collect()
@ -751,7 +756,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
StructField("non_exist_field", StringType, false),
StructField("non_exist_field2", StringType, false))),
false)))
val withEmptyColumn = spark.read.schema(schema).avro(testAvro).collect()
val withEmptyColumn = spark.read.schema(schema).format("avro").load(testAvro).collect()
assert(withEmptyColumn.forall(_ == Row(null: String, Row(null: String, null: String))))
}
@ -762,8 +767,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
import sparkSession.implicits._
val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records")
val outputDir = s"$dir/${UUID.randomUUID}"
df.write.avro(outputDir)
val input = spark.read.avro(outputDir)
df.write.format("avro").save(outputDir)
val input = spark.read.format("avro").load(outputDir)
assert(input.collect.toSet.size === 1024 * 3 + 1)
assert(input.rdd.partitions.size > 2)
}
@ -780,9 +785,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
// Save avro file on output folder path
val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1")))))
val outputFolder = s"$tempDir/duplicate_names/"
writeDf.write.avro(outputFolder)
writeDf.write.format("avro").save(outputFolder)
// Read avro file saved on the last step
val readDf = spark.read.avro(outputFolder)
val readDf = spark.read.format("avro").load(outputFolder)
// Check if the written DataFrame is equals than read DataFrame
assert(readDf.collect().sameElements(writeDf.collect()))
}
@ -801,9 +806,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
))))
)
val outputFolder = s"$tempDir/duplicate_names_array/"
writeDf.write.avro(outputFolder)
writeDf.write.format("avro").save(outputFolder)
// Read avro file saved on the last step
val readDf = spark.read.avro(outputFolder)
val readDf = spark.read.format("avro").load(outputFolder)
// Check if the written DataFrame is equals than read DataFrame
assert(readDf.collect().sameElements(writeDf.collect()))
}
@ -822,9 +827,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
))))
)
val outputFolder = s"$tempDir/duplicate_names_map/"
writeDf.write.avro(outputFolder)
writeDf.write.format("avro").save(outputFolder)
// Read avro file saved on the last step
val readDf = spark.read.avro(outputFolder)
val readDf = spark.read.format("avro").load(outputFolder)
// Check if the written DataFrame is equals than read DataFrame
assert(readDf.collect().sameElements(writeDf.collect()))
}
@ -837,32 +842,33 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
Paths.get(dir.getCanonicalPath, "episodes"))
val fileWithoutExtension = s"${dir.getCanonicalPath}/episodes"
val df1 = spark.read.avro(fileWithoutExtension)
val df1 = spark.read.format("avro").load(fileWithoutExtension)
assert(df1.count == 8)
val schema = new StructType()
.add("title", StringType)
.add("air_date", StringType)
.add("doctor", IntegerType)
val df2 = spark.read.schema(schema).avro(fileWithoutExtension)
val df2 = spark.read.schema(schema).format("avro").load(fileWithoutExtension)
assert(df2.count == 8)
}
}
test("SPARK-24836: checking the ignoreExtension option") {
withTempPath { tempDir =>
val df = spark.read.avro(episodesAvro)
val df = spark.read.format("avro").load(episodesAvro)
assert(df.count == 8)
val tempSaveDir = s"$tempDir/save/"
df.write.avro(tempSaveDir)
df.write.format("avro").save(tempSaveDir)
Files.createFile(new File(tempSaveDir, "non-avro").toPath)
val newDf = spark
.read
.option("ignoreExtension", false)
.avro(tempSaveDir)
.format("avro")
.load(tempSaveDir)
assert(newDf.count == 8)
}
@ -880,7 +886,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
val newDf = spark
.read
.option("ignoreExtension", "true")
.avro(s"${dir.getCanonicalPath}/episodes")
.format("avro")
.load(s"${dir.getCanonicalPath}/episodes")
newDf.count()
} finally {
hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty)