[SPARK-24883][SQL] Avro: remove implicit class AvroDataFrameWriter/AvroDataFrameReader
## What changes were proposed in this pull request? As per Reynold's comment: https://github.com/apache/spark/pull/21742#discussion_r203496489 It makes sense to remove the implicit class AvroDataFrameWriter/AvroDataFrameReader, since the Avro package is external module. ## How was this patch tested? Unit test Author: Gengliang Wang <gengliang.wang@databricks.com> Closes #21841 from gengliangwang/removeImplicit.
This commit is contained in:
parent
8817c68f50
commit
f59de52a2a
|
@ -17,30 +17,9 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.avro.Schema
|
||||
|
||||
import org.apache.spark.annotation.Experimental
|
||||
|
||||
package object avro {
|
||||
/**
|
||||
* Adds a method, `avro`, to DataFrameWriter that allows you to write avro files using
|
||||
* the DataFileWriter
|
||||
*/
|
||||
implicit class AvroDataFrameWriter[T](writer: DataFrameWriter[T]) {
|
||||
def avro: String => Unit = writer.format("avro").save
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a method, `avro`, to DataFrameReader that allows you to read avro files using
|
||||
* the DataFileReader
|
||||
*/
|
||||
implicit class AvroDataFrameReader(reader: DataFrameReader) {
|
||||
def avro: String => DataFrame = reader.format("avro").load
|
||||
|
||||
@scala.annotation.varargs
|
||||
def avro(sources: String*): DataFrame = reader.format("avro").load(sources: _*)
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a binary column of avro format into its corresponding catalyst value. The specified
|
||||
* schema must match the read data, otherwise the behavior is undefined: it may fail or return
|
||||
|
|
|
@ -46,24 +46,24 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
}
|
||||
|
||||
def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = {
|
||||
val originalEntries = spark.read.avro(testAvro).collect()
|
||||
val newEntries = spark.read.avro(newFile)
|
||||
val originalEntries = spark.read.format("avro").load(testAvro).collect()
|
||||
val newEntries = spark.read.format("avro").load(newFile)
|
||||
checkAnswer(newEntries, originalEntries)
|
||||
}
|
||||
|
||||
test("reading from multiple paths") {
|
||||
val df = spark.read.avro(episodesAvro, episodesAvro)
|
||||
val df = spark.read.format("avro").load(episodesAvro, episodesAvro)
|
||||
assert(df.count == 16)
|
||||
}
|
||||
|
||||
test("reading and writing partitioned data") {
|
||||
val df = spark.read.avro(episodesAvro)
|
||||
val df = spark.read.format("avro").load(episodesAvro)
|
||||
val fields = List("title", "air_date", "doctor")
|
||||
for (field <- fields) {
|
||||
withTempPath { dir =>
|
||||
val outputDir = s"$dir/${UUID.randomUUID}"
|
||||
df.write.partitionBy(field).avro(outputDir)
|
||||
val input = spark.read.avro(outputDir)
|
||||
df.write.partitionBy(field).format("avro").save(outputDir)
|
||||
val input = spark.read.format("avro").load(outputDir)
|
||||
// makes sure that no fields got dropped.
|
||||
// We convert Rows to Seqs in order to work around SPARK-10325
|
||||
assert(input.select(field).collect().map(_.toSeq).toSet ===
|
||||
|
@ -73,14 +73,14 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
}
|
||||
|
||||
test("request no fields") {
|
||||
val df = spark.read.avro(episodesAvro)
|
||||
val df = spark.read.format("avro").load(episodesAvro)
|
||||
df.createOrReplaceTempView("avro_table")
|
||||
assert(spark.sql("select count(*) from avro_table").collect().head === Row(8))
|
||||
}
|
||||
|
||||
test("convert formats") {
|
||||
withTempPath { dir =>
|
||||
val df = spark.read.avro(episodesAvro)
|
||||
val df = spark.read.format("avro").load(episodesAvro)
|
||||
df.write.parquet(dir.getCanonicalPath)
|
||||
assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count)
|
||||
}
|
||||
|
@ -88,8 +88,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
|
||||
test("rearrange internal schema") {
|
||||
withTempPath { dir =>
|
||||
val df = spark.read.avro(episodesAvro)
|
||||
df.select("doctor", "title").write.avro(dir.getCanonicalPath)
|
||||
val df = spark.read.format("avro").load(episodesAvro)
|
||||
df.select("doctor", "title").write.format("avro").save(dir.getCanonicalPath)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -109,7 +109,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
dataFileWriter.close()
|
||||
|
||||
intercept[IncompatibleSchemaException] {
|
||||
spark.read.avro(s"$dir.avro")
|
||||
spark.read.format("avro").load(s"$dir.avro")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -136,7 +136,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
dataFileWriter.append(rec2)
|
||||
dataFileWriter.flush()
|
||||
dataFileWriter.close()
|
||||
val df = spark.read.avro(s"$dir.avro")
|
||||
val df = spark.read.format("avro").load(s"$dir.avro")
|
||||
assert(df.schema.fields === Seq(StructField("field1", LongType, nullable = true)))
|
||||
assert(df.collect().toSet == Set(Row(1L), Row(2L)))
|
||||
}
|
||||
|
@ -164,7 +164,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
dataFileWriter.append(rec2)
|
||||
dataFileWriter.flush()
|
||||
dataFileWriter.close()
|
||||
val df = spark.read.avro(s"$dir.avro")
|
||||
val df = spark.read.format("avro").load(s"$dir.avro")
|
||||
assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true)))
|
||||
assert(df.collect().toSet == Set(Row(1.toDouble), Row(2.toDouble)))
|
||||
}
|
||||
|
@ -196,7 +196,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
dataFileWriter.append(rec2)
|
||||
dataFileWriter.flush()
|
||||
dataFileWriter.close()
|
||||
val df = spark.read.avro(s"$dir.avro")
|
||||
val df = spark.read.format("avro").load(s"$dir.avro")
|
||||
assert(df.schema.fields === Seq(StructField("field1", DoubleType, nullable = true)))
|
||||
assert(df.collect().toSet == Set(Row(1.toDouble), Row(null)))
|
||||
}
|
||||
|
@ -220,7 +220,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
dataFileWriter.flush()
|
||||
dataFileWriter.close()
|
||||
|
||||
val df = spark.read.avro(s"$dir.avro")
|
||||
val df = spark.read.format("avro").load(s"$dir.avro")
|
||||
assert(df.first() == Row(8))
|
||||
}
|
||||
}
|
||||
|
@ -255,7 +255,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
dataFileWriter.flush()
|
||||
dataFileWriter.close()
|
||||
|
||||
val df = spark.sqlContext.read.avro(s"$dir.avro")
|
||||
val df = spark.sqlContext.read.format("avro").load(s"$dir.avro")
|
||||
assertResult(field1)(df.selectExpr("field1.member0").first().get(0))
|
||||
assertResult(field2)(df.selectExpr("field2.member1").first().get(0))
|
||||
assertResult(field3)(df.selectExpr("field3.member2").first().get(0))
|
||||
|
@ -277,8 +277,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
Row(null, null, null, null, null),
|
||||
Row(null, null, null, null, null)))
|
||||
val df = spark.createDataFrame(rdd, schema)
|
||||
df.write.avro(dir.toString)
|
||||
assert(spark.read.avro(dir.toString).count == rdd.count)
|
||||
df.write.format("avro").save(dir.toString)
|
||||
assert(spark.read.format("avro").load(dir.toString).count == rdd.count)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -296,8 +296,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
Row(3f, 3.toShort, 3.toByte, true)
|
||||
))
|
||||
val df = spark.createDataFrame(rdd, schema)
|
||||
df.write.avro(dir.toString)
|
||||
assert(spark.read.avro(dir.toString).count == rdd.count)
|
||||
df.write.format("avro").save(dir.toString)
|
||||
assert(spark.read.format("avro").load(dir.toString).count == rdd.count)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -314,9 +314,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
Row(3f, new Date(1460066400500L))
|
||||
))
|
||||
val df = spark.createDataFrame(rdd, schema)
|
||||
df.write.avro(dir.toString)
|
||||
assert(spark.read.avro(dir.toString).count == rdd.count)
|
||||
assert(spark.read.avro(dir.toString).select("date").collect().map(_(0)).toSet ==
|
||||
df.write.format("avro").save(dir.toString)
|
||||
assert(spark.read.format("avro").load(dir.toString).count == rdd.count)
|
||||
assert(
|
||||
spark.read.format("avro").load(dir.toString).select("date").collect().map(_(0)).toSet ==
|
||||
Array(null, 1451865600000L, 1459987200000L).toSet)
|
||||
}
|
||||
}
|
||||
|
@ -350,8 +351,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
Array[Array[String]](Array[String]("CSH, tearing down the walls that divide us", "-jd")),
|
||||
Array[Row](Row("Bobby G. can't swim")))))
|
||||
val df = spark.createDataFrame(rdd, testSchema)
|
||||
df.write.avro(dir.toString)
|
||||
assert(spark.read.avro(dir.toString).count == rdd.count)
|
||||
df.write.format("avro").save(dir.toString)
|
||||
assert(spark.read.format("avro").load(dir.toString).count == rdd.count)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -363,14 +364,14 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
val deflateDir = s"$dir/deflate"
|
||||
val snappyDir = s"$dir/snappy"
|
||||
|
||||
val df = spark.read.avro(testAvro)
|
||||
val df = spark.read.format("avro").load(testAvro)
|
||||
spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed")
|
||||
df.write.avro(uncompressDir)
|
||||
df.write.format("avro").save(uncompressDir)
|
||||
spark.conf.set(AVRO_COMPRESSION_CODEC, "deflate")
|
||||
spark.conf.set(AVRO_DEFLATE_LEVEL, "9")
|
||||
df.write.avro(deflateDir)
|
||||
df.write.format("avro").save(deflateDir)
|
||||
spark.conf.set(AVRO_COMPRESSION_CODEC, "snappy")
|
||||
df.write.avro(snappyDir)
|
||||
df.write.format("avro").save(snappyDir)
|
||||
|
||||
val uncompressSize = FileUtils.sizeOfDirectory(new File(uncompressDir))
|
||||
val deflateSize = FileUtils.sizeOfDirectory(new File(deflateDir))
|
||||
|
@ -382,49 +383,50 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
}
|
||||
|
||||
test("dsl test") {
|
||||
val results = spark.read.avro(episodesAvro).select("title").collect()
|
||||
val results = spark.read.format("avro").load(episodesAvro).select("title").collect()
|
||||
assert(results.length === 8)
|
||||
}
|
||||
|
||||
test("support of various data types") {
|
||||
// This test uses data from test.avro. You can see the data and the schema of this file in
|
||||
// test.json and test.avsc
|
||||
val all = spark.read.avro(testAvro).collect()
|
||||
val all = spark.read.format("avro").load(testAvro).collect()
|
||||
assert(all.length == 3)
|
||||
|
||||
val str = spark.read.avro(testAvro).select("string").collect()
|
||||
val str = spark.read.format("avro").load(testAvro).select("string").collect()
|
||||
assert(str.map(_(0)).toSet.contains("Terran is IMBA!"))
|
||||
|
||||
val simple_map = spark.read.avro(testAvro).select("simple_map").collect()
|
||||
val simple_map = spark.read.format("avro").load(testAvro).select("simple_map").collect()
|
||||
assert(simple_map(0)(0).getClass.toString.contains("Map"))
|
||||
assert(simple_map.map(_(0).asInstanceOf[Map[String, Some[Int]]].size).toSet == Set(2, 0))
|
||||
|
||||
val union0 = spark.read.avro(testAvro).select("union_string_null").collect()
|
||||
val union0 = spark.read.format("avro").load(testAvro).select("union_string_null").collect()
|
||||
assert(union0.map(_(0)).toSet == Set("abc", "123", null))
|
||||
|
||||
val union1 = spark.read.avro(testAvro).select("union_int_long_null").collect()
|
||||
val union1 = spark.read.format("avro").load(testAvro).select("union_int_long_null").collect()
|
||||
assert(union1.map(_(0)).toSet == Set(66, 1, null))
|
||||
|
||||
val union2 = spark.read.avro(testAvro).select("union_float_double").collect()
|
||||
val union2 = spark.read.format("avro").load(testAvro).select("union_float_double").collect()
|
||||
assert(
|
||||
union2
|
||||
.map(x => new java.lang.Double(x(0).toString))
|
||||
.exists(p => Math.abs(p - Math.PI) < 0.001))
|
||||
|
||||
val fixed = spark.read.avro(testAvro).select("fixed3").collect()
|
||||
val fixed = spark.read.format("avro").load(testAvro).select("fixed3").collect()
|
||||
assert(fixed.map(_(0).asInstanceOf[Array[Byte]]).exists(p => p(1) == 3))
|
||||
|
||||
val enum = spark.read.avro(testAvro).select("enum").collect()
|
||||
val enum = spark.read.format("avro").load(testAvro).select("enum").collect()
|
||||
assert(enum.map(_(0)).toSet == Set("SPADES", "CLUBS", "DIAMONDS"))
|
||||
|
||||
val record = spark.read.avro(testAvro).select("record").collect()
|
||||
val record = spark.read.format("avro").load(testAvro).select("record").collect()
|
||||
assert(record(0)(0).getClass.toString.contains("Row"))
|
||||
assert(record.map(_(0).asInstanceOf[Row](0)).contains("TEST_STR123"))
|
||||
|
||||
val array_of_boolean = spark.read.avro(testAvro).select("array_of_boolean").collect()
|
||||
val array_of_boolean =
|
||||
spark.read.format("avro").load(testAvro).select("array_of_boolean").collect()
|
||||
assert(array_of_boolean.map(_(0).asInstanceOf[Seq[Boolean]].size).toSet == Set(3, 1, 0))
|
||||
|
||||
val bytes = spark.read.avro(testAvro).select("bytes").collect()
|
||||
val bytes = spark.read.format("avro").load(testAvro).select("bytes").collect()
|
||||
assert(bytes.map(_(0).asInstanceOf[Array[Byte]].length).toSet == Set(3, 1, 0))
|
||||
}
|
||||
|
||||
|
@ -444,7 +446,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
// get the same values back.
|
||||
withTempPath { dir =>
|
||||
val avroDir = s"$dir/avro"
|
||||
spark.read.avro(testAvro).write.avro(avroDir)
|
||||
spark.read.format("avro").load(testAvro).write.format("avro").save(avroDir)
|
||||
checkReloadMatchesSaved(testAvro, avroDir)
|
||||
}
|
||||
}
|
||||
|
@ -458,7 +460,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
val parameters = Map("recordName" -> name, "recordNamespace" -> namespace)
|
||||
|
||||
val avroDir = tempDir + "/namedAvro"
|
||||
spark.read.avro(testAvro).write.options(parameters).avro(avroDir)
|
||||
spark.read.format("avro").load(testAvro)
|
||||
.write.options(parameters).format("avro").save(avroDir)
|
||||
checkReloadMatchesSaved(testAvro, avroDir)
|
||||
|
||||
// Look at raw file and make sure has namespace info
|
||||
|
@ -489,22 +492,22 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
val cityDataFrame = spark.createDataFrame(cityRDD, testSchema)
|
||||
|
||||
val avroDir = tempDir + "/avro"
|
||||
cityDataFrame.write.avro(avroDir)
|
||||
assert(spark.read.avro(avroDir).collect().length == 3)
|
||||
cityDataFrame.write.format("avro").save(avroDir)
|
||||
assert(spark.read.format("avro").load(avroDir).collect().length == 3)
|
||||
|
||||
// TimesStamps are converted to longs
|
||||
val times = spark.read.avro(avroDir).select("Time").collect()
|
||||
val times = spark.read.format("avro").load(avroDir).select("Time").collect()
|
||||
assert(times.map(_(0)).toSet == Set(666, 777, 42))
|
||||
|
||||
// DecimalType should be converted to string
|
||||
val decimals = spark.read.avro(avroDir).select("Decimal").collect()
|
||||
val decimals = spark.read.format("avro").load(avroDir).select("Decimal").collect()
|
||||
assert(decimals.map(_(0)).contains("3.14"))
|
||||
|
||||
// There should be a null entry
|
||||
val length = spark.read.avro(avroDir).select("Length").collect()
|
||||
val length = spark.read.format("avro").load(avroDir).select("Length").collect()
|
||||
assert(length.map(_(0)).contains(null))
|
||||
|
||||
val binary = spark.read.avro(avroDir).select("Binary").collect()
|
||||
val binary = spark.read.format("avro").load(avroDir).select("Binary").collect()
|
||||
for (i <- arrayOfByte.indices) {
|
||||
assert(binary(1)(0).asInstanceOf[Array[Byte]](i) == arrayOfByte(i))
|
||||
}
|
||||
|
@ -523,10 +526,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
val writeDs = Seq((currentDate, currentTime)).toDS
|
||||
|
||||
val avroDir = tempDir + "/avro"
|
||||
writeDs.write.avro(avroDir)
|
||||
assert(spark.read.avro(avroDir).collect().length == 1)
|
||||
writeDs.write.format("avro").save(avroDir)
|
||||
assert(spark.read.format("avro").load(avroDir).collect().length == 1)
|
||||
|
||||
val readDs = spark.read.schema(schema).avro(avroDir).as[(Date, Timestamp)]
|
||||
val readDs = spark.read.schema(schema).format("avro").load(avroDir).as[(Date, Timestamp)]
|
||||
|
||||
assert(readDs.collect().sameElements(writeDs.collect()))
|
||||
}
|
||||
|
@ -534,10 +537,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
|
||||
test("support of globbed paths") {
|
||||
val resourceDir = testFile(".")
|
||||
val e1 = spark.read.avro(resourceDir + "../*/episodes.avro").collect()
|
||||
val e1 = spark.read.format("avro").load(resourceDir + "../*/episodes.avro").collect()
|
||||
assert(e1.length == 8)
|
||||
|
||||
val e2 = spark.read.avro(resourceDir + "../../*/*/episodes.avro").collect()
|
||||
val e2 = spark.read.format("avro").load(resourceDir + "../../*/*/episodes.avro").collect()
|
||||
assert(e2.length == 8)
|
||||
}
|
||||
|
||||
|
@ -555,8 +558,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
val writeDs = Seq((nullDate, nullTime)).toDS
|
||||
|
||||
val avroDir = tempDir + "/avro"
|
||||
writeDs.write.avro(avroDir)
|
||||
val readValues = spark.read.schema(schema).avro(avroDir).as[(Date, Timestamp)].collect
|
||||
writeDs.write.format("avro").save(avroDir)
|
||||
val readValues =
|
||||
spark.read.schema(schema).format("avro").load(avroDir).as[(Date, Timestamp)].collect
|
||||
|
||||
assert(readValues.size == 1)
|
||||
assert(readValues.head == ((nullDate, nullTime)))
|
||||
|
@ -579,9 +583,10 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
val result = spark
|
||||
.read
|
||||
.option("avroSchema", avroSchema)
|
||||
.avro(testAvro)
|
||||
.format("avro")
|
||||
.load(testAvro)
|
||||
.collect()
|
||||
val expected = spark.read.avro(testAvro).select("string").collect()
|
||||
val expected = spark.read.format("avro").load(testAvro).select("string").collect()
|
||||
assert(result.sameElements(expected))
|
||||
}
|
||||
|
||||
|
@ -601,7 +606,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
val result = spark
|
||||
.read
|
||||
.option("avroSchema", avroSchema)
|
||||
.avro(testAvro).select("missingField").first
|
||||
.format("avro").load(testAvro).select("missingField").first
|
||||
assert(result === Row("foo"))
|
||||
}
|
||||
|
||||
|
@ -609,17 +614,17 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
|
||||
// Directory given has no avro files
|
||||
intercept[AnalysisException] {
|
||||
withTempPath(dir => spark.read.avro(dir.getCanonicalPath))
|
||||
withTempPath(dir => spark.read.format("avro").load(dir.getCanonicalPath))
|
||||
}
|
||||
|
||||
intercept[AnalysisException] {
|
||||
spark.read.avro("very/invalid/path/123.avro")
|
||||
spark.read.format("avro").load("very/invalid/path/123.avro")
|
||||
}
|
||||
|
||||
// In case of globbed path that can't be matched to anything, another exception is thrown (and
|
||||
// exception message is helpful)
|
||||
intercept[AnalysisException] {
|
||||
spark.read.avro("*/*/*/*/*/*/*/something.avro")
|
||||
spark.read.format("avro").load("*/*/*/*/*/*/*/something.avro")
|
||||
}
|
||||
|
||||
intercept[FileNotFoundException] {
|
||||
|
@ -628,7 +633,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration
|
||||
try {
|
||||
hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true")
|
||||
spark.read.avro(dir.toString)
|
||||
spark.read.format("avro").load(dir.toString)
|
||||
} finally {
|
||||
hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty)
|
||||
}
|
||||
|
@ -642,7 +647,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
spark
|
||||
.read
|
||||
.option("ignoreExtension", false)
|
||||
.avro(dir.toString)
|
||||
.format("avro")
|
||||
.load(dir.toString)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -681,13 +687,13 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
test("test save and load") {
|
||||
// Test if load works as expected
|
||||
withTempPath { tempDir =>
|
||||
val df = spark.read.avro(episodesAvro)
|
||||
val df = spark.read.format("avro").load(episodesAvro)
|
||||
assert(df.count == 8)
|
||||
|
||||
val tempSaveDir = s"$tempDir/save/"
|
||||
|
||||
df.write.avro(tempSaveDir)
|
||||
val newDf = spark.read.avro(tempSaveDir)
|
||||
df.write.format("avro").save(tempSaveDir)
|
||||
val newDf = spark.read.format("avro").load(tempSaveDir)
|
||||
assert(newDf.count == 8)
|
||||
}
|
||||
}
|
||||
|
@ -695,20 +701,18 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
test("test load with non-Avro file") {
|
||||
// Test if load works as expected
|
||||
withTempPath { tempDir =>
|
||||
val df = spark.read.avro(episodesAvro)
|
||||
val df = spark.read.format("avro").load(episodesAvro)
|
||||
assert(df.count == 8)
|
||||
|
||||
val tempSaveDir = s"$tempDir/save/"
|
||||
df.write.avro(tempSaveDir)
|
||||
df.write.format("avro").save(tempSaveDir)
|
||||
|
||||
Files.createFile(new File(tempSaveDir, "non-avro").toPath)
|
||||
|
||||
val hadoopConf = spark.sqlContext.sparkContext.hadoopConfiguration
|
||||
val count = try {
|
||||
hadoopConf.set(AvroFileFormat.IgnoreFilesWithoutExtensionProperty, "true")
|
||||
val newDf = spark
|
||||
.read
|
||||
.avro(tempSaveDir)
|
||||
val newDf = spark.read.format("avro").load(tempSaveDir)
|
||||
newDf.count()
|
||||
} finally {
|
||||
hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty)
|
||||
|
@ -730,10 +734,11 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
StructField("record", StructType(Seq(StructField("value_field", StringType, false))), false),
|
||||
StructField("array_of_boolean", ArrayType(BooleanType), false),
|
||||
StructField("bytes", BinaryType, true)))
|
||||
val withSchema = spark.read.schema(partialColumns).avro(testAvro).collect()
|
||||
val withSchema = spark.read.schema(partialColumns).format("avro").load(testAvro).collect()
|
||||
val withOutSchema = spark
|
||||
.read
|
||||
.avro(testAvro)
|
||||
.format("avro")
|
||||
.load(testAvro)
|
||||
.select("string", "simple_map", "complex_map", "union_string_null", "union_int_long_null",
|
||||
"fixed3", "fixed2", "enum", "record", "array_of_boolean", "bytes")
|
||||
.collect()
|
||||
|
@ -751,7 +756,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
StructField("non_exist_field", StringType, false),
|
||||
StructField("non_exist_field2", StringType, false))),
|
||||
false)))
|
||||
val withEmptyColumn = spark.read.schema(schema).avro(testAvro).collect()
|
||||
val withEmptyColumn = spark.read.schema(schema).format("avro").load(testAvro).collect()
|
||||
|
||||
assert(withEmptyColumn.forall(_ == Row(null: String, Row(null: String, null: String))))
|
||||
}
|
||||
|
@ -762,8 +767,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
import sparkSession.implicits._
|
||||
val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records")
|
||||
val outputDir = s"$dir/${UUID.randomUUID}"
|
||||
df.write.avro(outputDir)
|
||||
val input = spark.read.avro(outputDir)
|
||||
df.write.format("avro").save(outputDir)
|
||||
val input = spark.read.format("avro").load(outputDir)
|
||||
assert(input.collect.toSet.size === 1024 * 3 + 1)
|
||||
assert(input.rdd.partitions.size > 2)
|
||||
}
|
||||
|
@ -780,9 +785,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
// Save avro file on output folder path
|
||||
val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1")))))
|
||||
val outputFolder = s"$tempDir/duplicate_names/"
|
||||
writeDf.write.avro(outputFolder)
|
||||
writeDf.write.format("avro").save(outputFolder)
|
||||
// Read avro file saved on the last step
|
||||
val readDf = spark.read.avro(outputFolder)
|
||||
val readDf = spark.read.format("avro").load(outputFolder)
|
||||
// Check if the written DataFrame is equals than read DataFrame
|
||||
assert(readDf.collect().sameElements(writeDf.collect()))
|
||||
}
|
||||
|
@ -801,9 +806,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
))))
|
||||
)
|
||||
val outputFolder = s"$tempDir/duplicate_names_array/"
|
||||
writeDf.write.avro(outputFolder)
|
||||
writeDf.write.format("avro").save(outputFolder)
|
||||
// Read avro file saved on the last step
|
||||
val readDf = spark.read.avro(outputFolder)
|
||||
val readDf = spark.read.format("avro").load(outputFolder)
|
||||
// Check if the written DataFrame is equals than read DataFrame
|
||||
assert(readDf.collect().sameElements(writeDf.collect()))
|
||||
}
|
||||
|
@ -822,9 +827,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
))))
|
||||
)
|
||||
val outputFolder = s"$tempDir/duplicate_names_map/"
|
||||
writeDf.write.avro(outputFolder)
|
||||
writeDf.write.format("avro").save(outputFolder)
|
||||
// Read avro file saved on the last step
|
||||
val readDf = spark.read.avro(outputFolder)
|
||||
val readDf = spark.read.format("avro").load(outputFolder)
|
||||
// Check if the written DataFrame is equals than read DataFrame
|
||||
assert(readDf.collect().sameElements(writeDf.collect()))
|
||||
}
|
||||
|
@ -837,32 +842,33 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
Paths.get(dir.getCanonicalPath, "episodes"))
|
||||
|
||||
val fileWithoutExtension = s"${dir.getCanonicalPath}/episodes"
|
||||
val df1 = spark.read.avro(fileWithoutExtension)
|
||||
val df1 = spark.read.format("avro").load(fileWithoutExtension)
|
||||
assert(df1.count == 8)
|
||||
|
||||
val schema = new StructType()
|
||||
.add("title", StringType)
|
||||
.add("air_date", StringType)
|
||||
.add("doctor", IntegerType)
|
||||
val df2 = spark.read.schema(schema).avro(fileWithoutExtension)
|
||||
val df2 = spark.read.schema(schema).format("avro").load(fileWithoutExtension)
|
||||
assert(df2.count == 8)
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-24836: checking the ignoreExtension option") {
|
||||
withTempPath { tempDir =>
|
||||
val df = spark.read.avro(episodesAvro)
|
||||
val df = spark.read.format("avro").load(episodesAvro)
|
||||
assert(df.count == 8)
|
||||
|
||||
val tempSaveDir = s"$tempDir/save/"
|
||||
df.write.avro(tempSaveDir)
|
||||
df.write.format("avro").save(tempSaveDir)
|
||||
|
||||
Files.createFile(new File(tempSaveDir, "non-avro").toPath)
|
||||
|
||||
val newDf = spark
|
||||
.read
|
||||
.option("ignoreExtension", false)
|
||||
.avro(tempSaveDir)
|
||||
.format("avro")
|
||||
.load(tempSaveDir)
|
||||
|
||||
assert(newDf.count == 8)
|
||||
}
|
||||
|
@ -880,7 +886,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
|
|||
val newDf = spark
|
||||
.read
|
||||
.option("ignoreExtension", "true")
|
||||
.avro(s"${dir.getCanonicalPath}/episodes")
|
||||
.format("avro")
|
||||
.load(s"${dir.getCanonicalPath}/episodes")
|
||||
newDf.count()
|
||||
} finally {
|
||||
hadoopConf.unset(AvroFileFormat.IgnoreFilesWithoutExtensionProperty)
|
||||
|
|
Loading…
Reference in a new issue