[SPARK-10289] [SQL] A direct write API for testing Parquet

This PR introduces a direct write API for testing Parquet. It's a DSL flavored version of the [`writeDirect` method] [1] comes with parquet-avro testing code. With this API, it's much easier to construct arbitrary Parquet structures. It's especially useful when adding regression tests for various compatibility corner cases.

Sample usage of this API can be found in the new test case added in `ParquetThriftCompatibilitySuite`.

[1]: https://github.com/apache/parquet-mr/blob/apache-parquet-1.8.1/parquet-avro/src/test/java/org/apache/parquet/avro/TestArrayCompatibility.java#L945-L972

Author: Cheng Lian <lian@databricks.com>

Closes #8454 from liancheng/spark-10289/parquet-testing-direct-write-api.
This commit is contained in:
Cheng Lian 2015-08-29 13:24:32 -07:00 committed by Michael Armbrust
parent 5369be8068
commit 24ffa85c00
2 changed files with 159 additions and 23 deletions

View file

@ -17,11 +17,15 @@
package org.apache.spark.sql.execution.datasources.parquet
import scala.collection.JavaConverters._
import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapAsJavaMapConverter, seqAsJavaListConverter}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, PathFilter}
import org.apache.parquet.hadoop.ParquetFileReader
import org.apache.parquet.schema.MessageType
import org.apache.parquet.hadoop.api.WriteSupport
import org.apache.parquet.hadoop.api.WriteSupport.WriteContext
import org.apache.parquet.hadoop.{ParquetFileReader, ParquetWriter}
import org.apache.parquet.io.api.RecordConsumer
import org.apache.parquet.schema.{MessageType, MessageTypeParser}
import org.apache.spark.sql.QueryTest
@ -38,11 +42,10 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq
val fs = fsPath.getFileSystem(configuration)
val parquetFiles = fs.listStatus(fsPath, new PathFilter {
override def accept(path: Path): Boolean = pathFilter(path)
}).toSeq
}).toSeq.asJava
val footers =
ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles.asJava, true)
footers.iterator().next().getParquetMetadata.getFileMetaData.getSchema
val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true)
footers.asScala.head.getParquetMetadata.getFileMetaData.getSchema
}
protected def logParquetSchema(path: String): Unit = {
@ -53,8 +56,69 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq
}
}
object ParquetCompatibilityTest {
def makeNullable[T <: AnyRef](i: Int)(f: => T): T = {
if (i % 3 == 0) null.asInstanceOf[T] else f
private[sql] object ParquetCompatibilityTest {
implicit class RecordConsumerDSL(consumer: RecordConsumer) {
def message(f: => Unit): Unit = {
consumer.startMessage()
f
consumer.endMessage()
}
def group(f: => Unit): Unit = {
consumer.startGroup()
f
consumer.endGroup()
}
def field(name: String, index: Int)(f: => Unit): Unit = {
consumer.startField(name, index)
f
consumer.endField(name, index)
}
}
/**
* A testing Parquet [[WriteSupport]] implementation used to write manually constructed Parquet
* records with arbitrary structures.
*/
private class DirectWriteSupport(schema: MessageType, metadata: Map[String, String])
extends WriteSupport[RecordConsumer => Unit] {
private var recordConsumer: RecordConsumer = _
override def init(configuration: Configuration): WriteContext = {
new WriteContext(schema, metadata.asJava)
}
override def write(recordWriter: RecordConsumer => Unit): Unit = {
recordWriter.apply(recordConsumer)
}
override def prepareForWrite(recordConsumer: RecordConsumer): Unit = {
this.recordConsumer = recordConsumer
}
}
/**
* Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path`.
* Records are produced by `recordWriters`.
*/
def writeDirect(path: String, schema: String, recordWriters: (RecordConsumer => Unit)*): Unit = {
writeDirect(path, schema, Map.empty[String, String], recordWriters: _*)
}
/**
* Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path`
* with given user-defined key-value `metadata`. Records are produced by `recordWriters`.
*/
def writeDirect(
path: String,
schema: String,
metadata: Map[String, String],
recordWriters: (RecordConsumer => Unit)*): Unit = {
val messageType = MessageTypeParser.parseMessageType(schema)
val writeSupport = new DirectWriteSupport(messageType, metadata)
val parquetWriter = new ParquetWriter[RecordConsumer => Unit](new Path(path), writeSupport)
try recordWriters.foreach(parquetWriter.write) finally parquetWriter.close()
}
}

View file

@ -33,11 +33,9 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
""".stripMargin)
checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i =>
def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i)
val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS")
Row(
val nonNullablePrimitiveValues = Seq(
i % 2 == 0,
i.toByte,
(i + 1).toShort,
@ -50,18 +48,15 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
s"val_$i",
s"val_$i",
// Thrift ENUM values are converted to Parquet binaries containing UTF-8 strings
suits(i % 4),
suits(i % 4))
nullable(i % 2 == 0: java.lang.Boolean),
nullable(i.toByte: java.lang.Byte),
nullable((i + 1).toShort: java.lang.Short),
nullable(i + 2: Integer),
nullable((i * 10).toLong: java.lang.Long),
nullable(i.toDouble + 0.2d: java.lang.Double),
nullable(s"val_$i"),
nullable(s"val_$i"),
nullable(suits(i % 4)),
val nullablePrimitiveValues = if (i % 3 == 0) {
Seq.fill(nonNullablePrimitiveValues.length)(null)
} else {
nonNullablePrimitiveValues
}
val complexValues = Seq(
Seq.tabulate(3)(n => s"arr_${i + n}"),
// Thrift `SET`s are converted to Parquet `LIST`s
Seq(i),
@ -71,6 +66,83 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}")
}
}.toMap)
Row(nonNullablePrimitiveValues ++ nullablePrimitiveValues ++ complexValues: _*)
})
}
test("SPARK-10136 list of primitive list") {
withTempPath { dir =>
val path = dir.getCanonicalPath
// This Parquet schema is translated from the following Thrift schema:
//
// struct ListOfPrimitiveList {
// 1: list<list<i32>> f;
// }
val schema =
s"""message ListOfPrimitiveList {
| required group f (LIST) {
| repeated group f_tuple (LIST) {
| repeated int32 f_tuple_tuple;
| }
| }
|}
""".stripMargin
writeDirect(path, schema, { rc =>
rc.message {
rc.field("f", 0) {
rc.group {
rc.field("f_tuple", 0) {
rc.group {
rc.field("f_tuple_tuple", 0) {
rc.addInteger(0)
rc.addInteger(1)
}
}
rc.group {
rc.field("f_tuple_tuple", 0) {
rc.addInteger(2)
rc.addInteger(3)
}
}
}
}
}
}
}, { rc =>
rc.message {
rc.field("f", 0) {
rc.group {
rc.field("f_tuple", 0) {
rc.group {
rc.field("f_tuple_tuple", 0) {
rc.addInteger(4)
rc.addInteger(5)
}
}
rc.group {
rc.field("f_tuple_tuple", 0) {
rc.addInteger(6)
rc.addInteger(7)
}
}
}
}
}
}
})
logParquetSchema(path)
checkAnswer(
sqlContext.read.parquet(path),
Seq(
Row(Seq(Seq(0, 1), Seq(2, 3))),
Row(Seq(Seq(4, 5), Seq(6, 7)))))
}
}
}