[SPARK-16101][SQL] Refactoring CSV write path to be consistent with JSON data source

## What changes were proposed in this pull request?

This PR refactors CSV write path to be consistent with JSON data source.

This PR makes the methods in classes have consistent arguments with JSON ones.
  - `UnivocityGenerator` and `JacksonGenerator`

    ``` scala
    private[csv] class UnivocityGenerator(
        schema: StructType,
        writer: Writer,
        options: CSVOptions = new CSVOptions(Map.empty[String, String])) {
    ...

    def write ...
    def close ...
    def flush ...
    ```

    ``` scala
    private[sql] class JacksonGenerator(
       schema: StructType,
       writer: Writer,
       options: JSONOptions = new JSONOptions(Map.empty[String, String])) {
    ...

    def write ...
    def close ...
    def flush ...
    ```

- This PR also makes the classes put in together in a consistent manner with JSON.
  - `CsvFileFormat`

    ``` scala
    CsvFileFormat
    CsvOutputWriter
    ```

  - `JsonFileFormat`

    ``` scala
    JsonFileFormat
    JsonOutputWriter
    ```

## How was this patch tested?

Existing tests should cover this.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #16496 from HyukjinKwon/SPARK-16101-write.
This commit is contained in:
hyukjinkwon 2017-01-21 10:43:52 +08:00 committed by Wenchen Fan
parent ea31f92bb8
commit 54268b42dc
5 changed files with 135 additions and 115 deletions

View file

@ -20,14 +20,15 @@ package org.apache.spark.sql.execution.datasources.csv
import java.nio.charset.{Charset, StandardCharsets}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce._
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.{Dataset, Encoders, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
@ -130,7 +131,18 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
CompressionCodecs.setCodecConfiguration(conf, codec)
}
new CSVOutputWriterFactory(csvOptions)
new OutputWriterFactory {
override def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new CsvOutputWriter(path, dataSchema, context, csvOptions)
}
override def getFileExtension(context: TaskAttemptContext): String = {
".csv" + CodecStreams.getCompressionExtension(context)
}
}
}
override def buildReader(
@ -228,3 +240,18 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
schema.foreach(field => verifyType(field.dataType))
}
}
private[csv] class CsvOutputWriter(
path: String,
dataSchema: StructType,
context: TaskAttemptContext,
params: CSVOptions) extends OutputWriter with Logging {
private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
private val gen = new UnivocityGenerator(dataSchema, writer, params)
override def write(row: InternalRow): Unit = gen.write(row)
override def close(): Unit = gen.close()
}

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.csv
import java.nio.charset.StandardCharsets
import java.util.Locale
import com.univocity.parsers.csv.CsvWriterSettings
import org.apache.commons.lang3.time.FastDateFormat
import org.apache.spark.internal.Logging
@ -126,6 +127,21 @@ private[csv] class CSVOptions(@transient private val parameters: CaseInsensitive
val inputBufferSize = 128
val isCommentSet = this.comment != '\u0000'
def asWriterSettings: CsvWriterSettings = {
val writerSettings = new CsvWriterSettings()
val format = writerSettings.getFormat
format.setDelimiter(delimiter)
format.setQuote(quote)
format.setQuoteEscape(escape)
format.setComment(comment)
writerSettings.setNullValue(nullValue)
writerSettings.setEmptyValue(nullValue)
writerSettings.setSkipEmptyLines(true)
writerSettings.setQuoteAllFields(quoteAll)
writerSettings.setQuoteEscapingEnabled(escapeQuotes)
writerSettings
}
}
object CSVOptions {

View file

@ -58,43 +58,3 @@ private[csv] class CsvReader(params: CSVOptions) {
*/
def parseLine(line: String): Array[String] = parser.parseLine(line)
}
/**
* Converts a sequence of string to CSV string
*
* @param params Parameters object for configuration
* @param headers headers for columns
*/
private[csv] class LineCsvWriter(
params: CSVOptions,
headers: Seq[String],
output: OutputStream) extends Logging {
private val writerSettings = new CsvWriterSettings
private val format = writerSettings.getFormat
format.setDelimiter(params.delimiter)
format.setQuote(params.quote)
format.setQuoteEscape(params.escape)
format.setComment(params.comment)
writerSettings.setNullValue(params.nullValue)
writerSettings.setEmptyValue(params.nullValue)
writerSettings.setSkipEmptyLines(true)
writerSettings.setQuoteAllFields(params.quoteAll)
writerSettings.setHeaders(headers: _*)
writerSettings.setQuoteEscapingEnabled(params.escapeQuotes)
private val writer = new CsvWriter(output, StandardCharsets.UTF_8, writerSettings)
def writeRow(row: Seq[String], includeHeader: Boolean): Unit = {
if (includeHeader) {
writer.writeHeaders()
}
writer.writeRow(row: _*)
}
def close(): Unit = {
writer.close()
}
}

View file

@ -159,75 +159,3 @@ object CSVRelation extends Logging {
}
}
}
private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory {
override def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new CsvOutputWriter(path, dataSchema, context, params)
}
override def getFileExtension(context: TaskAttemptContext): String = {
".csv" + CodecStreams.getCompressionExtension(context)
}
}
private[csv] class CsvOutputWriter(
path: String,
dataSchema: StructType,
context: TaskAttemptContext,
params: CSVOptions) extends OutputWriter with Logging {
// A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`.
// When the value is null, this converter should not be called.
private type ValueConverter = (InternalRow, Int) => String
// `ValueConverter`s for all values in the fields of the schema
private val valueConverters: Array[ValueConverter] =
dataSchema.map(_.dataType).map(makeConverter).toArray
private var printHeader: Boolean = params.headerFlag
private val writer = CodecStreams.createOutputStream(context, new Path(path))
private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq, writer)
private def rowToString(row: InternalRow): Seq[String] = {
var i = 0
val values = new Array[String](row.numFields)
while (i < row.numFields) {
if (!row.isNullAt(i)) {
values(i) = valueConverters(i).apply(row, i)
} else {
values(i) = params.nullValue
}
i += 1
}
values
}
private def makeConverter(dataType: DataType): ValueConverter = dataType match {
case DateType =>
(row: InternalRow, ordinal: Int) =>
params.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal)))
case TimestampType =>
(row: InternalRow, ordinal: Int) =>
params.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal)))
case udt: UserDefinedType[_] => makeConverter(udt.sqlType)
case dt: DataType =>
(row: InternalRow, ordinal: Int) =>
row.get(ordinal, dt).toString
}
override def write(row: InternalRow): Unit = {
csvWriter.writeRow(rowToString(row), printHeader)
printHeader = false
}
override def close(): Unit = {
csvWriter.close()
writer.close()
}
}

View file

@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.csv
import java.io.Writer
import com.univocity.parsers.csv.CsvWriter
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
private[csv] class UnivocityGenerator(
schema: StructType,
writer: Writer,
options: CSVOptions = new CSVOptions(Map.empty[String, String])) {
private val writerSettings = options.asWriterSettings
writerSettings.setHeaders(schema.fieldNames: _*)
private val gen = new CsvWriter(writer, writerSettings)
private var printHeader = options.headerFlag
// A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`.
// When the value is null, this converter should not be called.
private type ValueConverter = (InternalRow, Int) => String
// `ValueConverter`s for all values in the fields of the schema
private val valueConverters: Array[ValueConverter] =
schema.map(_.dataType).map(makeConverter).toArray
private def makeConverter(dataType: DataType): ValueConverter = dataType match {
case DateType =>
(row: InternalRow, ordinal: Int) =>
options.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal)))
case TimestampType =>
(row: InternalRow, ordinal: Int) =>
options.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal)))
case udt: UserDefinedType[_] => makeConverter(udt.sqlType)
case dt: DataType =>
(row: InternalRow, ordinal: Int) =>
row.get(ordinal, dt).toString
}
private def convertRow(row: InternalRow): Seq[String] = {
var i = 0
val values = new Array[String](row.numFields)
while (i < row.numFields) {
if (!row.isNullAt(i)) {
values(i) = valueConverters(i).apply(row, i)
} else {
values(i) = options.nullValue
}
i += 1
}
values
}
/**
* Writes a single InternalRow to CSV using Univocity.
*/
def write(row: InternalRow): Unit = {
if (printHeader) {
gen.writeHeaders()
}
gen.writeRow(convertRow(row): _*)
printHeader = false
}
def close(): Unit = gen.close()
def flush(): Unit = gen.flush()
}