diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index f1093701dd..adc56e7ec0 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1231,6 +1231,13 @@ class SQLContext: ... "field3.field5[0] as f3 from table3") >>> srdd6.collect() [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] + + >>> sqlCtx.jsonRDD(sc.parallelize(['{}', + ... '{"key0": {"key1": "value1"}}'])).collect() + [Row(key0=None), Row(key0=Row(key1=u'value1'))] + >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}', + ... '{"key0": {"key1": "value1"}}'])).collect() + [Row(key0=None), Row(key0=Row(key1=u'value1'))] """ def func(iterator): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 57df79321b..33b2ed1b3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -382,21 +382,26 @@ class SchemaRDD( private[sql] def javaToPython: JavaRDD[Array[Byte]] = { import scala.collection.Map - def toJava(obj: Any, dataType: DataType): Any = dataType match { - case struct: StructType => rowToArray(obj.asInstanceOf[Row], struct) - case array: ArrayType => obj match { - case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava - case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava - case arr if arr != null && arr.getClass.isArray => - arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) - case other => other - } - case mt: MapType => obj.asInstanceOf[Map[_, _]].map { + def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { + case (null, _) => null + + case (obj: Row, struct: StructType) => rowToArray(obj, struct) + + case (seq: Seq[Any], array: ArrayType) => + seq.map(x => toJava(x, array.elementType)).asJava + case (list: JList[_], array: ArrayType) => + list.map(x => toJava(x, array.elementType)).asJava + case (arr, array: ArrayType) if arr.getClass.isArray => + arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) + + case (obj: Map[_, _], mt: MapType) => obj.map { case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type }.asJava + // Pyrolite can handle Timestamp - case other => obj + case (other, _) => other } + def rowToArray(row: Row, structType: StructType): Array[Any] = { val fields = structType.fields.map(field => field.dataType) row.zip(fields).map {