[SPARK-7911] [MLLIB] A workaround for VectorUDT serialize (or deserialize) being called multiple times
~~A PythonUDT shouldn't be serialized into external Scala types in PythonRDD. I'm not sure whether this should fix one of the bugs related to SQL UDT/UDF in PySpark.~~
The fix above didn't work. So I added a workaround for this. If a Python UDF is applied to a Python UDT. This will put the Python SQL types as inputs. Still incorrect, but at least it doesn't throw exceptions on the Scala side. davies harsha2010
Author: Xiangrui Meng <meng@databricks.com>
Closes #6442 from mengxr/SPARK-7903 and squashes the following commits:
c257d2a [Xiangrui Meng] add a workaround for VectorUDT
(cherry picked from commit 530efe3e80
)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
This commit is contained in:
parent
ab62d73ddb
commit
7b5dffb802
|
@ -176,27 +176,31 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
|
|||
}
|
||||
|
||||
override def serialize(obj: Any): Row = {
|
||||
val row = new GenericMutableRow(4)
|
||||
obj match {
|
||||
case SparseVector(size, indices, values) =>
|
||||
val row = new GenericMutableRow(4)
|
||||
row.setByte(0, 0)
|
||||
row.setInt(1, size)
|
||||
row.update(2, indices.toSeq)
|
||||
row.update(3, values.toSeq)
|
||||
row
|
||||
case DenseVector(values) =>
|
||||
val row = new GenericMutableRow(4)
|
||||
row.setByte(0, 1)
|
||||
row.setNullAt(1)
|
||||
row.setNullAt(2)
|
||||
row.update(3, values.toSeq)
|
||||
row
|
||||
// TODO: There are bugs in UDT serialization because we don't have a clear separation between
|
||||
// TODO: internal SQL types and language specific types (including UDT). UDT serialize and
|
||||
// TODO: deserialize may get called twice. See SPARK-7186.
|
||||
case row: Row =>
|
||||
row
|
||||
}
|
||||
row
|
||||
}
|
||||
|
||||
override def deserialize(datum: Any): Vector = {
|
||||
datum match {
|
||||
// TODO: something wrong with UDT serialization
|
||||
case v: Vector =>
|
||||
v
|
||||
case row: Row =>
|
||||
require(row.length == 4,
|
||||
s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
|
||||
|
@ -211,6 +215,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
|
|||
val values = row.getAs[Iterable[Double]](3).toArray
|
||||
new DenseVector(values)
|
||||
}
|
||||
// TODO: There are bugs in UDT serialization because we don't have a clear separation between
|
||||
// TODO: internal SQL types and language specific types (including UDT). UDT serialize and
|
||||
// TODO: deserialize may get called twice. See SPARK-7186.
|
||||
case v: Vector =>
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue