[SPARK-33268][SQL][PYTHON] Fix bugs for casting data from/to PythonUserDefinedType

### What changes were proposed in this pull request?

This PR intends to fix bus for casting data from/to PythonUserDefinedType. A sequence of queries to reproduce this issue is as follows;
```
>>> from pyspark.sql import Row
>>> from pyspark.sql.functions import col
>>> from pyspark.sql.types import *
>>> from pyspark.testing.sqlutils import *
>>>
>>> row = Row(point=ExamplePoint(1.0, 2.0))
>>> df = spark.createDataFrame([row])
>>> df.select(col("point").cast(PythonOnlyUDT()))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/maropu/Repositories/spark/spark-master/python/pyspark/sql/dataframe.py", line 1402, in select
    jdf = self._jdf.select(self._jcols(*cols))
  File "/Users/maropu/Repositories/spark/spark-master/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py", line 1305, in __call__
  File "/Users/maropu/Repositories/spark/spark-master/python/pyspark/sql/utils.py", line 111, in deco
    return f(*a, **kw)
  File "/Users/maropu/Repositories/spark/spark-master/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py", line 328, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling o44.select.
: java.lang.NullPointerException
	at org.apache.spark.sql.types.UserDefinedType.acceptsType(UserDefinedType.scala:84)
	at org.apache.spark.sql.catalyst.expressions.Cast$.canCast(Cast.scala:96)
	at org.apache.spark.sql.catalyst.expressions.CastBase.checkInputDataTypes(Cast.scala:267)
	at org.apache.spark.sql.catalyst.expressions.CastBase.resolved$lzycompute(Cast.scala:290)
	at org.apache.spark.sql.catalyst.expressions.CastBase.resolved(Cast.scala:290)
```
A root cause of this issue is that, since `PythonUserDefinedType#userClassis` always null, `isAssignableFrom` in `UserDefinedType#acceptsType` throws a null exception. To fix it, this PR defines  `acceptsType` in `PythonUserDefinedType` and filters out the null case in `UserDefinedType#acceptsType`.

### Why are the changes needed?

Bug fixes.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Added tests.

Closes #30169 from maropu/FixPythonUDTCast.

Authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
Takeshi Yamamuro 2020-10-28 08:33:02 -07:00 committed by Dongjoon Hyun
parent b26ae98407
commit a6216e2446
2 changed files with 16 additions and 2 deletions

View file

@ -27,6 +27,7 @@ import unittest
from pyspark.sql import Row
from pyspark.sql.functions import col
from pyspark.sql.udf import UserDefinedFunction
from pyspark.sql.utils import AnalysisException
from pyspark.sql.types import ByteType, ShortType, IntegerType, FloatType, DateType, \
TimestampType, MapType, StringType, StructType, StructField, ArrayType, DoubleType, LongType, \
DecimalType, BinaryType, BooleanType, NullType
@ -441,6 +442,14 @@ class TypesTests(ReusedSQLTestCase):
result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head()
self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]'))
def test_cast_to_udt_with_udt(self):
from pyspark.sql.functions import col
row = Row(point=ExamplePoint(1.0, 2.0), python_only_point=PythonOnlyPoint(1.0, 2.0))
df = self.spark.createDataFrame([row])
self.assertRaises(AnalysisException, lambda: df.select(col("point").cast(PythonOnlyUDT())))
self.assertRaises(AnalysisException,
lambda: df.select(col("python_only_point").cast(ExamplePointUDT())))
def test_struct_type(self):
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
struct2 = StructType([StructField("f1", StringType(), True),

View file

@ -78,8 +78,8 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa
*/
override private[spark] def asNullable: UserDefinedType[UserType] = this
override private[sql] def acceptsType(dataType: DataType) = dataType match {
case other: UserDefinedType[_] =>
override private[sql] def acceptsType(dataType: DataType): Boolean = dataType match {
case other: UserDefinedType[_] if this.userClass != null && other.userClass != null =>
this.getClass == other.getClass ||
this.userClass.isAssignableFrom(other.userClass)
case _ => false
@ -131,6 +131,11 @@ private[sql] class PythonUserDefinedType(
("sqlType" -> sqlType.jsonValue)
}
override private[sql] def acceptsType(dataType: DataType): Boolean = dataType match {
case other: PythonUserDefinedType => pyUDT == other.pyUDT
case _ => false
}
override def equals(other: Any): Boolean = other match {
case that: PythonUserDefinedType => pyUDT == that.pyUDT
case _ => false