[SPARK-7911] [MLLIB] A workaround for VectorUDT serialize (or deserialize) being called multiple times
~~A PythonUDT shouldn't be serialized into external Scala types in PythonRDD. I'm not sure whether this should fix one of the bugs related to SQL UDT/UDF in PySpark.~~
The fix above didn't work. So I added a workaround for this. If a Python UDF is applied to a Python UDT. This will put the Python SQL types as inputs. Still incorrect, but at least it doesn't throw exceptions on the Scala side. davies harsha2010
Author: Xiangrui Meng <meng@databricks.com>
Closes #6442 from mengxr/SPARK-7903 and squashes the following commits:
c257d2a [Xiangrui Meng] add a workaround for VectorUDT
(cherry picked from commit 530efe3e80
)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
This commit is contained in:
parent
ab62d73ddb
commit
7b5dffb802
|
@ -176,27 +176,31 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
|
||||||
}
|
}
|
||||||
|
|
||||||
override def serialize(obj: Any): Row = {
|
override def serialize(obj: Any): Row = {
|
||||||
val row = new GenericMutableRow(4)
|
|
||||||
obj match {
|
obj match {
|
||||||
case SparseVector(size, indices, values) =>
|
case SparseVector(size, indices, values) =>
|
||||||
|
val row = new GenericMutableRow(4)
|
||||||
row.setByte(0, 0)
|
row.setByte(0, 0)
|
||||||
row.setInt(1, size)
|
row.setInt(1, size)
|
||||||
row.update(2, indices.toSeq)
|
row.update(2, indices.toSeq)
|
||||||
row.update(3, values.toSeq)
|
row.update(3, values.toSeq)
|
||||||
|
row
|
||||||
case DenseVector(values) =>
|
case DenseVector(values) =>
|
||||||
|
val row = new GenericMutableRow(4)
|
||||||
row.setByte(0, 1)
|
row.setByte(0, 1)
|
||||||
row.setNullAt(1)
|
row.setNullAt(1)
|
||||||
row.setNullAt(2)
|
row.setNullAt(2)
|
||||||
row.update(3, values.toSeq)
|
row.update(3, values.toSeq)
|
||||||
|
row
|
||||||
|
// TODO: There are bugs in UDT serialization because we don't have a clear separation between
|
||||||
|
// TODO: internal SQL types and language specific types (including UDT). UDT serialize and
|
||||||
|
// TODO: deserialize may get called twice. See SPARK-7186.
|
||||||
|
case row: Row =>
|
||||||
|
row
|
||||||
}
|
}
|
||||||
row
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override def deserialize(datum: Any): Vector = {
|
override def deserialize(datum: Any): Vector = {
|
||||||
datum match {
|
datum match {
|
||||||
// TODO: something wrong with UDT serialization
|
|
||||||
case v: Vector =>
|
|
||||||
v
|
|
||||||
case row: Row =>
|
case row: Row =>
|
||||||
require(row.length == 4,
|
require(row.length == 4,
|
||||||
s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
|
s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
|
||||||
|
@ -211,6 +215,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
|
||||||
val values = row.getAs[Iterable[Double]](3).toArray
|
val values = row.getAs[Iterable[Double]](3).toArray
|
||||||
new DenseVector(values)
|
new DenseVector(values)
|
||||||
}
|
}
|
||||||
|
// TODO: There are bugs in UDT serialization because we don't have a clear separation between
|
||||||
|
// TODO: internal SQL types and language specific types (including UDT). UDT serialize and
|
||||||
|
// TODO: deserialize may get called twice. See SPARK-7186.
|
||||||
|
case v: Vector =>
|
||||||
|
v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue