diff --git a/external/avro/src/test/resources/date.avro b/external/avro/src/test/resources/date.avro deleted file mode 100644 index 3a6761704c..0000000000 Binary files a/external/avro/src/test/resources/date.avro and /dev/null differ diff --git a/external/avro/src/test/resources/timestamp.avro b/external/avro/src/test/resources/timestamp.avro deleted file mode 100644 index daef50b78b..0000000000 Binary files a/external/avro/src/test/resources/timestamp.avro and /dev/null differ diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala new file mode 100644 index 0000000000..24d8c53764 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.avro + +import java.io.File +import java.sql.Timestamp + +import org.apache.avro.{LogicalTypes, Schema} +import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.file.DataFileWriter +import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} + +import org.apache.spark.SparkException +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.types.{StructField, StructType, TimestampType} + +class AvroLogicalTypeSuite extends QueryTest with SharedSQLContext with SQLTestUtils { + import testImplicits._ + + val dateSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "date", "type": {"type": "int", "logicalType": "date"}} + ] + } + """ + + val dateInputData = Seq(7, 365, 0) + + def dateFile(path: String): String = { + val schema = new Schema.Parser().parse(dateSchema) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + val result = s"$path/test.avro" + dataFileWriter.create(schema, new File(result)) + + dateInputData.foreach { x => + val record = new GenericData.Record(schema) + record.put("date", x) + dataFileWriter.append(record) + } + dataFileWriter.flush() + dataFileWriter.close() + result + } + + test("Logical type: date") { + withTempDir { dir => + val expected = dateInputData.map(t => Row(DateTimeUtils.toJavaDate(t))) + val dateAvro = dateFile(dir.getAbsolutePath) + val df = spark.read.format("avro").load(dateAvro) + + checkAnswer(df, expected) + + checkAnswer(spark.read.format("avro").option("avroSchema", dateSchema).load(dateAvro), + expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + val timestampSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "timestamp_millis", "type": {"type": "long","logicalType": "timestamp-millis"}}, + {"name": "timestamp_micros", "type": {"type": "long","logicalType": "timestamp-micros"}}, + {"name": "long", "type": "long"} + ] + } + """ + + val timestampInputData = Seq((1000L, 2000L, 3000L), (666000L, 999000L, 777000L)) + + def timestampFile(path: String): String = { + val schema = new Schema.Parser().parse(timestampSchema) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + val result = s"$path/test.avro" + dataFileWriter.create(schema, new File(result)) + + timestampInputData.foreach { t => + val record = new GenericData.Record(schema) + record.put("timestamp_millis", t._1) + // For microsecond precision, we multiple the value by 1000 to match the expected answer as + // timestamp with millisecond precision. + record.put("timestamp_micros", t._2 * 1000) + record.put("long", t._3) + dataFileWriter.append(record) + } + dataFileWriter.flush() + dataFileWriter.close() + result + } + + test("Logical type: timestamp_millis") { + withTempDir { dir => + val expected = timestampInputData.map(t => Row(new Timestamp(t._1))) + val timestampAvro = timestampFile(dir.getAbsolutePath) + val df = spark.read.format("avro").load(timestampAvro).select('timestamp_millis) + + checkAnswer(df, expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Logical type: timestamp_micros") { + withTempDir { dir => + val expected = timestampInputData.map(t => Row(new Timestamp(t._2))) + val timestampAvro = timestampFile(dir.getAbsolutePath) + val df = spark.read.format("avro").load(timestampAvro).select('timestamp_micros) + + checkAnswer(df, expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Logical type: specify different output timestamp types") { + withTempDir { dir => + val timestampAvro = timestampFile(dir.getAbsolutePath) + val df = + spark.read.format("avro").load(timestampAvro).select('timestamp_millis, 'timestamp_micros) + + val expected = timestampInputData.map(t => Row(new Timestamp(t._1), new Timestamp(t._2))) + + Seq("TIMESTAMP_MILLIS", "TIMESTAMP_MICROS").foreach { timestampType => + withSQLConf(SQLConf.AVRO_OUTPUT_TIMESTAMP_TYPE.key -> timestampType) { + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + } + } + + test("Read Long type as Timestamp") { + withTempDir { dir => + val timestampAvro = timestampFile(dir.getAbsolutePath) + val schema = StructType(StructField("long", TimestampType, true) :: Nil) + val df = spark.read.format("avro").schema(schema).load(timestampAvro).select('long) + + val expected = timestampInputData.map(t => Row(new Timestamp(t._3))) + + checkAnswer(df, expected) + } + } + + test("Logical type: user specified schema") { + withTempDir { dir => + val timestampAvro = timestampFile(dir.getAbsolutePath) + val expected = timestampInputData + .map(t => Row(new Timestamp(t._1), new Timestamp(t._2), t._3)) + + val df = spark.read.format("avro").option("avroSchema", timestampSchema).load(timestampAvro) + checkAnswer(df, expected) + } + } + + val decimalInputData = Seq("1.23", "4.56", "78.90", "-1", "-2.31") + + def decimalSchemaAndFile(path: String): (String, String) = { + val precision = 4 + val scale = 2 + val bytesFieldName = "bytes" + val bytesSchema = s"""{ + "type":"bytes", + "logicalType":"decimal", + "precision":$precision, + "scale":$scale + } + """ + + val fixedFieldName = "fixed" + val fixedSchema = s"""{ + "type":"fixed", + "size":5, + "logicalType":"decimal", + "precision":$precision, + "scale":$scale, + "name":"foo" + } + """ + val avroSchema = s""" + { + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [ + {"name": "$bytesFieldName", "type": $bytesSchema}, + {"name": "$fixedFieldName", "type": $fixedSchema} + ] + } + """ + val schema = new Schema.Parser().parse(avroSchema) + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + val decimalConversion = new DecimalConversion + val avroFile = s"$path/test.avro" + dataFileWriter.create(schema, new File(avroFile)) + val logicalType = LogicalTypes.decimal(precision, scale) + + decimalInputData.map { x => + val avroRec = new GenericData.Record(schema) + val decimal = new java.math.BigDecimal(x).setScale(scale) + val bytes = + decimalConversion.toBytes(decimal, schema.getField(bytesFieldName).schema, logicalType) + avroRec.put(bytesFieldName, bytes) + val fixed = + decimalConversion.toFixed(decimal, schema.getField(fixedFieldName).schema, logicalType) + avroRec.put(fixedFieldName, fixed) + dataFileWriter.append(avroRec) + } + dataFileWriter.flush() + dataFileWriter.close() + + (avroSchema, avroFile) + } + + test("Logical type: Decimal") { + withTempDir { dir => + val (avroSchema, avroFile) = decimalSchemaAndFile(dir.getAbsolutePath) + val expected = + decimalInputData.map { x => Row(new java.math.BigDecimal(x), new java.math.BigDecimal(x)) } + val df = spark.read.format("avro").load(avroFile) + checkAnswer(df, expected) + checkAnswer(spark.read.format("avro").option("avroSchema", avroSchema).load(avroFile), + expected) + + withTempPath { path => + df.write.format("avro").save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + + test("Logical type: Decimal with too large precision") { + withTempDir { dir => + val schema = new Schema.Parser().parse("""{ + "namespace": "logical", + "type": "record", + "name": "test", + "fields": [{ + "name": "decimal", + "type": {"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2} + }] + }""") + val datumWriter = new GenericDatumWriter[GenericRecord](schema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(schema, new File(s"$dir.avro")) + val avroRec = new GenericData.Record(schema) + val decimal = new java.math.BigDecimal("0.12345678901234567890123456789012345678") + val bytes = (new DecimalConversion).toBytes(decimal, schema, LogicalTypes.decimal(39, 38)) + avroRec.put("decimal", bytes) + dataFileWriter.append(avroRec) + dataFileWriter.flush() + dataFileWriter.close() + + val msg = intercept[SparkException] { + spark.read.format("avro").load(s"$dir.avro").collect() + }.getCause.getMessage + assert(msg.contains("Unscaled value too large for precision")) + } + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 3fa43bf929..b07b1464ef 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -25,8 +25,7 @@ import java.util.{TimeZone, UUID} import scala.collection.JavaConverters._ -import org.apache.avro.{LogicalTypes, Schema} -import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.Schema import org.apache.avro.Schema.{Field, Type} import org.apache.avro.file.{DataFileReader, DataFileWriter} import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} @@ -35,7 +34,6 @@ import org.apache.commons.io.FileUtils import org.apache.spark.SparkException import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} @@ -47,50 +45,6 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val episodesAvro = testFile("episodes.avro") val testAvro = testFile("test.avro") - // The test file timestamp.avro is generated via following Python code: - // import json - // import avro.schema - // from avro.datafile import DataFileWriter - // from avro.io import DatumWriter - // - // write_schema = avro.schema.parse(json.dumps({ - // "namespace": "logical", - // "type": "record", - // "name": "test", - // "fields": [ - // {"name": "timestamp_millis", "type": {"type": "long","logicalType": "timestamp-millis"}}, - // {"name": "timestamp_micros", "type": {"type": "long","logicalType": "timestamp-micros"}}, - // {"name": "long", "type": "long"} - // ] - // })) - // - // writer = DataFileWriter(open("timestamp.avro", "wb"), DatumWriter(), write_schema) - // writer.append({"timestamp_millis": 1000, "timestamp_micros": 2000000, "long": 3000}) - // writer.append({"timestamp_millis": 666000, "timestamp_micros": 999000000, "long": 777000}) - // writer.close() - val timestampAvro = testFile("timestamp.avro") - - // The test file date.avro is generated via following Python code: - // import json - // import avro.schema - // from avro.datafile import DataFileWriter - // from avro.io import DatumWriter - // - // write_schema = avro.schema.parse(json.dumps({ - // "namespace": "logical", - // "type": "record", - // "name": "test", - // "fields": [ - // {"name": "date", "type": {"type": "int", "logicalType": "date"}} - // ] - // })) - // - // writer = DataFileWriter(open("date.avro", "wb"), DatumWriter(), write_schema) - // writer.append({"date": 7}) - // writer.append({"date": 365}) - // writer.close() - val dateAvro = testFile("date.avro") - override protected def beforeAll(): Unit = { super.beforeAll() spark.conf.set("spark.sql.files.maxPartitionBytes", 1024) @@ -399,200 +353,6 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } - test("Logical type: date") { - val expected = Seq(7, 365).map(t => Row(DateTimeUtils.toJavaDate(t))) - val df = spark.read.format("avro").load(dateAvro) - - checkAnswer(df, expected) - - val avroSchema = s""" - { - "namespace": "logical", - "type": "record", - "name": "test", - "fields": [ - {"name": "date", "type": {"type": "int", "logicalType": "date"}} - ] - } - """ - - checkAnswer(spark.read.format("avro").option("avroSchema", avroSchema).load(dateAvro), - expected) - - withTempPath { dir => - df.write.format("avro").save(dir.toString) - checkAnswer(spark.read.format("avro").load(dir.toString), expected) - } - } - - test("Logical type: timestamp_millis") { - val expected = Seq(1000L, 666000L).map(t => Row(new Timestamp(t))) - val df = spark.read.format("avro").load(timestampAvro).select('timestamp_millis) - - checkAnswer(df, expected) - - withTempPath { dir => - df.write.format("avro").save(dir.toString) - checkAnswer(spark.read.format("avro").load(dir.toString), expected) - } - } - - test("Logical type: timestamp_micros") { - val expected = Seq(2000L, 999000L).map(t => Row(new Timestamp(t))) - val df = spark.read.format("avro").load(timestampAvro).select('timestamp_micros) - - checkAnswer(df, expected) - - withTempPath { dir => - df.write.format("avro").save(dir.toString) - checkAnswer(spark.read.format("avro").load(dir.toString), expected) - } - } - - test("Logical type: specify different output timestamp types") { - val df = - spark.read.format("avro").load(timestampAvro).select('timestamp_millis, 'timestamp_micros) - - val expected = Seq((1000L, 2000L), (666000L, 999000L)) - .map(t => Row(new Timestamp(t._1), new Timestamp(t._2))) - - Seq("TIMESTAMP_MILLIS", "TIMESTAMP_MICROS").foreach { timestampType => - withSQLConf(SQLConf.AVRO_OUTPUT_TIMESTAMP_TYPE.key -> timestampType) { - withTempPath { dir => - df.write.format("avro").save(dir.toString) - checkAnswer(spark.read.format("avro").load(dir.toString), expected) - } - } - } - } - - test("Read Long type as Timestamp") { - val schema = StructType(StructField("long", TimestampType, true) :: Nil) - val df = spark.read.format("avro").schema(schema).load(timestampAvro).select('long) - - val expected = Seq(3000L, 777000L).map(t => Row(new Timestamp(t))) - - checkAnswer(df, expected) - } - - test("Logical type: user specified schema") { - val expected = Seq((1000L, 2000L, 3000L), (666000L, 999000L, 777000L)) - .map(t => Row(new Timestamp(t._1), new Timestamp(t._2), t._3)) - - val avroSchema = s""" - { - "namespace": "logical", - "type": "record", - "name": "test", - "fields": [ - {"name": "timestamp_millis", "type": {"type": "long","logicalType": "timestamp-millis"}}, - {"name": "timestamp_micros", "type": {"type": "long","logicalType": "timestamp-micros"}}, - {"name": "long", "type": "long"} - ] - } - """ - val df = spark.read.format("avro").option("avroSchema", avroSchema).load(timestampAvro) - checkAnswer(df, expected) - } - - test("Logical type: Decimal") { - val precision = 4 - val scale = 2 - val bytesFieldName = "bytes" - val bytesSchema = s"""{ - "type":"bytes", - "logicalType":"decimal", - "precision":$precision, - "scale":$scale - } - """ - - val fixedFieldName = "fixed" - val fixedSchema = s"""{ - "type":"fixed", - "size":5, - "logicalType":"decimal", - "precision":$precision, - "scale":$scale, - "name":"foo" - } - """ - val avroSchema = s""" - { - "namespace": "logical", - "type": "record", - "name": "test", - "fields": [ - {"name": "$bytesFieldName", "type": $bytesSchema}, - {"name": "$fixedFieldName", "type": $fixedSchema} - ] - } - """ - val schema = new Schema.Parser().parse(avroSchema) - val datumWriter = new GenericDatumWriter[GenericRecord](schema) - val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) - val decimalConversion = new DecimalConversion - withTempDir { dir => - val avroFile = s"$dir.avro" - dataFileWriter.create(schema, new File(avroFile)) - val logicalType = LogicalTypes.decimal(precision, scale) - val data = Seq("1.23", "4.56", "78.90", "-1", "-2.31") - data.map { x => - val avroRec = new GenericData.Record(schema) - val decimal = new java.math.BigDecimal(x).setScale(scale) - val bytes = - decimalConversion.toBytes(decimal, schema.getField(bytesFieldName).schema, logicalType) - avroRec.put(bytesFieldName, bytes) - val fixed = - decimalConversion.toFixed(decimal, schema.getField(fixedFieldName).schema, logicalType) - avroRec.put(fixedFieldName, fixed) - dataFileWriter.append(avroRec) - } - dataFileWriter.flush() - dataFileWriter.close() - - val expected = data.map { x => Row(new java.math.BigDecimal(x), new java.math.BigDecimal(x)) } - val df = spark.read.format("avro").load(avroFile) - checkAnswer(df, expected) - checkAnswer(spark.read.format("avro").option("avroSchema", avroSchema).load(avroFile), - expected) - - withTempPath { path => - df.write.format("avro").save(path.toString) - checkAnswer(spark.read.format("avro").load(path.toString), expected) - } - } - } - - test("Logical type: Decimal with too large precision") { - withTempDir { dir => - val schema = new Schema.Parser().parse("""{ - "namespace": "logical", - "type": "record", - "name": "test", - "fields": [{ - "name": "decimal", - "type": {"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2} - }] - }""") - val datumWriter = new GenericDatumWriter[GenericRecord](schema) - val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) - dataFileWriter.create(schema, new File(s"$dir.avro")) - val avroRec = new GenericData.Record(schema) - val decimal = new java.math.BigDecimal("0.12345678901234567890123456789012345678") - val bytes = (new DecimalConversion).toBytes(decimal, schema, LogicalTypes.decimal(39, 38)) - avroRec.put("decimal", bytes) - dataFileWriter.append(avroRec) - dataFileWriter.flush() - dataFileWriter.close() - - val msg = intercept[SparkException] { - spark.read.format("avro").load(s"$dir.avro").collect() - }.getCause.getMessage - assert(msg.contains("Unscaled value too large for precision")) - } - } - test("Array data types") { withTempPath { dir => val testSchema = StructType(Seq(