[SPARK-32090][SQL] Improve UserDefinedType.equal() to make it be symmetrical
### What changes were proposed in this pull request? This PR fix `UserDefinedType.equal()` by comparing the UDT class instead of checking `acceptsType()`. ### Why are the changes needed? It's weird that equality comparison between two UDT types can have different result by switching the order: ```scala // ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass val udt1 = new ExampleBaseTypeUDT val udt2 = new ExampleSubTypeUDT println(udt1 == udt2) // true println(udt2 == udt1) // false ``` ### Does this PR introduce _any_ user-facing change? Yes. Before: ```scala // ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass val udt1 = new ExampleBaseTypeUDT val udt2 = new ExampleSubTypeUDT println(udt1 == udt2) // true println(udt2 == udt1) // false ``` After: ```scala // ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass val udt1 = new ExampleBaseTypeUDT val udt2 = new ExampleSubTypeUDT println(udt1 == udt2) // false println(udt2 == udt1) // false ``` ### How was this patch tested? Added a unit test. Closes #28923 from Ngone51/fix-udt-equal. Authored-by: yi.wu <yi.wu@databricks.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
This commit is contained in:
parent
f944603872
commit
0ec17c989d
|
@ -90,7 +90,7 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa
|
|||
override def hashCode(): Int = getClass.hashCode()
|
||||
|
||||
override def equals(other: Any): Boolean = other match {
|
||||
case that: UserDefinedType[_] => this.acceptsType(that)
|
||||
case that: UserDefinedType[_] => this.getClass == that.getClass
|
||||
case _ => false
|
||||
}
|
||||
|
||||
|
|
|
@ -134,6 +134,24 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque
|
|||
MyLabeledPoint(1.0, new TestUDT.MyDenseVector(Array(0.1, 1.0))),
|
||||
MyLabeledPoint(0.0, new TestUDT.MyDenseVector(Array(0.3, 3.0)))).toDF()
|
||||
|
||||
|
||||
test("SPARK-32090: equal") {
|
||||
val udt1 = new ExampleBaseTypeUDT
|
||||
val udt2 = new ExampleSubTypeUDT
|
||||
val udt3 = new ExampleSubTypeUDT
|
||||
assert(udt1 !== udt2)
|
||||
assert(udt2 !== udt1)
|
||||
assert(udt2 === udt3)
|
||||
assert(udt3 === udt2)
|
||||
}
|
||||
|
||||
test("SPARK-32090: acceptsType") {
|
||||
val udt1 = new ExampleBaseTypeUDT
|
||||
val udt2 = new ExampleSubTypeUDT
|
||||
assert(udt1.acceptsType(udt2))
|
||||
assert(!udt2.acceptsType(udt1))
|
||||
}
|
||||
|
||||
test("register user type: MyDenseVector for MyLabeledPoint") {
|
||||
val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v }
|
||||
val labelsArrays: Array[Double] = labels.collect()
|
||||
|
|
Loading…
Reference in a new issue