[SPARK-10289] [SQL] A direct write API for testing Parquet
This PR introduces a direct write API for testing Parquet. It's a DSL flavored version of the [`writeDirect` method] [1] comes with parquet-avro testing code. With this API, it's much easier to construct arbitrary Parquet structures. It's especially useful when adding regression tests for various compatibility corner cases. Sample usage of this API can be found in the new test case added in `ParquetThriftCompatibilitySuite`. [1]: https://github.com/apache/parquet-mr/blob/apache-parquet-1.8.1/parquet-avro/src/test/java/org/apache/parquet/avro/TestArrayCompatibility.java#L945-L972 Author: Cheng Lian <lian@databricks.com> Closes #8454 from liancheng/spark-10289/parquet-testing-direct-write-api.
This commit is contained in:
parent
5369be8068
commit
24ffa85c00
|
@ -17,11 +17,15 @@
|
|||
|
||||
package org.apache.spark.sql.execution.datasources.parquet
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapAsJavaMapConverter, seqAsJavaListConverter}
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{Path, PathFilter}
|
||||
import org.apache.parquet.hadoop.ParquetFileReader
|
||||
import org.apache.parquet.schema.MessageType
|
||||
import org.apache.parquet.hadoop.api.WriteSupport
|
||||
import org.apache.parquet.hadoop.api.WriteSupport.WriteContext
|
||||
import org.apache.parquet.hadoop.{ParquetFileReader, ParquetWriter}
|
||||
import org.apache.parquet.io.api.RecordConsumer
|
||||
import org.apache.parquet.schema.{MessageType, MessageTypeParser}
|
||||
|
||||
import org.apache.spark.sql.QueryTest
|
||||
|
||||
|
@ -38,11 +42,10 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq
|
|||
val fs = fsPath.getFileSystem(configuration)
|
||||
val parquetFiles = fs.listStatus(fsPath, new PathFilter {
|
||||
override def accept(path: Path): Boolean = pathFilter(path)
|
||||
}).toSeq
|
||||
}).toSeq.asJava
|
||||
|
||||
val footers =
|
||||
ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles.asJava, true)
|
||||
footers.iterator().next().getParquetMetadata.getFileMetaData.getSchema
|
||||
val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true)
|
||||
footers.asScala.head.getParquetMetadata.getFileMetaData.getSchema
|
||||
}
|
||||
|
||||
protected def logParquetSchema(path: String): Unit = {
|
||||
|
@ -53,8 +56,69 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq
|
|||
}
|
||||
}
|
||||
|
||||
object ParquetCompatibilityTest {
|
||||
def makeNullable[T <: AnyRef](i: Int)(f: => T): T = {
|
||||
if (i % 3 == 0) null.asInstanceOf[T] else f
|
||||
private[sql] object ParquetCompatibilityTest {
|
||||
implicit class RecordConsumerDSL(consumer: RecordConsumer) {
|
||||
def message(f: => Unit): Unit = {
|
||||
consumer.startMessage()
|
||||
f
|
||||
consumer.endMessage()
|
||||
}
|
||||
|
||||
def group(f: => Unit): Unit = {
|
||||
consumer.startGroup()
|
||||
f
|
||||
consumer.endGroup()
|
||||
}
|
||||
|
||||
def field(name: String, index: Int)(f: => Unit): Unit = {
|
||||
consumer.startField(name, index)
|
||||
f
|
||||
consumer.endField(name, index)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A testing Parquet [[WriteSupport]] implementation used to write manually constructed Parquet
|
||||
* records with arbitrary structures.
|
||||
*/
|
||||
private class DirectWriteSupport(schema: MessageType, metadata: Map[String, String])
|
||||
extends WriteSupport[RecordConsumer => Unit] {
|
||||
|
||||
private var recordConsumer: RecordConsumer = _
|
||||
|
||||
override def init(configuration: Configuration): WriteContext = {
|
||||
new WriteContext(schema, metadata.asJava)
|
||||
}
|
||||
|
||||
override def write(recordWriter: RecordConsumer => Unit): Unit = {
|
||||
recordWriter.apply(recordConsumer)
|
||||
}
|
||||
|
||||
override def prepareForWrite(recordConsumer: RecordConsumer): Unit = {
|
||||
this.recordConsumer = recordConsumer
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path`.
|
||||
* Records are produced by `recordWriters`.
|
||||
*/
|
||||
def writeDirect(path: String, schema: String, recordWriters: (RecordConsumer => Unit)*): Unit = {
|
||||
writeDirect(path, schema, Map.empty[String, String], recordWriters: _*)
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path`
|
||||
* with given user-defined key-value `metadata`. Records are produced by `recordWriters`.
|
||||
*/
|
||||
def writeDirect(
|
||||
path: String,
|
||||
schema: String,
|
||||
metadata: Map[String, String],
|
||||
recordWriters: (RecordConsumer => Unit)*): Unit = {
|
||||
val messageType = MessageTypeParser.parseMessageType(schema)
|
||||
val writeSupport = new DirectWriteSupport(messageType, metadata)
|
||||
val parquetWriter = new ParquetWriter[RecordConsumer => Unit](new Path(path), writeSupport)
|
||||
try recordWriters.foreach(parquetWriter.write) finally parquetWriter.close()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,11 +33,9 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
|
|||
""".stripMargin)
|
||||
|
||||
checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i =>
|
||||
def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i)
|
||||
|
||||
val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS")
|
||||
|
||||
Row(
|
||||
val nonNullablePrimitiveValues = Seq(
|
||||
i % 2 == 0,
|
||||
i.toByte,
|
||||
(i + 1).toShort,
|
||||
|
@ -50,18 +48,15 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
|
|||
s"val_$i",
|
||||
s"val_$i",
|
||||
// Thrift ENUM values are converted to Parquet binaries containing UTF-8 strings
|
||||
suits(i % 4),
|
||||
suits(i % 4))
|
||||
|
||||
nullable(i % 2 == 0: java.lang.Boolean),
|
||||
nullable(i.toByte: java.lang.Byte),
|
||||
nullable((i + 1).toShort: java.lang.Short),
|
||||
nullable(i + 2: Integer),
|
||||
nullable((i * 10).toLong: java.lang.Long),
|
||||
nullable(i.toDouble + 0.2d: java.lang.Double),
|
||||
nullable(s"val_$i"),
|
||||
nullable(s"val_$i"),
|
||||
nullable(suits(i % 4)),
|
||||
val nullablePrimitiveValues = if (i % 3 == 0) {
|
||||
Seq.fill(nonNullablePrimitiveValues.length)(null)
|
||||
} else {
|
||||
nonNullablePrimitiveValues
|
||||
}
|
||||
|
||||
val complexValues = Seq(
|
||||
Seq.tabulate(3)(n => s"arr_${i + n}"),
|
||||
// Thrift `SET`s are converted to Parquet `LIST`s
|
||||
Seq(i),
|
||||
|
@ -71,6 +66,83 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
|
|||
Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}")
|
||||
}
|
||||
}.toMap)
|
||||
|
||||
Row(nonNullablePrimitiveValues ++ nullablePrimitiveValues ++ complexValues: _*)
|
||||
})
|
||||
}
|
||||
|
||||
test("SPARK-10136 list of primitive list") {
|
||||
withTempPath { dir =>
|
||||
val path = dir.getCanonicalPath
|
||||
|
||||
// This Parquet schema is translated from the following Thrift schema:
|
||||
//
|
||||
// struct ListOfPrimitiveList {
|
||||
// 1: list<list<i32>> f;
|
||||
// }
|
||||
val schema =
|
||||
s"""message ListOfPrimitiveList {
|
||||
| required group f (LIST) {
|
||||
| repeated group f_tuple (LIST) {
|
||||
| repeated int32 f_tuple_tuple;
|
||||
| }
|
||||
| }
|
||||
|}
|
||||
""".stripMargin
|
||||
|
||||
writeDirect(path, schema, { rc =>
|
||||
rc.message {
|
||||
rc.field("f", 0) {
|
||||
rc.group {
|
||||
rc.field("f_tuple", 0) {
|
||||
rc.group {
|
||||
rc.field("f_tuple_tuple", 0) {
|
||||
rc.addInteger(0)
|
||||
rc.addInteger(1)
|
||||
}
|
||||
}
|
||||
|
||||
rc.group {
|
||||
rc.field("f_tuple_tuple", 0) {
|
||||
rc.addInteger(2)
|
||||
rc.addInteger(3)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}, { rc =>
|
||||
rc.message {
|
||||
rc.field("f", 0) {
|
||||
rc.group {
|
||||
rc.field("f_tuple", 0) {
|
||||
rc.group {
|
||||
rc.field("f_tuple_tuple", 0) {
|
||||
rc.addInteger(4)
|
||||
rc.addInteger(5)
|
||||
}
|
||||
}
|
||||
|
||||
rc.group {
|
||||
rc.field("f_tuple_tuple", 0) {
|
||||
rc.addInteger(6)
|
||||
rc.addInteger(7)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
logParquetSchema(path)
|
||||
|
||||
checkAnswer(
|
||||
sqlContext.read.parquet(path),
|
||||
Seq(
|
||||
Row(Seq(Seq(0, 1), Seq(2, 3))),
|
||||
Row(Seq(Seq(4, 5), Seq(6, 7)))))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue