[SPARK-32090][SQL] Improve UserDefinedType.equal() to make it be symmetrical

### What changes were proposed in this pull request?

This PR fix `UserDefinedType.equal()` by comparing the UDT class instead of checking `acceptsType()`.

### Why are the changes needed?

It's weird that equality comparison between two UDT types can have different result by switching the order:

```scala
// ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass
val udt1 = new ExampleBaseTypeUDT
val udt2 = new ExampleSubTypeUDT
println(udt1 == udt2) // true
println(udt2 == udt1) // false
```

### Does this PR introduce _any_ user-facing change?

Yes.

Before:
```scala
// ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass
val udt1 = new ExampleBaseTypeUDT
val udt2 = new ExampleSubTypeUDT
println(udt1 == udt2) // true
println(udt2 == udt1) // false
```

After:
```scala
// ExampleSubTypeUDT.userClass is a subclass of ExampleBaseTypeUDT.userClass
val udt1 = new ExampleBaseTypeUDT
val udt2 = new ExampleSubTypeUDT
println(udt1 == udt2) // false
println(udt2 == udt1) // false
```

### How was this patch tested?

Added a unit test.

Closes #28923 from Ngone51/fix-udt-equal.

Authored-by: yi.wu <yi.wu@databricks.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
This commit is contained in:
yi.wu 2020-06-28 21:49:10 -07:00 committed by Dongjoon Hyun
parent f944603872
commit 0ec17c989d
2 changed files with 19 additions and 1 deletions

View file

@ -90,7 +90,7 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa
override def hashCode(): Int = getClass.hashCode()
override def equals(other: Any): Boolean = other match {
case that: UserDefinedType[_] => this.acceptsType(that)
case that: UserDefinedType[_] => this.getClass == that.getClass
case _ => false
}

View file

@ -134,6 +134,24 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque
MyLabeledPoint(1.0, new TestUDT.MyDenseVector(Array(0.1, 1.0))),
MyLabeledPoint(0.0, new TestUDT.MyDenseVector(Array(0.3, 3.0)))).toDF()
test("SPARK-32090: equal") {
val udt1 = new ExampleBaseTypeUDT
val udt2 = new ExampleSubTypeUDT
val udt3 = new ExampleSubTypeUDT
assert(udt1 !== udt2)
assert(udt2 !== udt1)
assert(udt2 === udt3)
assert(udt3 === udt2)
}
test("SPARK-32090: acceptsType") {
val udt1 = new ExampleBaseTypeUDT
val udt2 = new ExampleSubTypeUDT
assert(udt1.acceptsType(udt2))
assert(!udt2.acceptsType(udt1))
}
test("register user type: MyDenseVector for MyLabeledPoint") {
val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v }
val labelsArrays: Array[Double] = labels.collect()