[SPARK-5938] [SPARK-5443] [SQL] Improve JsonRDD performance
This patch comprises of a few related pieces of work: * Schema inference is performed directly on the JSON token stream * `String => Row` conversion populate Spark SQL structures without intermediate types * Projection pushdown is implemented via CatalystScan for DataFrame queries * Support for the legacy parser by setting `spark.sql.json.useJacksonStreamingAPI` to `false` Performance improvements depend on the schema and queries being executed, but it should be faster across the board. Below are benchmarks using the last.fm Million Song dataset: ``` Command | Baseline | Patched ---------------------------------------------------|----------|-------- import sqlContext.implicits._ | | val df = sqlContext.jsonFile("/tmp/lastfm.json") | 70.0s | 14.6s df.count() | 28.8s | 6.2s df.rdd.count() | 35.3s | 21.5s df.where($"artist" === "Robert Hood").collect() | 28.3s | 16.9s ``` To prepare this dataset for benchmarking, follow these steps: ``` # Fetch the datasets from http://labrosa.ee.columbia.edu/millionsong/lastfm wget http://labrosa.ee.columbia.edu/millionsong/sites/default/files/lastfm/lastfm_test.zip \ http://labrosa.ee.columbia.edu/millionsong/sites/default/files/lastfm/lastfm_train.zip # Decompress and combine, pipe through `jq -c` to ensure there is one record per line unzip -p lastfm_test.zip lastfm_train.zip | jq -c . > lastfm.json ``` Author: Nathan Howell <nhowell@godaddy.com> Closes #5801 from NathanHowell/json-performance and squashes the following commits: 26fea31 [Nathan Howell] Recreate the baseRDD each for each scan operation a7ebeb2 [Nathan Howell] Increase coverage of inserts into a JSONRelation e06a1dd [Nathan Howell] Add comments to the `useJacksonStreamingAPI` config flag 6822712 [Nathan Howell] Split up JsonRDD2 into multiple objects fa8234f [Nathan Howell] Wrap long lines b31917b [Nathan Howell] Rename `useJsonRDD2` to `useJacksonStreamingAPI` 15c5d1b [Nathan Howell] JSONRelation's baseRDD need not be lazy f8add6e [Nathan Howell] Add comments on lack of support for precision and scale DecimalTypes fa0be47 [Nathan Howell] Remove unused default case in the field parser 80dba17 [Nathan Howell] Add comments regarding null handling and empty strings 842846d [Nathan Howell] Point the empty schema inference test at JsonRDD2 ab6ee87 [Nathan Howell] Add projection pushdown support to JsonRDD/JsonRDD2 f636c14 [Nathan Howell] Enable JsonRDD2 by default, add a flag to switch back to JsonRDD 0bbc445 [Nathan Howell] Improve JSON parsing and type inference performance 7ca70c1 [Nathan Howell] Eliminate arrow pattern, replace with pattern matches
This commit is contained in:
parent
9cfa9a516e
commit
2d6612cc8b
|
@ -26,7 +26,14 @@ object HiveTypeCoercion {
|
|||
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
|
||||
// The conversion for integral and floating point types have a linear widening hierarchy:
|
||||
private val numericPrecedence =
|
||||
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited)
|
||||
IndexedSeq(
|
||||
ByteType,
|
||||
ShortType,
|
||||
IntegerType,
|
||||
LongType,
|
||||
FloatType,
|
||||
DoubleType,
|
||||
DecimalType.Unlimited)
|
||||
|
||||
/**
|
||||
* Find the tightest common type of two types that might be used in a binary expression.
|
||||
|
@ -34,25 +41,21 @@ object HiveTypeCoercion {
|
|||
* with primitive types, because in that case the precision and scale of the result depends on
|
||||
* the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]].
|
||||
*/
|
||||
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
|
||||
val valueTypes = Seq(t1, t2).filter(t => t != NullType)
|
||||
if (valueTypes.distinct.size > 1) {
|
||||
// Promote numeric types to the highest of the two and all numeric types to unlimited decimal
|
||||
if (numericPrecedence.contains(t1) && numericPrecedence.contains(t2)) {
|
||||
Some(numericPrecedence.filter(t => t == t1 || t == t2).last)
|
||||
} else if (t1.isInstanceOf[DecimalType] && t2.isInstanceOf[DecimalType]) {
|
||||
// Fixed-precision decimals can up-cast into unlimited
|
||||
if (t1 == DecimalType.Unlimited || t2 == DecimalType.Unlimited) {
|
||||
Some(DecimalType.Unlimited)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
Some(if (valueTypes.size == 0) NullType else valueTypes.head)
|
||||
}
|
||||
val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
|
||||
case (t1, t2) if t1 == t2 => Some(t1)
|
||||
case (NullType, t1) => Some(t1)
|
||||
case (t1, NullType) => Some(t1)
|
||||
|
||||
// Promote numeric types to the highest of the two and all numeric types to unlimited decimal
|
||||
case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
|
||||
val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
|
||||
Some(numericPrecedence(index))
|
||||
|
||||
// Fixed-precision decimals can up-cast into unlimited
|
||||
case (DecimalType.Unlimited, _: DecimalType) => Some(DecimalType.Unlimited)
|
||||
case (_: DecimalType, DecimalType.Unlimited) => Some(DecimalType.Unlimited)
|
||||
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -134,6 +134,10 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
|
|||
throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
|
||||
}
|
||||
|
||||
private[sql] def getFieldIndex(name: String): Option[Int] = {
|
||||
nameToIndex.get(name)
|
||||
}
|
||||
|
||||
protected[sql] def toAttributes: Seq[AttributeReference] =
|
||||
map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
|
|||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
|
||||
import org.apache.spark.sql.jdbc.JDBCWriteDetails
|
||||
import org.apache.spark.sql.json.JsonRDD
|
||||
import org.apache.spark.sql.json.{JacksonGenerator, JsonRDD}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
|
||||
import org.apache.spark.util.Utils
|
||||
|
@ -1415,7 +1415,7 @@ class DataFrame private[sql](
|
|||
new Iterator[String] {
|
||||
override def hasNext: Boolean = iter.hasNext
|
||||
override def next(): String = {
|
||||
JsonRDD.rowToJSON(rowSchema, gen)(iter.next())
|
||||
JacksonGenerator(rowSchema, gen)(iter.next())
|
||||
gen.flush()
|
||||
|
||||
val json = writer.toString
|
||||
|
|
|
@ -73,6 +73,8 @@ private[spark] object SQLConf {
|
|||
|
||||
val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2"
|
||||
|
||||
val USE_JACKSON_STREAMING_API = "spark.sql.json.useJacksonStreamingAPI"
|
||||
|
||||
object Deprecated {
|
||||
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
|
||||
}
|
||||
|
@ -166,6 +168,12 @@ private[sql] class SQLConf extends Serializable {
|
|||
|
||||
private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean
|
||||
|
||||
/**
|
||||
* Selects between the new (true) and old (false) JSON handlers, to be removed in Spark 1.5.0
|
||||
*/
|
||||
private[spark] def useJacksonStreamingAPI: Boolean =
|
||||
getConf(USE_JACKSON_STREAMING_API, "true").toBoolean
|
||||
|
||||
/**
|
||||
* Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
|
||||
* a broadcast value during the physical executions of join operations. Setting this to -1
|
||||
|
|
|
@ -659,13 +659,17 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
|||
*/
|
||||
@Experimental
|
||||
def jsonRDD(json: RDD[String], schema: StructType): DataFrame = {
|
||||
val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
|
||||
val appliedSchema =
|
||||
Option(schema).getOrElse(
|
||||
JsonRDD.nullTypeToStringType(
|
||||
JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
|
||||
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
|
||||
createDataFrame(rowRDD, appliedSchema, needsConversion = false)
|
||||
if (conf.useJacksonStreamingAPI) {
|
||||
baseRelationToDataFrame(new JSONRelation(() => json, None, 1.0, Some(schema))(this))
|
||||
} else {
|
||||
val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
|
||||
val appliedSchema =
|
||||
Option(schema).getOrElse(
|
||||
JsonRDD.nullTypeToStringType(
|
||||
JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
|
||||
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
|
||||
createDataFrame(rowRDD, appliedSchema, needsConversion = false)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -689,12 +693,16 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
|||
*/
|
||||
@Experimental
|
||||
def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = {
|
||||
val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
|
||||
val appliedSchema =
|
||||
JsonRDD.nullTypeToStringType(
|
||||
JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
|
||||
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
|
||||
createDataFrame(rowRDD, appliedSchema, needsConversion = false)
|
||||
if (conf.useJacksonStreamingAPI) {
|
||||
baseRelationToDataFrame(new JSONRelation(() => json, None, samplingRatio, None)(this))
|
||||
} else {
|
||||
val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
|
||||
val appliedSchema =
|
||||
JsonRDD.nullTypeToStringType(
|
||||
JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
|
||||
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
|
||||
createDataFrame(rowRDD, appliedSchema, needsConversion = false)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,171 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.json
|
||||
|
||||
import com.fasterxml.jackson.core._
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
|
||||
import org.apache.spark.sql.json.JacksonUtils.nextUntil
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
private[sql] object InferSchema {
|
||||
/**
|
||||
* Infer the type of a collection of json records in three stages:
|
||||
* 1. Infer the type of each record
|
||||
* 2. Merge types by choosing the lowest type necessary to cover equal keys
|
||||
* 3. Replace any remaining null fields with string, the top type
|
||||
*/
|
||||
def apply(
|
||||
json: RDD[String],
|
||||
samplingRatio: Double = 1.0,
|
||||
columnNameOfCorruptRecords: String): StructType = {
|
||||
require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0")
|
||||
val schemaData = if (samplingRatio > 0.99) {
|
||||
json
|
||||
} else {
|
||||
json.sample(withReplacement = false, samplingRatio, 1)
|
||||
}
|
||||
|
||||
// perform schema inference on each row and merge afterwards
|
||||
schemaData.mapPartitions { iter =>
|
||||
val factory = new JsonFactory()
|
||||
iter.map { row =>
|
||||
try {
|
||||
val parser = factory.createParser(row)
|
||||
parser.nextToken()
|
||||
inferField(parser)
|
||||
} catch {
|
||||
case _: JsonParseException =>
|
||||
StructType(Seq(StructField(columnNameOfCorruptRecords, StringType)))
|
||||
}
|
||||
}
|
||||
}.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) match {
|
||||
case st: StructType => nullTypeToStringType(st)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Infer the type of a json document from the parser's token stream
|
||||
*/
|
||||
private def inferField(parser: JsonParser): DataType = {
|
||||
import com.fasterxml.jackson.core.JsonToken._
|
||||
parser.getCurrentToken match {
|
||||
case null | VALUE_NULL => NullType
|
||||
|
||||
case FIELD_NAME =>
|
||||
parser.nextToken()
|
||||
inferField(parser)
|
||||
|
||||
case VALUE_STRING if parser.getTextLength < 1 =>
|
||||
// Zero length strings and nulls have special handling to deal
|
||||
// with JSON generators that do not distinguish between the two.
|
||||
// To accurately infer types for empty strings that are really
|
||||
// meant to represent nulls we assume that the two are isomorphic
|
||||
// but will defer treating null fields as strings until all the
|
||||
// record fields' types have been combined.
|
||||
NullType
|
||||
|
||||
case VALUE_STRING => StringType
|
||||
case START_OBJECT =>
|
||||
val builder = Seq.newBuilder[StructField]
|
||||
while (nextUntil(parser, END_OBJECT)) {
|
||||
builder += StructField(parser.getCurrentName, inferField(parser), nullable = true)
|
||||
}
|
||||
|
||||
StructType(builder.result().sortBy(_.name))
|
||||
|
||||
case START_ARRAY =>
|
||||
// If this JSON array is empty, we use NullType as a placeholder.
|
||||
// If this array is not empty in other JSON objects, we can resolve
|
||||
// the type as we pass through all JSON objects.
|
||||
var elementType: DataType = NullType
|
||||
while (nextUntil(parser, END_ARRAY)) {
|
||||
elementType = compatibleType(elementType, inferField(parser))
|
||||
}
|
||||
|
||||
ArrayType(elementType)
|
||||
|
||||
case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
|
||||
import JsonParser.NumberType._
|
||||
parser.getNumberType match {
|
||||
// For Integer values, use LongType by default.
|
||||
case INT | LONG => LongType
|
||||
// Since we do not have a data type backed by BigInteger,
|
||||
// when we see a Java BigInteger, we use DecimalType.
|
||||
case BIG_INTEGER | BIG_DECIMAL => DecimalType.Unlimited
|
||||
case FLOAT | DOUBLE => DoubleType
|
||||
}
|
||||
|
||||
case VALUE_TRUE | VALUE_FALSE => BooleanType
|
||||
}
|
||||
}
|
||||
|
||||
private def nullTypeToStringType(struct: StructType): StructType = {
|
||||
val fields = struct.fields.map {
|
||||
case StructField(fieldName, dataType, nullable, _) =>
|
||||
val newType = dataType match {
|
||||
case NullType => StringType
|
||||
case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull)
|
||||
case ArrayType(struct: StructType, containsNull) =>
|
||||
ArrayType(nullTypeToStringType(struct), containsNull)
|
||||
case struct: StructType =>nullTypeToStringType(struct)
|
||||
case other: DataType => other
|
||||
}
|
||||
|
||||
StructField(fieldName, newType, nullable)
|
||||
}
|
||||
|
||||
StructType(fields)
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove top-level ArrayType wrappers and merge the remaining schemas
|
||||
*/
|
||||
private def compatibleRootType: (DataType, DataType) => DataType = {
|
||||
case (ArrayType(ty1, _), ty2) => compatibleRootType(ty1, ty2)
|
||||
case (ty1, ArrayType(ty2, _)) => compatibleRootType(ty1, ty2)
|
||||
case (ty1, ty2) => compatibleType(ty1, ty2)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the most general data type for two given data types.
|
||||
*/
|
||||
private[json] def compatibleType(t1: DataType, t2: DataType): DataType = {
|
||||
HiveTypeCoercion.findTightestCommonType(t1, t2).getOrElse {
|
||||
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
|
||||
(t1, t2) match {
|
||||
case (other: DataType, NullType) => other
|
||||
case (NullType, other: DataType) => other
|
||||
case (StructType(fields1), StructType(fields2)) =>
|
||||
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
|
||||
case (name, fieldTypes) =>
|
||||
val dataType = fieldTypes.view.map(_.dataType).reduce(compatibleType)
|
||||
StructField(name, dataType, nullable = true)
|
||||
}
|
||||
StructType(newFields.toSeq.sortBy(_.name))
|
||||
|
||||
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
|
||||
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
|
||||
|
||||
// strings and every string is a Json object.
|
||||
case (_, _) => StringType
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -22,14 +22,16 @@ import java.io.IOException
|
|||
import org.apache.hadoop.fs.Path
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.expressions.Row
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute, Row}
|
||||
import org.apache.spark.sql.sources._
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.sql.types.{StructField, StructType}
|
||||
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
|
||||
|
||||
|
||||
private[sql] class DefaultSource
|
||||
extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider {
|
||||
extends RelationProvider
|
||||
with SchemaRelationProvider
|
||||
with CreatableRelationProvider {
|
||||
|
||||
private def checkPath(parameters: Map[String, String]): String = {
|
||||
parameters.getOrElse("path", sys.error("'path' must be specified for json data."))
|
||||
|
@ -42,7 +44,7 @@ private[sql] class DefaultSource
|
|||
val path = checkPath(parameters)
|
||||
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
|
||||
|
||||
JSONRelation(path, samplingRatio, None)(sqlContext)
|
||||
new JSONRelation(path, samplingRatio, None, sqlContext)
|
||||
}
|
||||
|
||||
/** Returns a new base relation with the given schema and parameters. */
|
||||
|
@ -53,7 +55,7 @@ private[sql] class DefaultSource
|
|||
val path = checkPath(parameters)
|
||||
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
|
||||
|
||||
JSONRelation(path, samplingRatio, Some(schema))(sqlContext)
|
||||
new JSONRelation(path, samplingRatio, Some(schema), sqlContext)
|
||||
}
|
||||
|
||||
override def createRelation(
|
||||
|
@ -101,32 +103,87 @@ private[sql] class DefaultSource
|
|||
}
|
||||
}
|
||||
|
||||
private[sql] case class JSONRelation(
|
||||
path: String,
|
||||
samplingRatio: Double,
|
||||
private[sql] class JSONRelation(
|
||||
// baseRDD is not immutable with respect to INSERT OVERWRITE
|
||||
// and so it must be recreated at least as often as the
|
||||
// underlying inputs are modified. To be safe, a function is
|
||||
// used instead of a regular RDD value to ensure a fresh RDD is
|
||||
// recreated for each and every operation.
|
||||
baseRDD: () => RDD[String],
|
||||
val path: Option[String],
|
||||
val samplingRatio: Double,
|
||||
userSpecifiedSchema: Option[StructType])(
|
||||
@transient val sqlContext: SQLContext)
|
||||
extends BaseRelation
|
||||
with TableScan
|
||||
with InsertableRelation {
|
||||
with InsertableRelation
|
||||
with CatalystScan {
|
||||
|
||||
// TODO: Support partitioned JSON relation.
|
||||
private def baseRDD = sqlContext.sparkContext.textFile(path)
|
||||
def this(
|
||||
path: String,
|
||||
samplingRatio: Double,
|
||||
userSpecifiedSchema: Option[StructType],
|
||||
sqlContext: SQLContext) =
|
||||
this(
|
||||
() => sqlContext.sparkContext.textFile(path),
|
||||
Some(path),
|
||||
samplingRatio,
|
||||
userSpecifiedSchema)(sqlContext)
|
||||
|
||||
private val useJacksonStreamingAPI: Boolean = sqlContext.conf.useJacksonStreamingAPI
|
||||
|
||||
override val needConversion: Boolean = false
|
||||
|
||||
override val schema = userSpecifiedSchema.getOrElse(
|
||||
JsonRDD.nullTypeToStringType(
|
||||
JsonRDD.inferSchema(
|
||||
baseRDD,
|
||||
override lazy val schema = userSpecifiedSchema.getOrElse {
|
||||
if (useJacksonStreamingAPI) {
|
||||
InferSchema(
|
||||
baseRDD(),
|
||||
samplingRatio,
|
||||
sqlContext.conf.columnNameOfCorruptRecord)))
|
||||
sqlContext.conf.columnNameOfCorruptRecord)
|
||||
} else {
|
||||
JsonRDD.nullTypeToStringType(
|
||||
JsonRDD.inferSchema(
|
||||
baseRDD(),
|
||||
samplingRatio,
|
||||
sqlContext.conf.columnNameOfCorruptRecord))
|
||||
}
|
||||
}
|
||||
|
||||
override def buildScan(): RDD[Row] =
|
||||
JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.conf.columnNameOfCorruptRecord)
|
||||
override def buildScan(): RDD[Row] = {
|
||||
if (useJacksonStreamingAPI) {
|
||||
JacksonParser(
|
||||
baseRDD(),
|
||||
schema,
|
||||
sqlContext.conf.columnNameOfCorruptRecord)
|
||||
} else {
|
||||
JsonRDD.jsonStringToRow(
|
||||
baseRDD(),
|
||||
schema,
|
||||
sqlContext.conf.columnNameOfCorruptRecord)
|
||||
}
|
||||
}
|
||||
|
||||
override def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] = {
|
||||
if (useJacksonStreamingAPI) {
|
||||
JacksonParser(
|
||||
baseRDD(),
|
||||
StructType.fromAttributes(requiredColumns),
|
||||
sqlContext.conf.columnNameOfCorruptRecord)
|
||||
} else {
|
||||
JsonRDD.jsonStringToRow(
|
||||
baseRDD(),
|
||||
StructType.fromAttributes(requiredColumns),
|
||||
sqlContext.conf.columnNameOfCorruptRecord)
|
||||
}
|
||||
}
|
||||
|
||||
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
|
||||
val filesystemPath = new Path(path)
|
||||
val filesystemPath = path match {
|
||||
case Some(p) => new Path(p)
|
||||
case None =>
|
||||
throw new IOException(s"Cannot INSERT into table with no path defined")
|
||||
}
|
||||
|
||||
val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
|
||||
|
||||
if (overwrite) {
|
||||
|
@ -147,7 +204,7 @@ private[sql] case class JSONRelation(
|
|||
}
|
||||
}
|
||||
// Write the data.
|
||||
data.toJSON.saveAsTextFile(path)
|
||||
data.toJSON.saveAsTextFile(filesystemPath.toString)
|
||||
// Right now, we assume that the schema is not changed. We will not update the schema.
|
||||
// schema = data.schema
|
||||
} else {
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.json
|
||||
|
||||
import scala.collection.Map
|
||||
|
||||
import com.fasterxml.jackson.core._
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
private[sql] object JacksonGenerator {
|
||||
/** Transforms a single Row to JSON using Jackson
|
||||
*
|
||||
* @param rowSchema the schema object used for conversion
|
||||
* @param gen a JsonGenerator object
|
||||
* @param row The row to convert
|
||||
*/
|
||||
def apply(rowSchema: StructType, gen: JsonGenerator)(row: Row): Unit = {
|
||||
def valWriter: (DataType, Any) => Unit = {
|
||||
case (_, null) | (NullType, _) => gen.writeNull()
|
||||
case (StringType, v: String) => gen.writeString(v)
|
||||
case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString)
|
||||
case (IntegerType, v: Int) => gen.writeNumber(v)
|
||||
case (ShortType, v: Short) => gen.writeNumber(v)
|
||||
case (FloatType, v: Float) => gen.writeNumber(v)
|
||||
case (DoubleType, v: Double) => gen.writeNumber(v)
|
||||
case (LongType, v: Long) => gen.writeNumber(v)
|
||||
case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v)
|
||||
case (ByteType, v: Byte) => gen.writeNumber(v.toInt)
|
||||
case (BinaryType, v: Array[Byte]) => gen.writeBinary(v)
|
||||
case (BooleanType, v: Boolean) => gen.writeBoolean(v)
|
||||
case (DateType, v) => gen.writeString(v.toString)
|
||||
case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, udt.serialize(v))
|
||||
|
||||
case (ArrayType(ty, _), v: Seq[_] ) =>
|
||||
gen.writeStartArray()
|
||||
v.foreach(valWriter(ty,_))
|
||||
gen.writeEndArray()
|
||||
|
||||
case (MapType(kv,vv, _), v: Map[_,_]) =>
|
||||
gen.writeStartObject()
|
||||
v.foreach { p =>
|
||||
gen.writeFieldName(p._1.toString)
|
||||
valWriter(vv,p._2)
|
||||
}
|
||||
gen.writeEndObject()
|
||||
|
||||
case (StructType(ty), v: Row) =>
|
||||
gen.writeStartObject()
|
||||
ty.zip(v.toSeq).foreach {
|
||||
case (_, null) =>
|
||||
case (field, v) =>
|
||||
gen.writeFieldName(field.name)
|
||||
valWriter(field.dataType, v)
|
||||
}
|
||||
gen.writeEndObject()
|
||||
}
|
||||
|
||||
valWriter(rowSchema, row)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,215 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.json
|
||||
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.sql.Timestamp
|
||||
|
||||
import scala.collection.Map
|
||||
|
||||
import com.fasterxml.jackson.core._
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.json.JacksonUtils.nextUntil
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
private[sql] object JacksonParser {
|
||||
def apply(
|
||||
json: RDD[String],
|
||||
schema: StructType,
|
||||
columnNameOfCorruptRecords: String): RDD[Row] = {
|
||||
parseJson(json, schema, columnNameOfCorruptRecords)
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse the current token (and related children) according to a desired schema
|
||||
*/
|
||||
private[sql] def convertField(
|
||||
factory: JsonFactory,
|
||||
parser: JsonParser,
|
||||
schema: DataType): Any = {
|
||||
import com.fasterxml.jackson.core.JsonToken._
|
||||
(parser.getCurrentToken, schema) match {
|
||||
case (null | VALUE_NULL, _) =>
|
||||
null
|
||||
|
||||
case (FIELD_NAME, _) =>
|
||||
parser.nextToken()
|
||||
convertField(factory, parser, schema)
|
||||
|
||||
case (VALUE_STRING, StringType) =>
|
||||
UTF8String(parser.getText)
|
||||
|
||||
case (VALUE_STRING, _) if parser.getTextLength < 1 =>
|
||||
// guard the non string type
|
||||
null
|
||||
|
||||
case (VALUE_STRING, DateType) =>
|
||||
DateUtils.millisToDays(DateUtils.stringToTime(parser.getText).getTime)
|
||||
|
||||
case (VALUE_STRING, TimestampType) =>
|
||||
new Timestamp(DateUtils.stringToTime(parser.getText).getTime)
|
||||
|
||||
case (VALUE_NUMBER_INT, TimestampType) =>
|
||||
new Timestamp(parser.getLongValue)
|
||||
|
||||
case (_, StringType) =>
|
||||
val writer = new ByteArrayOutputStream()
|
||||
val generator = factory.createGenerator(writer, JsonEncoding.UTF8)
|
||||
generator.copyCurrentStructure(parser)
|
||||
generator.close()
|
||||
UTF8String(writer.toByteArray)
|
||||
|
||||
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, FloatType) =>
|
||||
parser.getFloatValue
|
||||
|
||||
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DoubleType) =>
|
||||
parser.getDoubleValue
|
||||
|
||||
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DecimalType()) =>
|
||||
// TODO: add fixed precision and scale handling
|
||||
Decimal(parser.getDecimalValue)
|
||||
|
||||
case (VALUE_NUMBER_INT, ByteType) =>
|
||||
parser.getByteValue
|
||||
|
||||
case (VALUE_NUMBER_INT, ShortType) =>
|
||||
parser.getShortValue
|
||||
|
||||
case (VALUE_NUMBER_INT, IntegerType) =>
|
||||
parser.getIntValue
|
||||
|
||||
case (VALUE_NUMBER_INT, LongType) =>
|
||||
parser.getLongValue
|
||||
|
||||
case (VALUE_TRUE, BooleanType) =>
|
||||
true
|
||||
|
||||
case (VALUE_FALSE, BooleanType) =>
|
||||
false
|
||||
|
||||
case (START_OBJECT, st: StructType) =>
|
||||
convertObject(factory, parser, st)
|
||||
|
||||
case (START_ARRAY, ArrayType(st, _)) =>
|
||||
convertList(factory, parser, st)
|
||||
|
||||
case (START_OBJECT, ArrayType(st, _)) =>
|
||||
// the business end of SPARK-3308:
|
||||
// when an object is found but an array is requested just wrap it in a list
|
||||
convertField(factory, parser, st) :: Nil
|
||||
|
||||
case (START_OBJECT, MapType(StringType, kt, _)) =>
|
||||
convertMap(factory, parser, kt)
|
||||
|
||||
case (_, udt: UserDefinedType[_]) =>
|
||||
udt.deserialize(convertField(factory, parser, udt.sqlType))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse an object from the token stream into a new Row representing the schema.
|
||||
*
|
||||
* Fields in the json that are not defined in the requested schema will be dropped.
|
||||
*/
|
||||
private def convertObject(factory: JsonFactory, parser: JsonParser, schema: StructType): Row = {
|
||||
val row = new GenericMutableRow(schema.length)
|
||||
while (nextUntil(parser, JsonToken.END_OBJECT)) {
|
||||
schema.getFieldIndex(parser.getCurrentName) match {
|
||||
case Some(index) =>
|
||||
row.update(index, convertField(factory, parser, schema(index).dataType))
|
||||
|
||||
case None =>
|
||||
parser.skipChildren()
|
||||
}
|
||||
}
|
||||
|
||||
row
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse an object as a Map, preserving all fields
|
||||
*/
|
||||
private def convertMap(
|
||||
factory: JsonFactory,
|
||||
parser: JsonParser,
|
||||
valueType: DataType): Map[String, Any] = {
|
||||
val builder = Map.newBuilder[String, Any]
|
||||
while (nextUntil(parser, JsonToken.END_OBJECT)) {
|
||||
builder += parser.getCurrentName -> convertField(factory, parser, valueType)
|
||||
}
|
||||
|
||||
builder.result()
|
||||
}
|
||||
|
||||
private def convertList(
|
||||
factory: JsonFactory,
|
||||
parser: JsonParser,
|
||||
schema: DataType): Seq[Any] = {
|
||||
val builder = Seq.newBuilder[Any]
|
||||
while (nextUntil(parser, JsonToken.END_ARRAY)) {
|
||||
builder += convertField(factory, parser, schema)
|
||||
}
|
||||
|
||||
builder.result()
|
||||
}
|
||||
|
||||
private def parseJson(
|
||||
json: RDD[String],
|
||||
schema: StructType,
|
||||
columnNameOfCorruptRecords: String): RDD[Row] = {
|
||||
|
||||
def failedRecord(record: String): Seq[Row] = {
|
||||
// create a row even if no corrupt record column is present
|
||||
val row = new GenericMutableRow(schema.length)
|
||||
for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) {
|
||||
require(schema(corruptIndex).dataType == StringType)
|
||||
row.update(corruptIndex, record)
|
||||
}
|
||||
|
||||
Seq(row)
|
||||
}
|
||||
|
||||
json.mapPartitions { iter =>
|
||||
val factory = new JsonFactory()
|
||||
|
||||
iter.flatMap { record =>
|
||||
try {
|
||||
val parser = factory.createParser(record)
|
||||
parser.nextToken()
|
||||
|
||||
// to support both object and arrays (see SPARK-3308) we'll start
|
||||
// by converting the StructType schema to an ArrayType and let
|
||||
// convertField wrap an object into a single value array when necessary.
|
||||
convertField(factory, parser, ArrayType(schema)) match {
|
||||
case null => failedRecord(record)
|
||||
case list: Seq[Row @unchecked] => list
|
||||
case _ =>
|
||||
sys.error(
|
||||
s"Failed to parse record $record. Please make sure that each line of the file " +
|
||||
"(or each string in the RDD) is a valid JSON object or an array of JSON objects.")
|
||||
}
|
||||
} catch {
|
||||
case _: JsonProcessingException =>
|
||||
failedRecord(record)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.json
|
||||
|
||||
import com.fasterxml.jackson.core.{JsonParser, JsonToken}
|
||||
|
||||
private object JacksonUtils {
|
||||
/**
|
||||
* Advance the parser until a null or a specific token is found
|
||||
*/
|
||||
def nextUntil(parser: JsonParser, stopOn: JsonToken): Boolean = {
|
||||
parser.nextToken() match {
|
||||
case null => false
|
||||
case x => x != stopOn
|
||||
}
|
||||
}
|
||||
}
|
|
@ -440,54 +440,4 @@ private[sql] object JsonRDD extends Logging {
|
|||
|
||||
row
|
||||
}
|
||||
|
||||
/** Transforms a single Row to JSON using Jackson
|
||||
*
|
||||
* @param rowSchema the schema object used for conversion
|
||||
* @param gen a JsonGenerator object
|
||||
* @param row The row to convert
|
||||
*/
|
||||
private[sql] def rowToJSON(rowSchema: StructType, gen: JsonGenerator)(row: Row) = {
|
||||
def valWriter: (DataType, Any) => Unit = {
|
||||
case (_, null) | (NullType, _) => gen.writeNull()
|
||||
case (StringType, v: String) => gen.writeString(v)
|
||||
case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString)
|
||||
case (IntegerType, v: Int) => gen.writeNumber(v)
|
||||
case (ShortType, v: Short) => gen.writeNumber(v)
|
||||
case (FloatType, v: Float) => gen.writeNumber(v)
|
||||
case (DoubleType, v: Double) => gen.writeNumber(v)
|
||||
case (LongType, v: Long) => gen.writeNumber(v)
|
||||
case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v)
|
||||
case (ByteType, v: Byte) => gen.writeNumber(v.toInt)
|
||||
case (BinaryType, v: Array[Byte]) => gen.writeBinary(v)
|
||||
case (BooleanType, v: Boolean) => gen.writeBoolean(v)
|
||||
case (DateType, v) => gen.writeString(v.toString)
|
||||
case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v)
|
||||
|
||||
case (ArrayType(ty, _), v: Seq[_] ) =>
|
||||
gen.writeStartArray()
|
||||
v.foreach(valWriter(ty,_))
|
||||
gen.writeEndArray()
|
||||
|
||||
case (MapType(kv,vv, _), v: Map[_,_]) =>
|
||||
gen.writeStartObject()
|
||||
v.foreach { p =>
|
||||
gen.writeFieldName(p._1.toString)
|
||||
valWriter(vv,p._2)
|
||||
}
|
||||
gen.writeEndObject()
|
||||
|
||||
case (StructType(ty), v: Row) =>
|
||||
gen.writeStartObject()
|
||||
ty.zip(v.toSeq).foreach {
|
||||
case (_, null) =>
|
||||
case (field, v) =>
|
||||
gen.writeFieldName(field.name)
|
||||
valWriter(field.dataType, v)
|
||||
}
|
||||
gen.writeEndObject()
|
||||
}
|
||||
|
||||
valWriter(rowSchema, row)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,13 +17,15 @@
|
|||
|
||||
package org.apache.spark.sql.json
|
||||
|
||||
import java.io.StringWriter
|
||||
import java.sql.{Date, Timestamp}
|
||||
|
||||
import com.fasterxml.jackson.core.JsonFactory
|
||||
import org.scalactic.Tolerance._
|
||||
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType}
|
||||
import org.apache.spark.sql.json.InferSchema.compatibleType
|
||||
import org.apache.spark.sql.sources.LogicalRelation
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.test.TestSQLContext._
|
||||
|
@ -46,6 +48,18 @@ class JsonSuite extends QueryTest {
|
|||
s"${expected}(${expected.getClass}).")
|
||||
}
|
||||
|
||||
val factory = new JsonFactory()
|
||||
def enforceCorrectType(value: Any, dataType: DataType): Any = {
|
||||
val writer = new StringWriter()
|
||||
val generator = factory.createGenerator(writer)
|
||||
generator.writeObject(value)
|
||||
generator.flush()
|
||||
|
||||
val parser = factory.createParser(writer.toString)
|
||||
parser.nextToken()
|
||||
JacksonParser.convertField(factory, parser, dataType)
|
||||
}
|
||||
|
||||
val intNumber: Int = 2147483647
|
||||
checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType))
|
||||
checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType))
|
||||
|
@ -439,7 +453,7 @@ class JsonSuite extends QueryTest {
|
|||
val jsonDF = jsonRDD(primitiveFieldValueTypeConflict)
|
||||
jsonDF.registerTempTable("jsonTable")
|
||||
|
||||
// Right now, the analyzer does not promote strings in a boolean expreesion.
|
||||
// Right now, the analyzer does not promote strings in a boolean expression.
|
||||
// Number and Boolean conflict: resolve the type as boolean in this query.
|
||||
checkAnswer(
|
||||
sql("select num_bool from jsonTable where NOT num_bool"),
|
||||
|
@ -508,7 +522,7 @@ class JsonSuite extends QueryTest {
|
|||
Row(Seq(), "11", "[1,2,3]", Row(null), "[]") ::
|
||||
Row(null, """{"field":false}""", null, null, "{}") ::
|
||||
Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") ::
|
||||
Row(Seq(7), "{}","[str1,str2,33]", Row("str"), """{"field":true}""") :: Nil
|
||||
Row(Seq(7), "{}","""["str1","str2",33]""", Row("str"), """{"field":true}""") :: Nil
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -566,19 +580,19 @@ class JsonSuite extends QueryTest {
|
|||
val analyzed = jsonDF.queryExecution.analyzed
|
||||
assert(
|
||||
analyzed.isInstanceOf[LogicalRelation],
|
||||
"The DataFrame returned by jsonFile should be based on JSONRelation.")
|
||||
"The DataFrame returned by jsonFile should be based on LogicalRelation.")
|
||||
val relation = analyzed.asInstanceOf[LogicalRelation].relation
|
||||
assert(
|
||||
relation.isInstanceOf[JSONRelation],
|
||||
"The DataFrame returned by jsonFile should be based on JSONRelation.")
|
||||
assert(relation.asInstanceOf[JSONRelation].path === path)
|
||||
assert(relation.asInstanceOf[JSONRelation].path === Some(path))
|
||||
assert(relation.asInstanceOf[JSONRelation].samplingRatio === (0.49 +- 0.001))
|
||||
|
||||
val schema = StructType(StructField("a", LongType, true) :: Nil)
|
||||
val logicalRelation =
|
||||
jsonFile(path, schema).queryExecution.analyzed.asInstanceOf[LogicalRelation]
|
||||
val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation]
|
||||
assert(relationWithSchema.path === path)
|
||||
assert(relationWithSchema.path === Some(path))
|
||||
assert(relationWithSchema.schema === schema)
|
||||
assert(relationWithSchema.samplingRatio > 0.99)
|
||||
}
|
||||
|
@ -1020,15 +1034,24 @@ class JsonSuite extends QueryTest {
|
|||
}
|
||||
|
||||
test("JSONRelation equality test") {
|
||||
val relation1 =
|
||||
JSONRelation("path", 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)))(null)
|
||||
val context = org.apache.spark.sql.test.TestSQLContext
|
||||
val relation1 = new JSONRelation(
|
||||
"path",
|
||||
1.0,
|
||||
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
|
||||
context)
|
||||
val logicalRelation1 = LogicalRelation(relation1)
|
||||
val relation2 =
|
||||
JSONRelation("path", 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)))(
|
||||
org.apache.spark.sql.test.TestSQLContext)
|
||||
val relation2 = new JSONRelation(
|
||||
"path",
|
||||
0.5,
|
||||
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
|
||||
context)
|
||||
val logicalRelation2 = LogicalRelation(relation2)
|
||||
val relation3 =
|
||||
JSONRelation("path", 1.0, Some(StructType(StructField("b", StringType, true) :: Nil)))(null)
|
||||
val relation3 = new JSONRelation(
|
||||
"path",
|
||||
1.0,
|
||||
Some(StructType(StructField("b", StringType, true) :: Nil)),
|
||||
context)
|
||||
val logicalRelation3 = LogicalRelation(relation3)
|
||||
|
||||
assert(relation1 === relation2)
|
||||
|
@ -1046,7 +1069,7 @@ class JsonSuite extends QueryTest {
|
|||
|
||||
test("SPARK-6245 JsonRDD.inferSchema on empty RDD") {
|
||||
// This is really a test that it doesn't throw an exception
|
||||
val emptySchema = JsonRDD.inferSchema(empty, 1.0, "")
|
||||
val emptySchema = InferSchema(empty, 1.0, "")
|
||||
assert(StructType(Seq()) === emptySchema)
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import java.io.File
|
|||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
import org.apache.spark.sql.{AnalysisException, Row}
|
||||
import org.apache.spark.sql.{SaveMode, AnalysisException, Row}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
|
||||
|
@ -100,23 +100,48 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
|
|||
test("INSERT OVERWRITE a JSONRelation multiple times") {
|
||||
sql(
|
||||
s"""
|
||||
|INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
|
||||
""".stripMargin)
|
||||
|
||||
sql(
|
||||
s"""
|
||||
|INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
|
||||
""".stripMargin)
|
||||
|
||||
sql(
|
||||
s"""
|
||||
|INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
|
||||
""".stripMargin)
|
||||
|
||||
|INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
|
||||
""".stripMargin)
|
||||
checkAnswer(
|
||||
sql("SELECT a, b FROM jsonTable"),
|
||||
(1 to 10).map(i => Row(i, s"str$i"))
|
||||
)
|
||||
|
||||
// Writing the table to less part files.
|
||||
val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5)
|
||||
jsonRDD(rdd1).registerTempTable("jt1")
|
||||
sql(
|
||||
s"""
|
||||
|INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1
|
||||
""".stripMargin)
|
||||
checkAnswer(
|
||||
sql("SELECT a, b FROM jsonTable"),
|
||||
(1 to 10).map(i => Row(i, s"str$i"))
|
||||
)
|
||||
|
||||
// Writing the table to more part files.
|
||||
val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10)
|
||||
jsonRDD(rdd2).registerTempTable("jt2")
|
||||
sql(
|
||||
s"""
|
||||
|INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2
|
||||
""".stripMargin)
|
||||
checkAnswer(
|
||||
sql("SELECT a, b FROM jsonTable"),
|
||||
(1 to 10).map(i => Row(i, s"str$i"))
|
||||
)
|
||||
|
||||
sql(
|
||||
s"""
|
||||
|INSERT OVERWRITE TABLE jsonTable SELECT a * 10, b FROM jt1
|
||||
""".stripMargin)
|
||||
checkAnswer(
|
||||
sql("SELECT a, b FROM jsonTable"),
|
||||
(1 to 10).map(i => Row(i * 10, s"str$i"))
|
||||
)
|
||||
|
||||
dropTempTable("jt1")
|
||||
dropTempTable("jt2")
|
||||
}
|
||||
|
||||
test("INSERT INTO not supported for JSONRelation for now") {
|
||||
|
@ -128,6 +153,20 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
|
|||
}
|
||||
}
|
||||
|
||||
test("save directly to the path of a JSON table") {
|
||||
table("jt").selectExpr("a * 5 as a", "b").save(path.toString, "json", SaveMode.Overwrite)
|
||||
checkAnswer(
|
||||
sql("SELECT a, b FROM jsonTable"),
|
||||
(1 to 10).map(i => Row(i * 5, s"str$i"))
|
||||
)
|
||||
|
||||
table("jt").save(path.toString, "json", SaveMode.Overwrite)
|
||||
checkAnswer(
|
||||
sql("SELECT a, b FROM jsonTable"),
|
||||
(1 to 10).map(i => Row(i, s"str$i"))
|
||||
)
|
||||
}
|
||||
|
||||
test("it is not allowed to write to a table while querying it.") {
|
||||
val message = intercept[AnalysisException] {
|
||||
sql(
|
||||
|
|
Loading…
Reference in a new issue