[SPARK-19918][SQL] Use TextFileFormat in implementation of TextInputJsonDataSource

## What changes were proposed in this pull request?

This PR proposes to use text datasource when Json schema inference.

This basically proposes the similar approach in https://github.com/apache/spark/pull/15813 If we use Dataset for initial loading when inferring the schema, there are advantages. Please refer SPARK-18362

It seems JSON one was supposed to be fixed together but taken out according to https://github.com/apache/spark/pull/15813

> A similar problem also affects the JSON file format and this patch originally fixed that as well, but I've decided to split that change into a separate patch so as not to conflict with changes in another JSON PR.

Also, this seems affecting some functionalities because it does not use `FileScanRDD`. This problem is described in SPARK-19885 (but it was CSV's case).

## How was this patch tested?

Existing tests should cover this and manual test by `spark.read.json(path)` and check the UI.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #17255 from HyukjinKwon/json-filescanrdd.
This commit is contained in:
hyukjinkwon 2017-03-15 10:19:19 +08:00 committed by Wenchen Fan
parent dacc382f0c
commit 8fb2a02e2c
5 changed files with 123 additions and 95 deletions

View file

@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.csv._
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
@ -376,17 +376,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
extraOptions.toMap,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val createParser = CreateJacksonParser.string _
val schema = userSpecifiedSchema.getOrElse {
JsonInferSchema.infer(
jsonDataset.rdd,
parsedOptions,
createParser)
TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions)
}
verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
val createParser = CreateJacksonParser.string _
val parsed = jsonDataset.rdd.mapPartitions { iter =>
val parser = new JacksonParser(schema, parsedOptions)
iter.flatMap(parser.parse(_, createParser, UTF8String.fromString))

View file

@ -17,32 +17,30 @@
package org.apache.spark.sql.execution.datasources.json
import scala.reflect.ClassTag
import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
import com.google.common.io.ByteStreams
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.io.Text
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat}
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.spark.TaskContext
import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
import org.apache.spark.sql.execution.datasources.{CodecStreams, HadoopFileLinesReader, PartitionedFile}
import org.apache.spark.sql.execution.datasources.{CodecStreams, DataSource, HadoopFileLinesReader, PartitionedFile}
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
/**
* Common functions for parsing JSON files
* @tparam T A datatype containing the unparsed JSON, such as [[Text]] or [[String]]
*/
abstract class JsonDataSource[T] extends Serializable {
abstract class JsonDataSource extends Serializable {
def isSplitable: Boolean
/**
@ -53,28 +51,12 @@ abstract class JsonDataSource[T] extends Serializable {
file: PartitionedFile,
parser: JacksonParser): Iterator[InternalRow]
/**
* Create an [[RDD]] that handles the preliminary parsing of [[T]] records
*/
protected def createBaseRdd(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus]): RDD[T]
/**
* A generic wrapper to invoke the correct [[JsonFactory]] method to allocate a [[JsonParser]]
* for an instance of [[T]]
*/
def createParser(jsonFactory: JsonFactory, value: T): JsonParser
final def infer(
final def inferSchema(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): Option[StructType] = {
if (inputPaths.nonEmpty) {
val jsonSchema = JsonInferSchema.infer(
createBaseRdd(sparkSession, inputPaths),
parsedOptions,
createParser)
val jsonSchema = infer(sparkSession, inputPaths, parsedOptions)
checkConstraints(jsonSchema)
Some(jsonSchema)
} else {
@ -82,6 +64,11 @@ abstract class JsonDataSource[T] extends Serializable {
}
}
protected def infer(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): StructType
/** Constraints to be imposed on schema to be stored. */
private def checkConstraints(schema: StructType): Unit = {
if (schema.fieldNames.length != schema.fieldNames.distinct.length) {
@ -95,53 +82,46 @@ abstract class JsonDataSource[T] extends Serializable {
}
object JsonDataSource {
def apply(options: JSONOptions): JsonDataSource[_] = {
def apply(options: JSONOptions): JsonDataSource = {
if (options.wholeFile) {
WholeFileJsonDataSource
} else {
TextInputJsonDataSource
}
}
/**
* Create a new [[RDD]] via the supplied callback if there is at least one file to process,
* otherwise an [[org.apache.spark.rdd.EmptyRDD]] will be returned.
*/
def createBaseRdd[T : ClassTag](
sparkSession: SparkSession,
inputPaths: Seq[FileStatus])(
fn: (Configuration, String) => RDD[T]): RDD[T] = {
val paths = inputPaths.map(_.getPath)
if (paths.nonEmpty) {
val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
FileInputFormat.setInputPaths(job, paths: _*)
fn(job.getConfiguration, paths.mkString(","))
} else {
sparkSession.sparkContext.emptyRDD[T]
}
}
}
object TextInputJsonDataSource extends JsonDataSource[Text] {
object TextInputJsonDataSource extends JsonDataSource {
override val isSplitable: Boolean = {
// splittable if the underlying source is
true
}
override protected def createBaseRdd(
override def infer(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus]): RDD[Text] = {
JsonDataSource.createBaseRdd(sparkSession, inputPaths) {
case (conf, name) =>
sparkSession.sparkContext.newAPIHadoopRDD(
conf,
classOf[TextInputFormat],
classOf[LongWritable],
classOf[Text])
.setName(s"JsonLines: $name")
.values // get the text column
}
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): StructType = {
val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths)
inferFromDataset(json, parsedOptions)
}
def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = {
val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions)
val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0))
JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String)
}
private def createBaseDataset(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus]): Dataset[String] = {
val paths = inputPaths.map(_.getPath.toString)
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
paths = paths,
className = classOf[TextFileFormat].getName
).resolveRelation(checkFilesExist = false))
.select("value").as(Encoders.STRING)
}
override def readFile(
@ -150,41 +130,48 @@ object TextInputJsonDataSource extends JsonDataSource[Text] {
parser: JacksonParser): Iterator[InternalRow] = {
val linesReader = new HadoopFileLinesReader(file, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
linesReader.flatMap(parser.parse(_, createParser, textToUTF8String))
linesReader.flatMap(parser.parse(_, CreateJacksonParser.text, textToUTF8String))
}
private def textToUTF8String(value: Text): UTF8String = {
UTF8String.fromBytes(value.getBytes, 0, value.getLength)
}
override def createParser(jsonFactory: JsonFactory, value: Text): JsonParser = {
CreateJacksonParser.text(jsonFactory, value)
}
}
object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] {
object WholeFileJsonDataSource extends JsonDataSource {
override val isSplitable: Boolean = {
false
}
override protected def createBaseRdd(
override def infer(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = {
JsonDataSource.createBaseRdd(sparkSession, inputPaths) {
case (conf, name) =>
new BinaryFileRDD(
sparkSession.sparkContext,
classOf[StreamInputFormat],
classOf[String],
classOf[PortableDataStream],
conf,
sparkSession.sparkContext.defaultMinPartitions)
.setName(s"JsonFile: $name")
.values
}
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): StructType = {
val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths)
val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions)
JsonInferSchema.infer(sampled, parsedOptions, createParser)
}
override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
private def createBaseRdd(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = {
val paths = inputPaths.map(_.getPath)
val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
val conf = job.getConfiguration
val name = paths.mkString(",")
FileInputFormat.setInputPaths(job, paths: _*)
new BinaryFileRDD(
sparkSession.sparkContext,
classOf[StreamInputFormat],
classOf[String],
classOf[PortableDataStream],
conf,
sparkSession.sparkContext.defaultMinPartitions)
.setName(s"JsonFile: $name")
.values
}
private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
CreateJacksonParser.inputStream(
jsonFactory,
CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath()))

View file

@ -54,7 +54,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
options,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
JsonDataSource(parsedOptions).infer(
JsonDataSource(parsedOptions).inferSchema(
sparkSession, files, parsedOptions)
}

View file

@ -40,18 +40,11 @@ private[sql] object JsonInferSchema {
json: RDD[T],
configOptions: JSONOptions,
createParser: (JsonFactory, T) => JsonParser): StructType = {
require(configOptions.samplingRatio > 0,
s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0")
val shouldHandleCorruptRecord = configOptions.permissive
val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
val schemaData = if (configOptions.samplingRatio > 0.99) {
json
} else {
json.sample(withReplacement = false, configOptions.samplingRatio, 1)
}
// perform schema inference on each row and merge afterwards
val rootType = schemaData.mapPartitions { iter =>
val rootType = json.mapPartitions { iter =>
val factory = new JsonFactory()
configOptions.setJacksonOptions(factory)
iter.flatMap { row =>

View file

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.json
import org.apache.spark.input.PortableDataStream
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.catalyst.json.JSONOptions
object JsonUtils {
/**
* Sample JSON dataset as configured by `samplingRatio`.
*/
def sample(json: Dataset[String], options: JSONOptions): Dataset[String] = {
require(options.samplingRatio > 0,
s"samplingRatio (${options.samplingRatio}) should be greater than 0")
if (options.samplingRatio > 0.99) {
json
} else {
json.sample(withReplacement = false, options.samplingRatio, 1)
}
}
/**
* Sample JSON RDD as configured by `samplingRatio`.
*/
def sample(json: RDD[PortableDataStream], options: JSONOptions): RDD[PortableDataStream] = {
require(options.samplingRatio > 0,
s"samplingRatio (${options.samplingRatio}) should be greater than 0")
if (options.samplingRatio > 0.99) {
json
} else {
json.sample(withReplacement = false, options.samplingRatio, 1)
}
}
}