[SPARK-19918][SQL] Use TextFileFormat in implementation of TextInputJsonDataSource
## What changes were proposed in this pull request? This PR proposes to use text datasource when Json schema inference. This basically proposes the similar approach in https://github.com/apache/spark/pull/15813 If we use Dataset for initial loading when inferring the schema, there are advantages. Please refer SPARK-18362 It seems JSON one was supposed to be fixed together but taken out according to https://github.com/apache/spark/pull/15813 > A similar problem also affects the JSON file format and this patch originally fixed that as well, but I've decided to split that change into a separate patch so as not to conflict with changes in another JSON PR. Also, this seems affecting some functionalities because it does not use `FileScanRDD`. This problem is described in SPARK-19885 (but it was CSV's case). ## How was this patch tested? Existing tests should cover this and manual test by `spark.read.json(path)` and check the UI. Author: hyukjinkwon <gurwls223@gmail.com> Closes #17255 from HyukjinKwon/json-filescanrdd.
This commit is contained in:
parent
dacc382f0c
commit
8fb2a02e2c
|
@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
|
|||
import org.apache.spark.sql.execution.datasources.csv._
|
||||
import org.apache.spark.sql.execution.datasources.DataSource
|
||||
import org.apache.spark.sql.execution.datasources.jdbc._
|
||||
import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
|
||||
import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource
|
||||
import org.apache.spark.sql.types.{StringType, StructType}
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
|
@ -376,17 +376,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
|
|||
extraOptions.toMap,
|
||||
sparkSession.sessionState.conf.sessionLocalTimeZone,
|
||||
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
|
||||
val createParser = CreateJacksonParser.string _
|
||||
|
||||
val schema = userSpecifiedSchema.getOrElse {
|
||||
JsonInferSchema.infer(
|
||||
jsonDataset.rdd,
|
||||
parsedOptions,
|
||||
createParser)
|
||||
TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions)
|
||||
}
|
||||
|
||||
verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
|
||||
|
||||
val createParser = CreateJacksonParser.string _
|
||||
val parsed = jsonDataset.rdd.mapPartitions { iter =>
|
||||
val parser = new JacksonParser(schema, parsedOptions)
|
||||
iter.flatMap(parser.parse(_, createParser, UTF8String.fromString))
|
||||
|
|
|
@ -17,32 +17,30 @@
|
|||
|
||||
package org.apache.spark.sql.execution.datasources.json
|
||||
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
|
||||
import com.google.common.io.ByteStreams
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.FileStatus
|
||||
import org.apache.hadoop.io.{LongWritable, Text}
|
||||
import org.apache.hadoop.io.Text
|
||||
import org.apache.hadoop.mapreduce.Job
|
||||
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat}
|
||||
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
|
||||
import org.apache.spark.rdd.{BinaryFileRDD, RDD}
|
||||
import org.apache.spark.sql.{AnalysisException, SparkSession}
|
||||
import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
|
||||
import org.apache.spark.sql.execution.datasources.{CodecStreams, HadoopFileLinesReader, PartitionedFile}
|
||||
import org.apache.spark.sql.execution.datasources.{CodecStreams, DataSource, HadoopFileLinesReader, PartitionedFile}
|
||||
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
* Common functions for parsing JSON files
|
||||
* @tparam T A datatype containing the unparsed JSON, such as [[Text]] or [[String]]
|
||||
*/
|
||||
abstract class JsonDataSource[T] extends Serializable {
|
||||
abstract class JsonDataSource extends Serializable {
|
||||
def isSplitable: Boolean
|
||||
|
||||
/**
|
||||
|
@ -53,28 +51,12 @@ abstract class JsonDataSource[T] extends Serializable {
|
|||
file: PartitionedFile,
|
||||
parser: JacksonParser): Iterator[InternalRow]
|
||||
|
||||
/**
|
||||
* Create an [[RDD]] that handles the preliminary parsing of [[T]] records
|
||||
*/
|
||||
protected def createBaseRdd(
|
||||
sparkSession: SparkSession,
|
||||
inputPaths: Seq[FileStatus]): RDD[T]
|
||||
|
||||
/**
|
||||
* A generic wrapper to invoke the correct [[JsonFactory]] method to allocate a [[JsonParser]]
|
||||
* for an instance of [[T]]
|
||||
*/
|
||||
def createParser(jsonFactory: JsonFactory, value: T): JsonParser
|
||||
|
||||
final def infer(
|
||||
final def inferSchema(
|
||||
sparkSession: SparkSession,
|
||||
inputPaths: Seq[FileStatus],
|
||||
parsedOptions: JSONOptions): Option[StructType] = {
|
||||
if (inputPaths.nonEmpty) {
|
||||
val jsonSchema = JsonInferSchema.infer(
|
||||
createBaseRdd(sparkSession, inputPaths),
|
||||
parsedOptions,
|
||||
createParser)
|
||||
val jsonSchema = infer(sparkSession, inputPaths, parsedOptions)
|
||||
checkConstraints(jsonSchema)
|
||||
Some(jsonSchema)
|
||||
} else {
|
||||
|
@ -82,6 +64,11 @@ abstract class JsonDataSource[T] extends Serializable {
|
|||
}
|
||||
}
|
||||
|
||||
protected def infer(
|
||||
sparkSession: SparkSession,
|
||||
inputPaths: Seq[FileStatus],
|
||||
parsedOptions: JSONOptions): StructType
|
||||
|
||||
/** Constraints to be imposed on schema to be stored. */
|
||||
private def checkConstraints(schema: StructType): Unit = {
|
||||
if (schema.fieldNames.length != schema.fieldNames.distinct.length) {
|
||||
|
@ -95,53 +82,46 @@ abstract class JsonDataSource[T] extends Serializable {
|
|||
}
|
||||
|
||||
object JsonDataSource {
|
||||
def apply(options: JSONOptions): JsonDataSource[_] = {
|
||||
def apply(options: JSONOptions): JsonDataSource = {
|
||||
if (options.wholeFile) {
|
||||
WholeFileJsonDataSource
|
||||
} else {
|
||||
TextInputJsonDataSource
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new [[RDD]] via the supplied callback if there is at least one file to process,
|
||||
* otherwise an [[org.apache.spark.rdd.EmptyRDD]] will be returned.
|
||||
*/
|
||||
def createBaseRdd[T : ClassTag](
|
||||
sparkSession: SparkSession,
|
||||
inputPaths: Seq[FileStatus])(
|
||||
fn: (Configuration, String) => RDD[T]): RDD[T] = {
|
||||
val paths = inputPaths.map(_.getPath)
|
||||
|
||||
if (paths.nonEmpty) {
|
||||
val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
|
||||
FileInputFormat.setInputPaths(job, paths: _*)
|
||||
fn(job.getConfiguration, paths.mkString(","))
|
||||
} else {
|
||||
sparkSession.sparkContext.emptyRDD[T]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object TextInputJsonDataSource extends JsonDataSource[Text] {
|
||||
object TextInputJsonDataSource extends JsonDataSource {
|
||||
override val isSplitable: Boolean = {
|
||||
// splittable if the underlying source is
|
||||
true
|
||||
}
|
||||
|
||||
override protected def createBaseRdd(
|
||||
override def infer(
|
||||
sparkSession: SparkSession,
|
||||
inputPaths: Seq[FileStatus]): RDD[Text] = {
|
||||
JsonDataSource.createBaseRdd(sparkSession, inputPaths) {
|
||||
case (conf, name) =>
|
||||
sparkSession.sparkContext.newAPIHadoopRDD(
|
||||
conf,
|
||||
classOf[TextInputFormat],
|
||||
classOf[LongWritable],
|
||||
classOf[Text])
|
||||
.setName(s"JsonLines: $name")
|
||||
.values // get the text column
|
||||
}
|
||||
inputPaths: Seq[FileStatus],
|
||||
parsedOptions: JSONOptions): StructType = {
|
||||
val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths)
|
||||
inferFromDataset(json, parsedOptions)
|
||||
}
|
||||
|
||||
def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = {
|
||||
val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions)
|
||||
val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0))
|
||||
JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String)
|
||||
}
|
||||
|
||||
private def createBaseDataset(
|
||||
sparkSession: SparkSession,
|
||||
inputPaths: Seq[FileStatus]): Dataset[String] = {
|
||||
val paths = inputPaths.map(_.getPath.toString)
|
||||
sparkSession.baseRelationToDataFrame(
|
||||
DataSource.apply(
|
||||
sparkSession,
|
||||
paths = paths,
|
||||
className = classOf[TextFileFormat].getName
|
||||
).resolveRelation(checkFilesExist = false))
|
||||
.select("value").as(Encoders.STRING)
|
||||
}
|
||||
|
||||
override def readFile(
|
||||
|
@ -150,41 +130,48 @@ object TextInputJsonDataSource extends JsonDataSource[Text] {
|
|||
parser: JacksonParser): Iterator[InternalRow] = {
|
||||
val linesReader = new HadoopFileLinesReader(file, conf)
|
||||
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
|
||||
linesReader.flatMap(parser.parse(_, createParser, textToUTF8String))
|
||||
linesReader.flatMap(parser.parse(_, CreateJacksonParser.text, textToUTF8String))
|
||||
}
|
||||
|
||||
private def textToUTF8String(value: Text): UTF8String = {
|
||||
UTF8String.fromBytes(value.getBytes, 0, value.getLength)
|
||||
}
|
||||
|
||||
override def createParser(jsonFactory: JsonFactory, value: Text): JsonParser = {
|
||||
CreateJacksonParser.text(jsonFactory, value)
|
||||
}
|
||||
}
|
||||
|
||||
object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] {
|
||||
object WholeFileJsonDataSource extends JsonDataSource {
|
||||
override val isSplitable: Boolean = {
|
||||
false
|
||||
}
|
||||
|
||||
override protected def createBaseRdd(
|
||||
override def infer(
|
||||
sparkSession: SparkSession,
|
||||
inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = {
|
||||
JsonDataSource.createBaseRdd(sparkSession, inputPaths) {
|
||||
case (conf, name) =>
|
||||
new BinaryFileRDD(
|
||||
sparkSession.sparkContext,
|
||||
classOf[StreamInputFormat],
|
||||
classOf[String],
|
||||
classOf[PortableDataStream],
|
||||
conf,
|
||||
sparkSession.sparkContext.defaultMinPartitions)
|
||||
.setName(s"JsonFile: $name")
|
||||
.values
|
||||
}
|
||||
inputPaths: Seq[FileStatus],
|
||||
parsedOptions: JSONOptions): StructType = {
|
||||
val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths)
|
||||
val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions)
|
||||
JsonInferSchema.infer(sampled, parsedOptions, createParser)
|
||||
}
|
||||
|
||||
override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
|
||||
private def createBaseRdd(
|
||||
sparkSession: SparkSession,
|
||||
inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = {
|
||||
val paths = inputPaths.map(_.getPath)
|
||||
val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
|
||||
val conf = job.getConfiguration
|
||||
val name = paths.mkString(",")
|
||||
FileInputFormat.setInputPaths(job, paths: _*)
|
||||
new BinaryFileRDD(
|
||||
sparkSession.sparkContext,
|
||||
classOf[StreamInputFormat],
|
||||
classOf[String],
|
||||
classOf[PortableDataStream],
|
||||
conf,
|
||||
sparkSession.sparkContext.defaultMinPartitions)
|
||||
.setName(s"JsonFile: $name")
|
||||
.values
|
||||
}
|
||||
|
||||
private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
|
||||
CreateJacksonParser.inputStream(
|
||||
jsonFactory,
|
||||
CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath()))
|
||||
|
|
|
@ -54,7 +54,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
|
|||
options,
|
||||
sparkSession.sessionState.conf.sessionLocalTimeZone,
|
||||
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
|
||||
JsonDataSource(parsedOptions).infer(
|
||||
JsonDataSource(parsedOptions).inferSchema(
|
||||
sparkSession, files, parsedOptions)
|
||||
}
|
||||
|
||||
|
|
|
@ -40,18 +40,11 @@ private[sql] object JsonInferSchema {
|
|||
json: RDD[T],
|
||||
configOptions: JSONOptions,
|
||||
createParser: (JsonFactory, T) => JsonParser): StructType = {
|
||||
require(configOptions.samplingRatio > 0,
|
||||
s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0")
|
||||
val shouldHandleCorruptRecord = configOptions.permissive
|
||||
val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
|
||||
val schemaData = if (configOptions.samplingRatio > 0.99) {
|
||||
json
|
||||
} else {
|
||||
json.sample(withReplacement = false, configOptions.samplingRatio, 1)
|
||||
}
|
||||
|
||||
// perform schema inference on each row and merge afterwards
|
||||
val rootType = schemaData.mapPartitions { iter =>
|
||||
val rootType = json.mapPartitions { iter =>
|
||||
val factory = new JsonFactory()
|
||||
configOptions.setJacksonOptions(factory)
|
||||
iter.flatMap { row =>
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.datasources.json
|
||||
|
||||
import org.apache.spark.input.PortableDataStream
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.sql.catalyst.json.JSONOptions
|
||||
|
||||
object JsonUtils {
|
||||
/**
|
||||
* Sample JSON dataset as configured by `samplingRatio`.
|
||||
*/
|
||||
def sample(json: Dataset[String], options: JSONOptions): Dataset[String] = {
|
||||
require(options.samplingRatio > 0,
|
||||
s"samplingRatio (${options.samplingRatio}) should be greater than 0")
|
||||
if (options.samplingRatio > 0.99) {
|
||||
json
|
||||
} else {
|
||||
json.sample(withReplacement = false, options.samplingRatio, 1)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sample JSON RDD as configured by `samplingRatio`.
|
||||
*/
|
||||
def sample(json: RDD[PortableDataStream], options: JSONOptions): RDD[PortableDataStream] = {
|
||||
require(options.samplingRatio > 0,
|
||||
s"samplingRatio (${options.samplingRatio}) should be greater than 0")
|
||||
if (options.samplingRatio > 0.99) {
|
||||
json
|
||||
} else {
|
||||
json.sample(withReplacement = false, options.samplingRatio, 1)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue