[SPARK-18295][SQL] Make to_json function null safe (matching it to from_json)
## What changes were proposed in this pull request? This PR proposes to match up the behaviour of `to_json` to `from_json` function for null-safety. Currently, it throws `NullPointException` but this PR fixes this to produce `null` instead. with the data below: ```scala import spark.implicits._ val df = Seq(Some(Tuple1(Tuple1(1))), None).toDF("a") df.show() ``` ``` +----+ | a| +----+ | [1]| |null| +----+ ``` the codes below ```scala import org.apache.spark.sql.functions._ df.select(to_json($"a")).show() ``` produces.. **Before** throws `NullPointException` as below: ``` java.lang.NullPointerException at org.apache.spark.sql.catalyst.json.JacksonGenerator.org$apache$spark$sql$catalyst$json$JacksonGenerator$$writeFields(JacksonGenerator.scala:138) at org.apache.spark.sql.catalyst.json.JacksonGenerator$$anonfun$write$1.apply$mcV$sp(JacksonGenerator.scala:194) at org.apache.spark.sql.catalyst.json.JacksonGenerator.org$apache$spark$sql$catalyst$json$JacksonGenerator$$writeObject(JacksonGenerator.scala:131) at org.apache.spark.sql.catalyst.json.JacksonGenerator.write(JacksonGenerator.scala:193) at org.apache.spark.sql.catalyst.expressions.StructToJson.eval(jsonExpressions.scala:544) at org.apache.spark.sql.catalyst.expressions.Alias.eval(namedExpressions.scala:142) at org.apache.spark.sql.catalyst.expressions.InterpretedProjection.apply(Projection.scala:48) at org.apache.spark.sql.catalyst.expressions.InterpretedProjection.apply(Projection.scala:30) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) ``` **After** ``` +---------------+ |structtojson(a)| +---------------+ | {"_1":1}| | null| +---------------+ ``` ## How was this patch tested? Unit test in `JsonExpressionsSuite.scala` and `JsonFunctionsSuite.scala`. Author: hyukjinkwon <gurwls223@gmail.com> Closes #15792 from HyukjinKwon/SPARK-18295.
This commit is contained in:
parent
3a710b94b0
commit
3eda05703f
|
@ -484,7 +484,7 @@ case class JsonTuple(children: Seq[Expression])
|
|||
* Converts an json input string to a [[StructType]] with the specified schema.
|
||||
*/
|
||||
case class JsonToStruct(schema: StructType, options: Map[String, String], child: Expression)
|
||||
extends Expression with CodegenFallback with ExpectsInputTypes {
|
||||
extends UnaryExpression with CodegenFallback with ExpectsInputTypes {
|
||||
override def nullable: Boolean = true
|
||||
|
||||
@transient
|
||||
|
@ -495,11 +495,8 @@ case class JsonToStruct(schema: StructType, options: Map[String, String], child:
|
|||
new JSONOptions(options ++ Map("mode" -> ParseModes.FAIL_FAST_MODE)))
|
||||
|
||||
override def dataType: DataType = schema
|
||||
override def children: Seq[Expression] = child :: Nil
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val json = child.eval(input)
|
||||
if (json == null) return null
|
||||
override def nullSafeEval(json: Any): Any = {
|
||||
try parser.parse(json.toString).head catch {
|
||||
case _: SparkSQLJsonProcessingException => null
|
||||
}
|
||||
|
@ -512,7 +509,7 @@ case class JsonToStruct(schema: StructType, options: Map[String, String], child:
|
|||
* Converts a [[StructType]] to a json output string.
|
||||
*/
|
||||
case class StructToJson(options: Map[String, String], child: Expression)
|
||||
extends Expression with CodegenFallback with ExpectsInputTypes {
|
||||
extends UnaryExpression with CodegenFallback with ExpectsInputTypes {
|
||||
override def nullable: Boolean = true
|
||||
|
||||
@transient
|
||||
|
@ -523,7 +520,6 @@ case class StructToJson(options: Map[String, String], child: Expression)
|
|||
new JacksonGenerator(child.dataType.asInstanceOf[StructType], writer)
|
||||
|
||||
override def dataType: DataType = StringType
|
||||
override def children: Seq[Expression] = child :: Nil
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (StructType.acceptsType(child.dataType)) {
|
||||
|
@ -540,8 +536,8 @@ case class StructToJson(options: Map[String, String], child: Expression)
|
|||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
gen.write(child.eval(input).asInstanceOf[InternalRow])
|
||||
override def nullSafeEval(row: Any): Any = {
|
||||
gen.write(row.asInstanceOf[InternalRow])
|
||||
gen.flush()
|
||||
val json = writer.toString
|
||||
writer.reset()
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.util.ParseModes
|
||||
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
|
||||
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||
|
@ -347,7 +347,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
test("from_json null input column") {
|
||||
val schema = StructType(StructField("a", IntegerType) :: Nil)
|
||||
checkEvaluation(
|
||||
JsonToStruct(schema, Map.empty, Literal(null)),
|
||||
JsonToStruct(schema, Map.empty, Literal.create(null, StringType)),
|
||||
null
|
||||
)
|
||||
}
|
||||
|
@ -360,4 +360,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
"""{"a":1}"""
|
||||
)
|
||||
}
|
||||
|
||||
test("to_json null input column") {
|
||||
val schema = StructType(StructField("a", IntegerType) :: Nil)
|
||||
val struct = Literal.create(null, schema)
|
||||
checkEvaluation(
|
||||
StructToJson(Map.empty, struct),
|
||||
null
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -141,4 +141,18 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
assert(e.getMessage.contains(
|
||||
"Unable to convert column a of type calendarinterval to JSON."))
|
||||
}
|
||||
|
||||
test("roundtrip in to_json and from_json") {
|
||||
val dfOne = Seq(Some(Tuple1(Tuple1(1))), None).toDF("struct")
|
||||
val schemaOne = dfOne.schema(0).dataType.asInstanceOf[StructType]
|
||||
val readBackOne = dfOne.select(to_json($"struct").as("json"))
|
||||
.select(from_json($"json", schemaOne).as("struct"))
|
||||
checkAnswer(dfOne, readBackOne)
|
||||
|
||||
val dfTwo = Seq(Some("""{"a":1}"""), None).toDF("json")
|
||||
val schemaTwo = new StructType().add("a", IntegerType)
|
||||
val readBackTwo = dfTwo.select(from_json($"json", schemaTwo).as("struct"))
|
||||
.select(to_json($"struct").as("json"))
|
||||
checkAnswer(dfTwo, readBackTwo)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue