[SPARK-2010] [PySpark] [SQL] support nested structure in SchemaRDD

Convert Row in JavaSchemaRDD into Array[Any] and unpickle them as tuple in Python, then convert them into namedtuple, so use can access fields just like attributes.

This will let nested structure can be accessed as object, also it will reduce the size of serialized data and better performance.

root
 |-- field1: integer (nullable = true)
 |-- field2: string (nullable = true)
 |-- field3: struct (nullable = true)
 |    |-- field4: integer (nullable = true)
 |    |-- field5: array (nullable = true)
 |    |    |-- element: integer (containsNull = false)
 |-- field6: array (nullable = true)
 |    |-- element: struct (containsNull = false)
 |    |    |-- field7: string (nullable = true)

Then we can access them by row.field3.field5[0]  or row.field6[5].field7

It also will infer the schema in Python, convert Row/dict/namedtuple/objects into tuple before serialization, then call applySchema in JVM. During inferSchema(), the top level of dict in row will be StructType, but any nested dictionary will be MapType.

You can use pyspark.sql.Row to convert unnamed structure into Row object, make the RDD can be inferable. Such as:

ctx.inferSchema(rdd.map(lambda x: Row(a=x[0], b=x[1]))

Or you could use Row to create a class just like namedtuple, for example:

Person = Row("name", "age")
ctx.inferSchema(rdd.map(lambda x: Person(*x)))

Also, you can call applySchema to apply an schema to a RDD of tuple/list and turn it into a SchemaRDD. The `schema` should be StructType, see the API docs for details.

schema = StructType([StructField("name, StringType, True),
                                    StructType("age", IntegerType, True)])
ctx.applySchema(rdd, schema)

PS: In order to use namedtuple to inferSchema, you should make namedtuple picklable.

Author: Davies Liu <davies.liu@gmail.com>

Closes #1598 from davies/nested and squashes the following commits:

f1d15b6 [Davies Liu] verify schema with the first few rows
8852aaf [Davies Liu] check type of schema
abe9e6e [Davies Liu] address comments
61b2292 [Davies Liu] add @deprecated to pythonToJavaMap
1e5b801 [Davies Liu] improve cache of classes
51aa135 [Davies Liu] use Row to infer schema
e9c0d5c [Davies Liu] remove string typed schema
353a3f2 [Davies Liu] fix code style
63de8f8 [Davies Liu] fix typo
c79ca67 [Davies Liu] fix serialization of nested data
6b258b5 [Davies Liu] fix pep8
9d8447c [Davies Liu] apply schema provided by string of names
f5df97f [Davies Liu] refactor, address comments
9d9af55 [Davies Liu] use arrry to applySchema and infer schema in Python
84679b3 [Davies Liu] Merge branch 'master' of github.com:apache/spark into nested
0eaaf56 [Davies Liu] fix doc tests
b3559b4 [Davies Liu] use generated Row instead of namedtuple
c4ddc30 [Davies Liu] fix conflict between name of fields and variables
7f6f251 [Davies Liu] address all comments
d69d397 [Davies Liu] refactor
2cc2d45 [Davies Liu] refactor
182fb46 [Davies Liu] refactor
bc6e9e1 [Davies Liu] switch to new Schema API
547bf3e [Davies Liu] Merge branch 'master' into nested
a435b5a [Davies Liu] add docs and code refactor
2c8debc [Davies Liu] Merge branch 'master' into nested
644665a [Davies Liu] use tuple and namedtuple for schemardd
This commit is contained in:
Davies Liu 2014-08-01 18:47:41 -07:00 committed by Michael Armbrust
parent 7058a5393b
commit 880eabec37
5 changed files with 995 additions and 443 deletions

View file

@ -25,7 +25,7 @@ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collectio
import scala.collection.JavaConversions._
import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.Try
import scala.util.{Try, Success, Failure}
import net.razorvine.pickle.{Pickler, Unpickler}
@ -536,25 +536,6 @@ private[spark] object PythonRDD extends Logging {
file.close()
}
/**
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
* It is only used by pyspark.sql.
* TODO: Support more Python types.
*/
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
iter.flatMap { row =>
unpickle.loads(row) match {
// in case of objects are pickled in batch mode
case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
// not in batch mode
case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
}
}
}
}
private def getMergedConf(confAsMap: java.util.HashMap[String, String],
baseConf: Configuration): Configuration = {
val conf = PythonHadoopUtil.mapToConf(confAsMap)
@ -701,6 +682,54 @@ private[spark] object PythonRDD extends Logging {
}
}
/**
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
* This function is outdated, PySpark does not use it anymore
*/
@deprecated
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
iter.flatMap { row =>
unpickle.loads(row) match {
// in case of objects are pickled in batch mode
case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
// not in batch mode
case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
}
}
}
}
/**
* Convert an RDD of serialized Python tuple to Array (no recursive conversions).
* It is only used by pyspark.sql.
*/
def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = {
def toArray(obj: Any): Array[_] = {
obj match {
case objs: JArrayList[_] =>
objs.toArray
case obj if obj.getClass.isArray =>
obj.asInstanceOf[Array[_]].toArray
}
}
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
obj.asInstanceOf[JArrayList[_]].map(toArray)
} else {
Seq(toArray(obj))
}
}
}.toJavaRDD()
}
/**
* Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by
* PySpark.

View file

@ -318,9 +318,9 @@ class RDD(object):
>>> sorted(rdd.map(lambda x: (x, 1)).collect())
[('a', 1), ('b', 1), ('c', 1)]
"""
def func(split, iterator):
def func(_, iterator):
return imap(f, iterator)
return PipelinedRDD(self, func, preservesPartitioning)
return self.mapPartitionsWithIndex(func, preservesPartitioning)
def flatMap(self, f, preservesPartitioning=False):
"""
@ -1184,7 +1184,7 @@ class RDD(object):
if not isinstance(x, basestring):
x = unicode(x)
yield x.encode("utf-8")
keyed = PipelinedRDD(self, func)
keyed = self.mapPartitionsWithIndex(func)
keyed._bypass_serializer = True
keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path)
@ -1382,7 +1382,7 @@ class RDD(object):
yield pack_long(split)
yield outputSerializer.dumps(items)
keyed = PipelinedRDD(self, add_shuffle_key)
keyed = self.mapPartitionsWithIndex(add_shuffle_key)
keyed._bypass_serializer = True
with _JavaStackTrace(self.context) as st:
pairRDD = self.ctx._jvm.PairwiseRDD(

File diff suppressed because it is too large Load diff

View file

@ -411,35 +411,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
""".stripMargin.trim
}
/**
* Peek at the first row of the RDD and infer its schema.
* It is only used by PySpark.
*/
private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
import scala.collection.JavaConversions._
def typeOfComplexValue: PartialFunction[Any, DataType] = {
case c: java.util.Calendar => TimestampType
case c: java.util.List[_] =>
ArrayType(typeOfObject(c.head))
case c: java.util.Map[_, _] =>
val (key, value) = c.head
MapType(typeOfObject(key), typeOfObject(value))
case c if c.getClass.isArray =>
val elem = c.asInstanceOf[Array[_]].head
ArrayType(typeOfObject(elem))
case c => throw new Exception(s"Object of type $c cannot be used")
}
def typeOfObject = ScalaReflection.typeOfObject orElse typeOfComplexValue
val firstRow = rdd.first()
val fields = firstRow.map {
case (fieldName, obj) => StructField(fieldName, typeOfObject(obj), true)
}.toSeq
applySchemaToPythonRDD(rdd, StructType(fields))
}
/**
* Parses the data type in our internal string representation. The data type string should
* have the same format as the one generated by `toString` in scala.
@ -454,7 +425,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Apply a schema defined by the schemaString to an RDD. It is only used by PySpark.
*/
private[sql] def applySchemaToPythonRDD(
rdd: RDD[Map[String, _]],
rdd: RDD[Array[Any]],
schemaString: String): SchemaRDD = {
val schema = parseDataType(schemaString).asInstanceOf[StructType]
applySchemaToPythonRDD(rdd, schema)
@ -464,10 +435,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Apply a schema defined by the schema to an RDD. It is only used by PySpark.
*/
private[sql] def applySchemaToPythonRDD(
rdd: RDD[Map[String, _]],
rdd: RDD[Array[Any]],
schema: StructType): SchemaRDD = {
// TODO: We should have a better implementation once we do not turn a Python side record
// to a Map.
import scala.collection.JavaConversions._
import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper}
@ -494,55 +463,39 @@ class SQLContext(@transient val sparkContext: SparkContext)
val converted = c.map { e => convert(e, elementType)}
JListWrapper(converted)
case (c: java.util.Map[_, _], struct: StructType) =>
val row = new GenericMutableRow(struct.fields.length)
struct.fields.zipWithIndex.foreach {
case (field, i) =>
val value = convert(c.get(field.name), field.dataType)
row.update(i, value)
}
row
case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
val converted = c.map {
case (key, value) =>
(convert(key, keyType), convert(value, valueType))
}
JMapWrapper(converted)
case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
val converted = c.asInstanceOf[Array[_]].map(e => convert(e, elementType))
converted: Seq[Any]
c.asInstanceOf[Array[_]].map(e => convert(e, elementType)): Seq[Any]
case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map {
case (key, value) => (convert(key, keyType), convert(value, valueType))
}.toMap
case (c, StructType(fields)) if c.getClass.isArray =>
new GenericRow(c.asInstanceOf[Array[_]].zip(fields).map {
case (e, f) => convert(e, f.dataType)
}): Row
case (c: java.util.Calendar, TimestampType) =>
new java.sql.Timestamp(c.getTime().getTime())
case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime())
case (c: Int, ByteType) => c.toByte
case (c: Int, ShortType) => c.toShort
case (c: Double, FloatType) => c.toFloat
case (c, StringType) if !c.isInstanceOf[String] => c.toString
case (c, _) => c
}
val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) {
rdd.map(m => m.map { case (key, value) => (key, convert(value, schema(key).dataType)) })
rdd.map(m => m.zip(schema.fields).map {
case (value, field) => convert(value, field.dataType)
})
} else {
rdd
}
val rowRdd = convertedRdd.mapPartitions { iter =>
val row = new GenericMutableRow(schema.fields.length)
val fieldsWithIndex = schema.fields.zipWithIndex
iter.map { m =>
// We cannot use m.values because the order of values returned by m.values may not
// match fields order.
fieldsWithIndex.foreach {
case (field, i) =>
val value =
m.get(field.name).flatMap(v => Option(v)).map(v => convert(v, field.dataType)).orNull
row.update(i, value)
}
row: Row
}
iter.map { m => new GenericRow(m): Row}
}
new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))(self))

View file

@ -383,7 +383,7 @@ class SchemaRDD(
import scala.collection.Map
def toJava(obj: Any, dataType: DataType): Any = dataType match {
case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct)
case struct: StructType => rowToArray(obj.asInstanceOf[Row], struct)
case array: ArrayType => obj match {
case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava
case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava
@ -397,21 +397,19 @@ class SchemaRDD(
// Pyrolite can handle Timestamp
case other => obj
}
def rowToMap(row: Row, structType: StructType): JMap[String, Any] = {
val fields = structType.fields.map(field => (field.name, field.dataType))
val map: JMap[String, Any] = new java.util.HashMap
row.zip(fields).foreach {
case (obj, (attrName, dataType)) => map.put(attrName, toJava(obj, dataType))
}
map
def rowToArray(row: Row, structType: StructType): Array[Any] = {
val fields = structType.fields.map(field => field.dataType)
row.zip(fields).map {
case (obj, dataType) => toJava(obj, dataType)
}.toArray
}
val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output)
this.mapPartitions { iter =>
val pickle = new Pickler
iter.map { row =>
rowToMap(row, rowSchema)
}.grouped(10).map(batched => pickle.dumps(batched.toArray))
rowToArray(row, rowSchema)
}.grouped(100).map(batched => pickle.dumps(batched.toArray))
}
}