[SPARK-11011][SQL] Narrow type of UDT serialization
## What changes were proposed in this pull request? Narrow down the parameter type of `UserDefinedType#serialize()`. Currently, the parameter type is `Any`, however it would logically make more sense to narrow it down to the type of the actual user defined type. ## How was this patch tested? Existing tests were successfully run on local machine. Author: Jakob Odersky <jakob@odersky.com> Closes #11379 from jodersky/SPARK-11011-udt-types.
This commit is contained in:
parent
77ba3021c1
commit
d4d84936fb
|
@ -177,7 +177,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
|
|||
))
|
||||
}
|
||||
|
||||
override def serialize(obj: Any): InternalRow = {
|
||||
override def serialize(obj: Matrix): InternalRow = {
|
||||
val row = new GenericMutableRow(7)
|
||||
obj match {
|
||||
case sm: SparseMatrix =>
|
||||
|
|
|
@ -203,7 +203,7 @@ class VectorUDT extends UserDefinedType[Vector] {
|
|||
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
|
||||
}
|
||||
|
||||
override def serialize(obj: Any): InternalRow = {
|
||||
override def serialize(obj: Vector): InternalRow = {
|
||||
obj match {
|
||||
case SparseVector(size, indices, values) =>
|
||||
val row = new GenericMutableRow(4)
|
||||
|
|
|
@ -292,6 +292,8 @@ object MimaExcludes {
|
|||
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"),
|
||||
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$")
|
||||
) ++ Seq(
|
||||
//SPARK-11011 UserDefinedType serialization should be strongly typed
|
||||
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"),
|
||||
// SPARK-12073: backpressure rate controller consumes events preferentially from lagging partitions
|
||||
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.KafkaTestUtils.createTopic"),
|
||||
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition")
|
||||
|
|
|
@ -136,16 +136,16 @@ object CatalystTypeConverters {
|
|||
override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column, dataType)
|
||||
}
|
||||
|
||||
private case class UDTConverter(
|
||||
udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] {
|
||||
private case class UDTConverter[A >: Null](
|
||||
udt: UserDefinedType[A]) extends CatalystTypeConverter[A, A, Any] {
|
||||
// toCatalyst (it calls toCatalystImpl) will do null check.
|
||||
override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue)
|
||||
override def toCatalystImpl(scalaValue: A): Any = udt.serialize(scalaValue)
|
||||
|
||||
override def toScala(catalystValue: Any): Any = {
|
||||
override def toScala(catalystValue: Any): A = {
|
||||
if (catalystValue == null) null else udt.deserialize(catalystValue)
|
||||
}
|
||||
|
||||
override def toScalaImpl(row: InternalRow, column: Int): Any =
|
||||
override def toScalaImpl(row: InternalRow, column: Int): A =
|
||||
toScala(row.get(column, udt.sqlType))
|
||||
}
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ import org.apache.spark.annotation.DeveloperApi
|
|||
* The conversion via `deserialize` occurs when reading from a `DataFrame`.
|
||||
*/
|
||||
@DeveloperApi
|
||||
abstract class UserDefinedType[UserType] extends DataType with Serializable {
|
||||
abstract class UserDefinedType[UserType >: Null] extends DataType with Serializable {
|
||||
|
||||
/** Underlying storage type for this UDT */
|
||||
def sqlType: DataType
|
||||
|
@ -50,11 +50,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
|
|||
|
||||
/**
|
||||
* Convert the user type to a SQL datum
|
||||
*
|
||||
* TODO: Can we make this take obj: UserType? The issue is in
|
||||
* CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType.
|
||||
*/
|
||||
def serialize(obj: Any): Any
|
||||
def serialize(obj: UserType): Any
|
||||
|
||||
/** Convert a SQL datum to the user type */
|
||||
def deserialize(datum: Any): UserType
|
||||
|
|
|
@ -36,11 +36,7 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] {
|
|||
|
||||
override def sqlType: DataType = IntegerType
|
||||
|
||||
override def serialize(obj: Any): Int = {
|
||||
obj match {
|
||||
case groupableData: GroupableData => groupableData.data
|
||||
}
|
||||
}
|
||||
override def serialize(groupableData: GroupableData): Int = groupableData.data
|
||||
|
||||
override def deserialize(datum: Any): GroupableData = {
|
||||
datum match {
|
||||
|
@ -60,13 +56,10 @@ private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] {
|
|||
|
||||
override def sqlType: DataType = MapType(IntegerType, IntegerType)
|
||||
|
||||
override def serialize(obj: Any): MapData = {
|
||||
obj match {
|
||||
case groupableData: UngroupableData =>
|
||||
val keyArray = new GenericArrayData(groupableData.data.keys.toSeq)
|
||||
val valueArray = new GenericArrayData(groupableData.data.values.toSeq)
|
||||
new ArrayBasedMapData(keyArray, valueArray)
|
||||
}
|
||||
override def serialize(ungroupableData: UngroupableData): MapData = {
|
||||
val keyArray = new GenericArrayData(ungroupableData.data.keys.toSeq)
|
||||
val valueArray = new GenericArrayData(ungroupableData.data.values.toSeq)
|
||||
new ArrayBasedMapData(keyArray, valueArray)
|
||||
}
|
||||
|
||||
override def deserialize(datum: Any): UngroupableData = {
|
||||
|
|
|
@ -47,14 +47,11 @@ class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
|
|||
|
||||
override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
|
||||
|
||||
override def serialize(obj: Any): GenericArrayData = {
|
||||
obj match {
|
||||
case p: ExamplePoint =>
|
||||
val output = new Array[Any](2)
|
||||
output(0) = p.x
|
||||
output(1) = p.y
|
||||
new GenericArrayData(output)
|
||||
}
|
||||
override def serialize(p: ExamplePoint): GenericArrayData = {
|
||||
val output = new Array[Any](2)
|
||||
output(0) = p.x
|
||||
output(1) = p.y
|
||||
new GenericArrayData(output)
|
||||
}
|
||||
|
||||
override def deserialize(datum: Any): ExamplePoint = {
|
||||
|
|
|
@ -42,14 +42,11 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
|
|||
|
||||
override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
|
||||
|
||||
override def serialize(obj: Any): GenericArrayData = {
|
||||
obj match {
|
||||
case p: ExamplePoint =>
|
||||
val output = new Array[Any](2)
|
||||
output(0) = p.x
|
||||
output(1) = p.y
|
||||
new GenericArrayData(output)
|
||||
}
|
||||
override def serialize(p: ExamplePoint): GenericArrayData = {
|
||||
val output = new Array[Any](2)
|
||||
output(0) = p.x
|
||||
output(1) = p.y
|
||||
new GenericArrayData(output)
|
||||
}
|
||||
|
||||
override def deserialize(datum: Any): ExamplePoint = {
|
||||
|
|
|
@ -45,11 +45,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
|
|||
|
||||
override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)
|
||||
|
||||
override def serialize(obj: Any): ArrayData = {
|
||||
obj match {
|
||||
case features: MyDenseVector =>
|
||||
new GenericArrayData(features.data.map(_.asInstanceOf[Any]))
|
||||
}
|
||||
override def serialize(features: MyDenseVector): ArrayData = {
|
||||
new GenericArrayData(features.data.map(_.asInstanceOf[Any]))
|
||||
}
|
||||
|
||||
override def deserialize(datum: Any): MyDenseVector = {
|
||||
|
|
|
@ -590,14 +590,11 @@ object TestingUDT {
|
|||
.add("b", LongType, nullable = false)
|
||||
.add("c", DoubleType, nullable = false)
|
||||
|
||||
override def serialize(obj: Any): Any = {
|
||||
override def serialize(n: NestedStruct): Any = {
|
||||
val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType))
|
||||
obj match {
|
||||
case n: NestedStruct =>
|
||||
row.setInt(0, n.a)
|
||||
row.setLong(1, n.b)
|
||||
row.setDouble(2, n.c)
|
||||
}
|
||||
row.setInt(0, n.a)
|
||||
row.setLong(1, n.b)
|
||||
row.setDouble(2, n.c)
|
||||
}
|
||||
|
||||
override def userClass: Class[NestedStruct] = classOf[NestedStruct]
|
||||
|
|
Loading…
Reference in a new issue