diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index f1bce1aa41..309654c804 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ -import org.apache.spark.sql.execution.datasources.json.JsonInferSchema +import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -376,17 +376,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) - val createParser = CreateJacksonParser.string _ val schema = userSpecifiedSchema.getOrElse { - JsonInferSchema.infer( - jsonDataset.rdd, - parsedOptions, - createParser) + TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions) } verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord) + val createParser = CreateJacksonParser.string _ val parsed = jsonDataset.rdd.mapPartitions { iter => val parser = new JacksonParser(schema, parsedOptions) iter.flatMap(parser.parse(_, createParser, UTF8String.fromString)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 18843bfc30..84f026620d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -17,32 +17,30 @@ package org.apache.spark.sql.execution.datasources.json -import scala.reflect.ClassTag - import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileStatus -import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.Job -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.TaskContext import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} -import org.apache.spark.sql.execution.datasources.{CodecStreams, HadoopFileLinesReader, PartitionedFile} +import org.apache.spark.sql.execution.datasources.{CodecStreams, DataSource, HadoopFileLinesReader, PartitionedFile} +import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** * Common functions for parsing JSON files - * @tparam T A datatype containing the unparsed JSON, such as [[Text]] or [[String]] */ -abstract class JsonDataSource[T] extends Serializable { +abstract class JsonDataSource extends Serializable { def isSplitable: Boolean /** @@ -53,28 +51,12 @@ abstract class JsonDataSource[T] extends Serializable { file: PartitionedFile, parser: JacksonParser): Iterator[InternalRow] - /** - * Create an [[RDD]] that handles the preliminary parsing of [[T]] records - */ - protected def createBaseRdd( - sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[T] - - /** - * A generic wrapper to invoke the correct [[JsonFactory]] method to allocate a [[JsonParser]] - * for an instance of [[T]] - */ - def createParser(jsonFactory: JsonFactory, value: T): JsonParser - - final def infer( + final def inferSchema( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Option[StructType] = { if (inputPaths.nonEmpty) { - val jsonSchema = JsonInferSchema.infer( - createBaseRdd(sparkSession, inputPaths), - parsedOptions, - createParser) + val jsonSchema = infer(sparkSession, inputPaths, parsedOptions) checkConstraints(jsonSchema) Some(jsonSchema) } else { @@ -82,6 +64,11 @@ abstract class JsonDataSource[T] extends Serializable { } } + protected def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): StructType + /** Constraints to be imposed on schema to be stored. */ private def checkConstraints(schema: StructType): Unit = { if (schema.fieldNames.length != schema.fieldNames.distinct.length) { @@ -95,53 +82,46 @@ abstract class JsonDataSource[T] extends Serializable { } object JsonDataSource { - def apply(options: JSONOptions): JsonDataSource[_] = { + def apply(options: JSONOptions): JsonDataSource = { if (options.wholeFile) { WholeFileJsonDataSource } else { TextInputJsonDataSource } } - - /** - * Create a new [[RDD]] via the supplied callback if there is at least one file to process, - * otherwise an [[org.apache.spark.rdd.EmptyRDD]] will be returned. - */ - def createBaseRdd[T : ClassTag]( - sparkSession: SparkSession, - inputPaths: Seq[FileStatus])( - fn: (Configuration, String) => RDD[T]): RDD[T] = { - val paths = inputPaths.map(_.getPath) - - if (paths.nonEmpty) { - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) - FileInputFormat.setInputPaths(job, paths: _*) - fn(job.getConfiguration, paths.mkString(",")) - } else { - sparkSession.sparkContext.emptyRDD[T] - } - } } -object TextInputJsonDataSource extends JsonDataSource[Text] { +object TextInputJsonDataSource extends JsonDataSource { override val isSplitable: Boolean = { // splittable if the underlying source is true } - override protected def createBaseRdd( + override def infer( sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[Text] = { - JsonDataSource.createBaseRdd(sparkSession, inputPaths) { - case (conf, name) => - sparkSession.sparkContext.newAPIHadoopRDD( - conf, - classOf[TextInputFormat], - classOf[LongWritable], - classOf[Text]) - .setName(s"JsonLines: $name") - .values // get the text column - } + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): StructType = { + val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths) + inferFromDataset(json, parsedOptions) + } + + def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = { + val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions) + val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0)) + JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String) + } + + private def createBaseDataset( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus]): Dataset[String] = { + val paths = inputPaths.map(_.getPath.toString) + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + className = classOf[TextFileFormat].getName + ).resolveRelation(checkFilesExist = false)) + .select("value").as(Encoders.STRING) } override def readFile( @@ -150,41 +130,48 @@ object TextInputJsonDataSource extends JsonDataSource[Text] { parser: JacksonParser): Iterator[InternalRow] = { val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) - linesReader.flatMap(parser.parse(_, createParser, textToUTF8String)) + linesReader.flatMap(parser.parse(_, CreateJacksonParser.text, textToUTF8String)) } private def textToUTF8String(value: Text): UTF8String = { UTF8String.fromBytes(value.getBytes, 0, value.getLength) } - - override def createParser(jsonFactory: JsonFactory, value: Text): JsonParser = { - CreateJacksonParser.text(jsonFactory, value) - } } -object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { +object WholeFileJsonDataSource extends JsonDataSource { override val isSplitable: Boolean = { false } - override protected def createBaseRdd( + override def infer( sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = { - JsonDataSource.createBaseRdd(sparkSession, inputPaths) { - case (conf, name) => - new BinaryFileRDD( - sparkSession.sparkContext, - classOf[StreamInputFormat], - classOf[String], - classOf[PortableDataStream], - conf, - sparkSession.sparkContext.defaultMinPartitions) - .setName(s"JsonFile: $name") - .values - } + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): StructType = { + val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths) + val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions) + JsonInferSchema.infer(sampled, parsedOptions, createParser) } - override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { + private def createBaseRdd( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = { + val paths = inputPaths.map(_.getPath) + val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + val conf = job.getConfiguration + val name = paths.mkString(",") + FileInputFormat.setInputPaths(job, paths: _*) + new BinaryFileRDD( + sparkSession.sparkContext, + classOf[StreamInputFormat], + classOf[String], + classOf[PortableDataStream], + conf, + sparkSession.sparkContext.defaultMinPartitions) + .setName(s"JsonFile: $name") + .values + } + + private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { CreateJacksonParser.inputStream( jsonFactory, CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 902fee5a7e..a9dd91eba6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -54,7 +54,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) - JsonDataSource(parsedOptions).infer( + JsonDataSource(parsedOptions).inferSchema( sparkSession, files, parsedOptions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index ab09358115..7475f8ec79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -40,18 +40,11 @@ private[sql] object JsonInferSchema { json: RDD[T], configOptions: JSONOptions, createParser: (JsonFactory, T) => JsonParser): StructType = { - require(configOptions.samplingRatio > 0, - s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0") val shouldHandleCorruptRecord = configOptions.permissive val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord - val schemaData = if (configOptions.samplingRatio > 0.99) { - json - } else { - json.sample(withReplacement = false, configOptions.samplingRatio, 1) - } // perform schema inference on each row and merge afterwards - val rootType = schemaData.mapPartitions { iter => + val rootType = json.mapPartitions { iter => val factory = new JsonFactory() configOptions.setJacksonOptions(factory) iter.flatMap { row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala new file mode 100644 index 0000000000..d511594c5d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import org.apache.spark.input.PortableDataStream +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.catalyst.json.JSONOptions + +object JsonUtils { + /** + * Sample JSON dataset as configured by `samplingRatio`. + */ + def sample(json: Dataset[String], options: JSONOptions): Dataset[String] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + json + } else { + json.sample(withReplacement = false, options.samplingRatio, 1) + } + } + + /** + * Sample JSON RDD as configured by `samplingRatio`. + */ + def sample(json: RDD[PortableDataStream], options: JSONOptions): RDD[PortableDataStream] = { + require(options.samplingRatio > 0, + s"samplingRatio (${options.samplingRatio}) should be greater than 0") + if (options.samplingRatio > 0.99) { + json + } else { + json.sample(withReplacement = false, options.samplingRatio, 1) + } + } +}