[SPARK-5938] [SPARK-5443] [SQL] Improve JsonRDD performance

This patch comprises of a few related pieces of work:

* Schema inference is performed directly on the JSON token stream
* `String => Row` conversion populate Spark SQL structures without intermediate types
* Projection pushdown is implemented via CatalystScan for DataFrame queries
* Support for the legacy parser by setting `spark.sql.json.useJacksonStreamingAPI` to `false`

Performance improvements depend on the schema and queries being executed, but it should be faster across the board. Below are benchmarks using the last.fm Million Song dataset:

```
Command                                            | Baseline | Patched
---------------------------------------------------|----------|--------
import sqlContext.implicits._                      |          |
val df = sqlContext.jsonFile("/tmp/lastfm.json")   |    70.0s |   14.6s
df.count()                                         |    28.8s |    6.2s
df.rdd.count()                                     |    35.3s |   21.5s
df.where($"artist" === "Robert Hood").collect()    |    28.3s |   16.9s
```

To prepare this dataset for benchmarking, follow these steps:

```
# Fetch the datasets from http://labrosa.ee.columbia.edu/millionsong/lastfm
wget http://labrosa.ee.columbia.edu/millionsong/sites/default/files/lastfm/lastfm_test.zip \
     http://labrosa.ee.columbia.edu/millionsong/sites/default/files/lastfm/lastfm_train.zip

# Decompress and combine, pipe through `jq -c` to ensure there is one record per line
unzip -p lastfm_test.zip lastfm_train.zip  | jq -c . > lastfm.json
```

Author: Nathan Howell <nhowell@godaddy.com>

Closes #5801 from NathanHowell/json-performance and squashes the following commits:

26fea31 [Nathan Howell] Recreate the baseRDD each for each scan operation
a7ebeb2 [Nathan Howell] Increase coverage of inserts into a JSONRelation
e06a1dd [Nathan Howell] Add comments to the `useJacksonStreamingAPI` config flag
6822712 [Nathan Howell] Split up JsonRDD2 into multiple objects
fa8234f [Nathan Howell] Wrap long lines
b31917b [Nathan Howell] Rename `useJsonRDD2` to `useJacksonStreamingAPI`
15c5d1b [Nathan Howell] JSONRelation's baseRDD need not be lazy
f8add6e [Nathan Howell] Add comments on lack of support for precision and scale DecimalTypes
fa0be47 [Nathan Howell] Remove unused default case in the field parser
80dba17 [Nathan Howell] Add comments regarding null handling and empty strings
842846d [Nathan Howell] Point the empty schema inference test at JsonRDD2
ab6ee87 [Nathan Howell] Add projection pushdown support to JsonRDD/JsonRDD2
f636c14 [Nathan Howell] Enable JsonRDD2 by default, add a flag to switch back to JsonRDD
0bbc445 [Nathan Howell] Improve JSON parsing and type inference performance
7ca70c1 [Nathan Howell] Eliminate arrow pattern, replace with pattern matches
This commit is contained in:
Nathan Howell 2015-05-06 22:56:53 -07:00 committed by Yin Huai
parent 9cfa9a516e
commit 2d6612cc8b
13 changed files with 720 additions and 133 deletions

View file

@ -26,7 +26,14 @@ object HiveTypeCoercion {
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
// The conversion for integral and floating point types have a linear widening hierarchy:
private val numericPrecedence =
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited)
IndexedSeq(
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType,
DecimalType.Unlimited)
/**
* Find the tightest common type of two types that might be used in a binary expression.
@ -34,25 +41,21 @@ object HiveTypeCoercion {
* with primitive types, because in that case the precision and scale of the result depends on
* the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]].
*/
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
val valueTypes = Seq(t1, t2).filter(t => t != NullType)
if (valueTypes.distinct.size > 1) {
val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)
// Promote numeric types to the highest of the two and all numeric types to unlimited decimal
if (numericPrecedence.contains(t1) && numericPrecedence.contains(t2)) {
Some(numericPrecedence.filter(t => t == t1 || t == t2).last)
} else if (t1.isInstanceOf[DecimalType] && t2.isInstanceOf[DecimalType]) {
case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
Some(numericPrecedence(index))
// Fixed-precision decimals can up-cast into unlimited
if (t1 == DecimalType.Unlimited || t2 == DecimalType.Unlimited) {
Some(DecimalType.Unlimited)
} else {
None
}
} else {
None
}
} else {
Some(if (valueTypes.size == 0) NullType else valueTypes.head)
}
case (DecimalType.Unlimited, _: DecimalType) => Some(DecimalType.Unlimited)
case (_: DecimalType, DecimalType.Unlimited) => Some(DecimalType.Unlimited)
case _ => None
}
}

View file

@ -134,6 +134,10 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
}
private[sql] def getFieldIndex(name: String): Option[Int] = {
nameToIndex.get(name)
}
protected[sql] def toAttributes: Seq[AttributeReference] =
map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())

View file

@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.jdbc.JDBCWriteDetails
import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.json.{JacksonGenerator, JsonRDD}
import org.apache.spark.sql.types._
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
import org.apache.spark.util.Utils
@ -1415,7 +1415,7 @@ class DataFrame private[sql](
new Iterator[String] {
override def hasNext: Boolean = iter.hasNext
override def next(): String = {
JsonRDD.rowToJSON(rowSchema, gen)(iter.next())
JacksonGenerator(rowSchema, gen)(iter.next())
gen.flush()
val json = writer.toString

View file

@ -73,6 +73,8 @@ private[spark] object SQLConf {
val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2"
val USE_JACKSON_STREAMING_API = "spark.sql.json.useJacksonStreamingAPI"
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@ -166,6 +168,12 @@ private[sql] class SQLConf extends Serializable {
private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean
/**
* Selects between the new (true) and old (false) JSON handlers, to be removed in Spark 1.5.0
*/
private[spark] def useJacksonStreamingAPI: Boolean =
getConf(USE_JACKSON_STREAMING_API, "true").toBoolean
/**
* Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
* a broadcast value during the physical executions of join operations. Setting this to -1

View file

@ -659,6 +659,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def jsonRDD(json: RDD[String], schema: StructType): DataFrame = {
if (conf.useJacksonStreamingAPI) {
baseRelationToDataFrame(new JSONRelation(() => json, None, 1.0, Some(schema))(this))
} else {
val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
val appliedSchema =
Option(schema).getOrElse(
@ -667,6 +670,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
createDataFrame(rowRDD, appliedSchema, needsConversion = false)
}
}
/**
* :: Experimental ::
@ -689,6 +693,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = {
if (conf.useJacksonStreamingAPI) {
baseRelationToDataFrame(new JSONRelation(() => json, None, samplingRatio, None)(this))
} else {
val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
val appliedSchema =
JsonRDD.nullTypeToStringType(
@ -696,6 +703,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
createDataFrame(rowRDD, appliedSchema, needsConversion = false)
}
}
/**
* :: Experimental ::

View file

@ -0,0 +1,171 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.json
import com.fasterxml.jackson.core._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
import org.apache.spark.sql.json.JacksonUtils.nextUntil
import org.apache.spark.sql.types._
private[sql] object InferSchema {
/**
* Infer the type of a collection of json records in three stages:
* 1. Infer the type of each record
* 2. Merge types by choosing the lowest type necessary to cover equal keys
* 3. Replace any remaining null fields with string, the top type
*/
def apply(
json: RDD[String],
samplingRatio: Double = 1.0,
columnNameOfCorruptRecords: String): StructType = {
require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0")
val schemaData = if (samplingRatio > 0.99) {
json
} else {
json.sample(withReplacement = false, samplingRatio, 1)
}
// perform schema inference on each row and merge afterwards
schemaData.mapPartitions { iter =>
val factory = new JsonFactory()
iter.map { row =>
try {
val parser = factory.createParser(row)
parser.nextToken()
inferField(parser)
} catch {
case _: JsonParseException =>
StructType(Seq(StructField(columnNameOfCorruptRecords, StringType)))
}
}
}.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) match {
case st: StructType => nullTypeToStringType(st)
}
}
/**
* Infer the type of a json document from the parser's token stream
*/
private def inferField(parser: JsonParser): DataType = {
import com.fasterxml.jackson.core.JsonToken._
parser.getCurrentToken match {
case null | VALUE_NULL => NullType
case FIELD_NAME =>
parser.nextToken()
inferField(parser)
case VALUE_STRING if parser.getTextLength < 1 =>
// Zero length strings and nulls have special handling to deal
// with JSON generators that do not distinguish between the two.
// To accurately infer types for empty strings that are really
// meant to represent nulls we assume that the two are isomorphic
// but will defer treating null fields as strings until all the
// record fields' types have been combined.
NullType
case VALUE_STRING => StringType
case START_OBJECT =>
val builder = Seq.newBuilder[StructField]
while (nextUntil(parser, END_OBJECT)) {
builder += StructField(parser.getCurrentName, inferField(parser), nullable = true)
}
StructType(builder.result().sortBy(_.name))
case START_ARRAY =>
// If this JSON array is empty, we use NullType as a placeholder.
// If this array is not empty in other JSON objects, we can resolve
// the type as we pass through all JSON objects.
var elementType: DataType = NullType
while (nextUntil(parser, END_ARRAY)) {
elementType = compatibleType(elementType, inferField(parser))
}
ArrayType(elementType)
case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
import JsonParser.NumberType._
parser.getNumberType match {
// For Integer values, use LongType by default.
case INT | LONG => LongType
// Since we do not have a data type backed by BigInteger,
// when we see a Java BigInteger, we use DecimalType.
case BIG_INTEGER | BIG_DECIMAL => DecimalType.Unlimited
case FLOAT | DOUBLE => DoubleType
}
case VALUE_TRUE | VALUE_FALSE => BooleanType
}
}
private def nullTypeToStringType(struct: StructType): StructType = {
val fields = struct.fields.map {
case StructField(fieldName, dataType, nullable, _) =>
val newType = dataType match {
case NullType => StringType
case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull)
case ArrayType(struct: StructType, containsNull) =>
ArrayType(nullTypeToStringType(struct), containsNull)
case struct: StructType =>nullTypeToStringType(struct)
case other: DataType => other
}
StructField(fieldName, newType, nullable)
}
StructType(fields)
}
/**
* Remove top-level ArrayType wrappers and merge the remaining schemas
*/
private def compatibleRootType: (DataType, DataType) => DataType = {
case (ArrayType(ty1, _), ty2) => compatibleRootType(ty1, ty2)
case (ty1, ArrayType(ty2, _)) => compatibleRootType(ty1, ty2)
case (ty1, ty2) => compatibleType(ty1, ty2)
}
/**
* Returns the most general data type for two given data types.
*/
private[json] def compatibleType(t1: DataType, t2: DataType): DataType = {
HiveTypeCoercion.findTightestCommonType(t1, t2).getOrElse {
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
(t1, t2) match {
case (other: DataType, NullType) => other
case (NullType, other: DataType) => other
case (StructType(fields1), StructType(fields2)) =>
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
case (name, fieldTypes) =>
val dataType = fieldTypes.view.map(_.dataType).reduce(compatibleType)
StructField(name, dataType, nullable = true)
}
StructType(newFields.toSeq.sortBy(_.name))
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
// strings and every string is a Json object.
case (_, _) => StringType
}
}
}
}

View file

@ -22,14 +22,16 @@ import java.io.IOException
import org.apache.hadoop.fs.Path
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute, Row}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
private[sql] class DefaultSource
extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider {
extends RelationProvider
with SchemaRelationProvider
with CreatableRelationProvider {
private def checkPath(parameters: Map[String, String]): String = {
parameters.getOrElse("path", sys.error("'path' must be specified for json data."))
@ -42,7 +44,7 @@ private[sql] class DefaultSource
val path = checkPath(parameters)
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
JSONRelation(path, samplingRatio, None)(sqlContext)
new JSONRelation(path, samplingRatio, None, sqlContext)
}
/** Returns a new base relation with the given schema and parameters. */
@ -53,7 +55,7 @@ private[sql] class DefaultSource
val path = checkPath(parameters)
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
JSONRelation(path, samplingRatio, Some(schema))(sqlContext)
new JSONRelation(path, samplingRatio, Some(schema), sqlContext)
}
override def createRelation(
@ -101,32 +103,87 @@ private[sql] class DefaultSource
}
}
private[sql] case class JSONRelation(
path: String,
samplingRatio: Double,
private[sql] class JSONRelation(
// baseRDD is not immutable with respect to INSERT OVERWRITE
// and so it must be recreated at least as often as the
// underlying inputs are modified. To be safe, a function is
// used instead of a regular RDD value to ensure a fresh RDD is
// recreated for each and every operation.
baseRDD: () => RDD[String],
val path: Option[String],
val samplingRatio: Double,
userSpecifiedSchema: Option[StructType])(
@transient val sqlContext: SQLContext)
extends BaseRelation
with TableScan
with InsertableRelation {
with InsertableRelation
with CatalystScan {
// TODO: Support partitioned JSON relation.
private def baseRDD = sqlContext.sparkContext.textFile(path)
def this(
path: String,
samplingRatio: Double,
userSpecifiedSchema: Option[StructType],
sqlContext: SQLContext) =
this(
() => sqlContext.sparkContext.textFile(path),
Some(path),
samplingRatio,
userSpecifiedSchema)(sqlContext)
private val useJacksonStreamingAPI: Boolean = sqlContext.conf.useJacksonStreamingAPI
override val needConversion: Boolean = false
override val schema = userSpecifiedSchema.getOrElse(
override lazy val schema = userSpecifiedSchema.getOrElse {
if (useJacksonStreamingAPI) {
InferSchema(
baseRDD(),
samplingRatio,
sqlContext.conf.columnNameOfCorruptRecord)
} else {
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(
baseRDD,
baseRDD(),
samplingRatio,
sqlContext.conf.columnNameOfCorruptRecord)))
sqlContext.conf.columnNameOfCorruptRecord))
}
}
override def buildScan(): RDD[Row] =
JsonRDD.jsonStringToRow(baseRDD, schema, sqlContext.conf.columnNameOfCorruptRecord)
override def buildScan(): RDD[Row] = {
if (useJacksonStreamingAPI) {
JacksonParser(
baseRDD(),
schema,
sqlContext.conf.columnNameOfCorruptRecord)
} else {
JsonRDD.jsonStringToRow(
baseRDD(),
schema,
sqlContext.conf.columnNameOfCorruptRecord)
}
}
override def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] = {
if (useJacksonStreamingAPI) {
JacksonParser(
baseRDD(),
StructType.fromAttributes(requiredColumns),
sqlContext.conf.columnNameOfCorruptRecord)
} else {
JsonRDD.jsonStringToRow(
baseRDD(),
StructType.fromAttributes(requiredColumns),
sqlContext.conf.columnNameOfCorruptRecord)
}
}
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
val filesystemPath = new Path(path)
val filesystemPath = path match {
case Some(p) => new Path(p)
case None =>
throw new IOException(s"Cannot INSERT into table with no path defined")
}
val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
if (overwrite) {
@ -147,7 +204,7 @@ private[sql] case class JSONRelation(
}
}
// Write the data.
data.toJSON.saveAsTextFile(path)
data.toJSON.saveAsTextFile(filesystemPath.toString)
// Right now, we assume that the schema is not changed. We will not update the schema.
// schema = data.schema
} else {

View file

@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.json
import scala.collection.Map
import com.fasterxml.jackson.core._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
private[sql] object JacksonGenerator {
/** Transforms a single Row to JSON using Jackson
*
* @param rowSchema the schema object used for conversion
* @param gen a JsonGenerator object
* @param row The row to convert
*/
def apply(rowSchema: StructType, gen: JsonGenerator)(row: Row): Unit = {
def valWriter: (DataType, Any) => Unit = {
case (_, null) | (NullType, _) => gen.writeNull()
case (StringType, v: String) => gen.writeString(v)
case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString)
case (IntegerType, v: Int) => gen.writeNumber(v)
case (ShortType, v: Short) => gen.writeNumber(v)
case (FloatType, v: Float) => gen.writeNumber(v)
case (DoubleType, v: Double) => gen.writeNumber(v)
case (LongType, v: Long) => gen.writeNumber(v)
case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v)
case (ByteType, v: Byte) => gen.writeNumber(v.toInt)
case (BinaryType, v: Array[Byte]) => gen.writeBinary(v)
case (BooleanType, v: Boolean) => gen.writeBoolean(v)
case (DateType, v) => gen.writeString(v.toString)
case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, udt.serialize(v))
case (ArrayType(ty, _), v: Seq[_] ) =>
gen.writeStartArray()
v.foreach(valWriter(ty,_))
gen.writeEndArray()
case (MapType(kv,vv, _), v: Map[_,_]) =>
gen.writeStartObject()
v.foreach { p =>
gen.writeFieldName(p._1.toString)
valWriter(vv,p._2)
}
gen.writeEndObject()
case (StructType(ty), v: Row) =>
gen.writeStartObject()
ty.zip(v.toSeq).foreach {
case (_, null) =>
case (field, v) =>
gen.writeFieldName(field.name)
valWriter(field.dataType, v)
}
gen.writeEndObject()
}
valWriter(rowSchema, row)
}
}

View file

@ -0,0 +1,215 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.json
import java.io.ByteArrayOutputStream
import java.sql.Timestamp
import scala.collection.Map
import com.fasterxml.jackson.core._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.json.JacksonUtils.nextUntil
import org.apache.spark.sql.types._
private[sql] object JacksonParser {
def apply(
json: RDD[String],
schema: StructType,
columnNameOfCorruptRecords: String): RDD[Row] = {
parseJson(json, schema, columnNameOfCorruptRecords)
}
/**
* Parse the current token (and related children) according to a desired schema
*/
private[sql] def convertField(
factory: JsonFactory,
parser: JsonParser,
schema: DataType): Any = {
import com.fasterxml.jackson.core.JsonToken._
(parser.getCurrentToken, schema) match {
case (null | VALUE_NULL, _) =>
null
case (FIELD_NAME, _) =>
parser.nextToken()
convertField(factory, parser, schema)
case (VALUE_STRING, StringType) =>
UTF8String(parser.getText)
case (VALUE_STRING, _) if parser.getTextLength < 1 =>
// guard the non string type
null
case (VALUE_STRING, DateType) =>
DateUtils.millisToDays(DateUtils.stringToTime(parser.getText).getTime)
case (VALUE_STRING, TimestampType) =>
new Timestamp(DateUtils.stringToTime(parser.getText).getTime)
case (VALUE_NUMBER_INT, TimestampType) =>
new Timestamp(parser.getLongValue)
case (_, StringType) =>
val writer = new ByteArrayOutputStream()
val generator = factory.createGenerator(writer, JsonEncoding.UTF8)
generator.copyCurrentStructure(parser)
generator.close()
UTF8String(writer.toByteArray)
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, FloatType) =>
parser.getFloatValue
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DoubleType) =>
parser.getDoubleValue
case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, DecimalType()) =>
// TODO: add fixed precision and scale handling
Decimal(parser.getDecimalValue)
case (VALUE_NUMBER_INT, ByteType) =>
parser.getByteValue
case (VALUE_NUMBER_INT, ShortType) =>
parser.getShortValue
case (VALUE_NUMBER_INT, IntegerType) =>
parser.getIntValue
case (VALUE_NUMBER_INT, LongType) =>
parser.getLongValue
case (VALUE_TRUE, BooleanType) =>
true
case (VALUE_FALSE, BooleanType) =>
false
case (START_OBJECT, st: StructType) =>
convertObject(factory, parser, st)
case (START_ARRAY, ArrayType(st, _)) =>
convertList(factory, parser, st)
case (START_OBJECT, ArrayType(st, _)) =>
// the business end of SPARK-3308:
// when an object is found but an array is requested just wrap it in a list
convertField(factory, parser, st) :: Nil
case (START_OBJECT, MapType(StringType, kt, _)) =>
convertMap(factory, parser, kt)
case (_, udt: UserDefinedType[_]) =>
udt.deserialize(convertField(factory, parser, udt.sqlType))
}
}
/**
* Parse an object from the token stream into a new Row representing the schema.
*
* Fields in the json that are not defined in the requested schema will be dropped.
*/
private def convertObject(factory: JsonFactory, parser: JsonParser, schema: StructType): Row = {
val row = new GenericMutableRow(schema.length)
while (nextUntil(parser, JsonToken.END_OBJECT)) {
schema.getFieldIndex(parser.getCurrentName) match {
case Some(index) =>
row.update(index, convertField(factory, parser, schema(index).dataType))
case None =>
parser.skipChildren()
}
}
row
}
/**
* Parse an object as a Map, preserving all fields
*/
private def convertMap(
factory: JsonFactory,
parser: JsonParser,
valueType: DataType): Map[String, Any] = {
val builder = Map.newBuilder[String, Any]
while (nextUntil(parser, JsonToken.END_OBJECT)) {
builder += parser.getCurrentName -> convertField(factory, parser, valueType)
}
builder.result()
}
private def convertList(
factory: JsonFactory,
parser: JsonParser,
schema: DataType): Seq[Any] = {
val builder = Seq.newBuilder[Any]
while (nextUntil(parser, JsonToken.END_ARRAY)) {
builder += convertField(factory, parser, schema)
}
builder.result()
}
private def parseJson(
json: RDD[String],
schema: StructType,
columnNameOfCorruptRecords: String): RDD[Row] = {
def failedRecord(record: String): Seq[Row] = {
// create a row even if no corrupt record column is present
val row = new GenericMutableRow(schema.length)
for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) {
require(schema(corruptIndex).dataType == StringType)
row.update(corruptIndex, record)
}
Seq(row)
}
json.mapPartitions { iter =>
val factory = new JsonFactory()
iter.flatMap { record =>
try {
val parser = factory.createParser(record)
parser.nextToken()
// to support both object and arrays (see SPARK-3308) we'll start
// by converting the StructType schema to an ArrayType and let
// convertField wrap an object into a single value array when necessary.
convertField(factory, parser, ArrayType(schema)) match {
case null => failedRecord(record)
case list: Seq[Row @unchecked] => list
case _ =>
sys.error(
s"Failed to parse record $record. Please make sure that each line of the file " +
"(or each string in the RDD) is a valid JSON object or an array of JSON objects.")
}
} catch {
case _: JsonProcessingException =>
failedRecord(record)
}
}
}
}
}

View file

@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.json
import com.fasterxml.jackson.core.{JsonParser, JsonToken}
private object JacksonUtils {
/**
* Advance the parser until a null or a specific token is found
*/
def nextUntil(parser: JsonParser, stopOn: JsonToken): Boolean = {
parser.nextToken() match {
case null => false
case x => x != stopOn
}
}
}

View file

@ -440,54 +440,4 @@ private[sql] object JsonRDD extends Logging {
row
}
/** Transforms a single Row to JSON using Jackson
*
* @param rowSchema the schema object used for conversion
* @param gen a JsonGenerator object
* @param row The row to convert
*/
private[sql] def rowToJSON(rowSchema: StructType, gen: JsonGenerator)(row: Row) = {
def valWriter: (DataType, Any) => Unit = {
case (_, null) | (NullType, _) => gen.writeNull()
case (StringType, v: String) => gen.writeString(v)
case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString)
case (IntegerType, v: Int) => gen.writeNumber(v)
case (ShortType, v: Short) => gen.writeNumber(v)
case (FloatType, v: Float) => gen.writeNumber(v)
case (DoubleType, v: Double) => gen.writeNumber(v)
case (LongType, v: Long) => gen.writeNumber(v)
case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v)
case (ByteType, v: Byte) => gen.writeNumber(v.toInt)
case (BinaryType, v: Array[Byte]) => gen.writeBinary(v)
case (BooleanType, v: Boolean) => gen.writeBoolean(v)
case (DateType, v) => gen.writeString(v.toString)
case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v)
case (ArrayType(ty, _), v: Seq[_] ) =>
gen.writeStartArray()
v.foreach(valWriter(ty,_))
gen.writeEndArray()
case (MapType(kv,vv, _), v: Map[_,_]) =>
gen.writeStartObject()
v.foreach { p =>
gen.writeFieldName(p._1.toString)
valWriter(vv,p._2)
}
gen.writeEndObject()
case (StructType(ty), v: Row) =>
gen.writeStartObject()
ty.zip(v.toSeq).foreach {
case (_, null) =>
case (field, v) =>
gen.writeFieldName(field.name)
valWriter(field.dataType, v)
}
gen.writeEndObject()
}
valWriter(rowSchema, row)
}
}

View file

@ -17,13 +17,15 @@
package org.apache.spark.sql.json
import java.io.StringWriter
import java.sql.{Date, Timestamp}
import com.fasterxml.jackson.core.JsonFactory
import org.scalactic.Tolerance._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType}
import org.apache.spark.sql.json.InferSchema.compatibleType
import org.apache.spark.sql.sources.LogicalRelation
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
@ -46,6 +48,18 @@ class JsonSuite extends QueryTest {
s"${expected}(${expected.getClass}).")
}
val factory = new JsonFactory()
def enforceCorrectType(value: Any, dataType: DataType): Any = {
val writer = new StringWriter()
val generator = factory.createGenerator(writer)
generator.writeObject(value)
generator.flush()
val parser = factory.createParser(writer.toString)
parser.nextToken()
JacksonParser.convertField(factory, parser, dataType)
}
val intNumber: Int = 2147483647
checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType))
checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType))
@ -439,7 +453,7 @@ class JsonSuite extends QueryTest {
val jsonDF = jsonRDD(primitiveFieldValueTypeConflict)
jsonDF.registerTempTable("jsonTable")
// Right now, the analyzer does not promote strings in a boolean expreesion.
// Right now, the analyzer does not promote strings in a boolean expression.
// Number and Boolean conflict: resolve the type as boolean in this query.
checkAnswer(
sql("select num_bool from jsonTable where NOT num_bool"),
@ -508,7 +522,7 @@ class JsonSuite extends QueryTest {
Row(Seq(), "11", "[1,2,3]", Row(null), "[]") ::
Row(null, """{"field":false}""", null, null, "{}") ::
Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") ::
Row(Seq(7), "{}","[str1,str2,33]", Row("str"), """{"field":true}""") :: Nil
Row(Seq(7), "{}","""["str1","str2",33]""", Row("str"), """{"field":true}""") :: Nil
)
}
@ -566,19 +580,19 @@ class JsonSuite extends QueryTest {
val analyzed = jsonDF.queryExecution.analyzed
assert(
analyzed.isInstanceOf[LogicalRelation],
"The DataFrame returned by jsonFile should be based on JSONRelation.")
"The DataFrame returned by jsonFile should be based on LogicalRelation.")
val relation = analyzed.asInstanceOf[LogicalRelation].relation
assert(
relation.isInstanceOf[JSONRelation],
"The DataFrame returned by jsonFile should be based on JSONRelation.")
assert(relation.asInstanceOf[JSONRelation].path === path)
assert(relation.asInstanceOf[JSONRelation].path === Some(path))
assert(relation.asInstanceOf[JSONRelation].samplingRatio === (0.49 +- 0.001))
val schema = StructType(StructField("a", LongType, true) :: Nil)
val logicalRelation =
jsonFile(path, schema).queryExecution.analyzed.asInstanceOf[LogicalRelation]
val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation]
assert(relationWithSchema.path === path)
assert(relationWithSchema.path === Some(path))
assert(relationWithSchema.schema === schema)
assert(relationWithSchema.samplingRatio > 0.99)
}
@ -1020,15 +1034,24 @@ class JsonSuite extends QueryTest {
}
test("JSONRelation equality test") {
val relation1 =
JSONRelation("path", 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)))(null)
val context = org.apache.spark.sql.test.TestSQLContext
val relation1 = new JSONRelation(
"path",
1.0,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
context)
val logicalRelation1 = LogicalRelation(relation1)
val relation2 =
JSONRelation("path", 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)))(
org.apache.spark.sql.test.TestSQLContext)
val relation2 = new JSONRelation(
"path",
0.5,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
context)
val logicalRelation2 = LogicalRelation(relation2)
val relation3 =
JSONRelation("path", 1.0, Some(StructType(StructField("b", StringType, true) :: Nil)))(null)
val relation3 = new JSONRelation(
"path",
1.0,
Some(StructType(StructField("b", StringType, true) :: Nil)),
context)
val logicalRelation3 = LogicalRelation(relation3)
assert(relation1 === relation2)
@ -1046,7 +1069,7 @@ class JsonSuite extends QueryTest {
test("SPARK-6245 JsonRDD.inferSchema on empty RDD") {
// This is really a test that it doesn't throw an exception
val emptySchema = JsonRDD.inferSchema(empty, 1.0, "")
val emptySchema = InferSchema(empty, 1.0, "")
assert(StructType(Seq()) === emptySchema)
}

View file

@ -21,7 +21,7 @@ import java.io.File
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.{SaveMode, AnalysisException, Row}
import org.apache.spark.util.Utils
class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
@ -102,21 +102,46 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
s"""
|INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
""".stripMargin)
sql(
s"""
|INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
""".stripMargin)
sql(
s"""
|INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt
""".stripMargin)
checkAnswer(
sql("SELECT a, b FROM jsonTable"),
(1 to 10).map(i => Row(i, s"str$i"))
)
// Writing the table to less part files.
val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5)
jsonRDD(rdd1).registerTempTable("jt1")
sql(
s"""
|INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1
""".stripMargin)
checkAnswer(
sql("SELECT a, b FROM jsonTable"),
(1 to 10).map(i => Row(i, s"str$i"))
)
// Writing the table to more part files.
val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10)
jsonRDD(rdd2).registerTempTable("jt2")
sql(
s"""
|INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2
""".stripMargin)
checkAnswer(
sql("SELECT a, b FROM jsonTable"),
(1 to 10).map(i => Row(i, s"str$i"))
)
sql(
s"""
|INSERT OVERWRITE TABLE jsonTable SELECT a * 10, b FROM jt1
""".stripMargin)
checkAnswer(
sql("SELECT a, b FROM jsonTable"),
(1 to 10).map(i => Row(i * 10, s"str$i"))
)
dropTempTable("jt1")
dropTempTable("jt2")
}
test("INSERT INTO not supported for JSONRelation for now") {
@ -128,6 +153,20 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
}
}
test("save directly to the path of a JSON table") {
table("jt").selectExpr("a * 5 as a", "b").save(path.toString, "json", SaveMode.Overwrite)
checkAnswer(
sql("SELECT a, b FROM jsonTable"),
(1 to 10).map(i => Row(i * 5, s"str$i"))
)
table("jt").save(path.toString, "json", SaveMode.Overwrite)
checkAnswer(
sql("SELECT a, b FROM jsonTable"),
(1 to 10).map(i => Row(i, s"str$i"))
)
}
test("it is not allowed to write to a table while querying it.") {
val message = intercept[AnalysisException] {
sql(