[SPARK-7911] [MLLIB] A workaround for VectorUDT serialize (or deserialize) being called multiple times

~~A PythonUDT shouldn't be serialized into external Scala types in PythonRDD. I'm not sure whether this should fix one of the bugs related to SQL UDT/UDF in PySpark.~~

The fix above didn't work. So I added a workaround for this. If a Python UDF is applied to a Python UDT. This will put the Python SQL types as inputs. Still incorrect, but at least it doesn't throw exceptions on the Scala side. davies harsha2010

Author: Xiangrui Meng <meng@databricks.com>

Closes #6442 from mengxr/SPARK-7903 and squashes the following commits:

c257d2a [Xiangrui Meng] add a workaround for VectorUDT

(cherry picked from commit 530efe3e80)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
This commit is contained in:
Xiangrui Meng 2015-05-28 12:03:46 -07:00
parent ab62d73ddb
commit 7b5dffb802

View file

@ -176,27 +176,31 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
} }
override def serialize(obj: Any): Row = { override def serialize(obj: Any): Row = {
val row = new GenericMutableRow(4)
obj match { obj match {
case SparseVector(size, indices, values) => case SparseVector(size, indices, values) =>
val row = new GenericMutableRow(4)
row.setByte(0, 0) row.setByte(0, 0)
row.setInt(1, size) row.setInt(1, size)
row.update(2, indices.toSeq) row.update(2, indices.toSeq)
row.update(3, values.toSeq) row.update(3, values.toSeq)
row
case DenseVector(values) => case DenseVector(values) =>
val row = new GenericMutableRow(4)
row.setByte(0, 1) row.setByte(0, 1)
row.setNullAt(1) row.setNullAt(1)
row.setNullAt(2) row.setNullAt(2)
row.update(3, values.toSeq) row.update(3, values.toSeq)
row
// TODO: There are bugs in UDT serialization because we don't have a clear separation between
// TODO: internal SQL types and language specific types (including UDT). UDT serialize and
// TODO: deserialize may get called twice. See SPARK-7186.
case row: Row =>
row
} }
row
} }
override def deserialize(datum: Any): Vector = { override def deserialize(datum: Any): Vector = {
datum match { datum match {
// TODO: something wrong with UDT serialization
case v: Vector =>
v
case row: Row => case row: Row =>
require(row.length == 4, require(row.length == 4,
s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
@ -211,6 +215,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
val values = row.getAs[Iterable[Double]](3).toArray val values = row.getAs[Iterable[Double]](3).toArray
new DenseVector(values) new DenseVector(values)
} }
// TODO: There are bugs in UDT serialization because we don't have a clear separation between
// TODO: internal SQL types and language specific types (including UDT). UDT serialize and
// TODO: deserialize may get called twice. See SPARK-7186.
case v: Vector =>
v
} }
} }