[SPARK-33268][SQL][PYTHON] Fix bugs for casting data from/to PythonUserDefinedType
### What changes were proposed in this pull request? This PR intends to fix bus for casting data from/to PythonUserDefinedType. A sequence of queries to reproduce this issue is as follows; ``` >>> from pyspark.sql import Row >>> from pyspark.sql.functions import col >>> from pyspark.sql.types import * >>> from pyspark.testing.sqlutils import * >>> >>> row = Row(point=ExamplePoint(1.0, 2.0)) >>> df = spark.createDataFrame([row]) >>> df.select(col("point").cast(PythonOnlyUDT())) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/Users/maropu/Repositories/spark/spark-master/python/pyspark/sql/dataframe.py", line 1402, in select jdf = self._jdf.select(self._jcols(*cols)) File "/Users/maropu/Repositories/spark/spark-master/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py", line 1305, in __call__ File "/Users/maropu/Repositories/spark/spark-master/python/pyspark/sql/utils.py", line 111, in deco return f(*a, **kw) File "/Users/maropu/Repositories/spark/spark-master/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py", line 328, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o44.select. : java.lang.NullPointerException at org.apache.spark.sql.types.UserDefinedType.acceptsType(UserDefinedType.scala:84) at org.apache.spark.sql.catalyst.expressions.Cast$.canCast(Cast.scala:96) at org.apache.spark.sql.catalyst.expressions.CastBase.checkInputDataTypes(Cast.scala:267) at org.apache.spark.sql.catalyst.expressions.CastBase.resolved$lzycompute(Cast.scala:290) at org.apache.spark.sql.catalyst.expressions.CastBase.resolved(Cast.scala:290) ``` A root cause of this issue is that, since `PythonUserDefinedType#userClassis` always null, `isAssignableFrom` in `UserDefinedType#acceptsType` throws a null exception. To fix it, this PR defines `acceptsType` in `PythonUserDefinedType` and filters out the null case in `UserDefinedType#acceptsType`. ### Why are the changes needed? Bug fixes. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added tests. Closes #30169 from maropu/FixPythonUDTCast. Authored-by: Takeshi Yamamuro <yamamuro@apache.org> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
parent
b26ae98407
commit
a6216e2446
|
@ -27,6 +27,7 @@ import unittest
|
|||
from pyspark.sql import Row
|
||||
from pyspark.sql.functions import col
|
||||
from pyspark.sql.udf import UserDefinedFunction
|
||||
from pyspark.sql.utils import AnalysisException
|
||||
from pyspark.sql.types import ByteType, ShortType, IntegerType, FloatType, DateType, \
|
||||
TimestampType, MapType, StringType, StructType, StructField, ArrayType, DoubleType, LongType, \
|
||||
DecimalType, BinaryType, BooleanType, NullType
|
||||
|
@ -441,6 +442,14 @@ class TypesTests(ReusedSQLTestCase):
|
|||
result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head()
|
||||
self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]'))
|
||||
|
||||
def test_cast_to_udt_with_udt(self):
|
||||
from pyspark.sql.functions import col
|
||||
row = Row(point=ExamplePoint(1.0, 2.0), python_only_point=PythonOnlyPoint(1.0, 2.0))
|
||||
df = self.spark.createDataFrame([row])
|
||||
self.assertRaises(AnalysisException, lambda: df.select(col("point").cast(PythonOnlyUDT())))
|
||||
self.assertRaises(AnalysisException,
|
||||
lambda: df.select(col("python_only_point").cast(ExamplePointUDT())))
|
||||
|
||||
def test_struct_type(self):
|
||||
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
|
||||
struct2 = StructType([StructField("f1", StringType(), True),
|
||||
|
|
|
@ -78,8 +78,8 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa
|
|||
*/
|
||||
override private[spark] def asNullable: UserDefinedType[UserType] = this
|
||||
|
||||
override private[sql] def acceptsType(dataType: DataType) = dataType match {
|
||||
case other: UserDefinedType[_] =>
|
||||
override private[sql] def acceptsType(dataType: DataType): Boolean = dataType match {
|
||||
case other: UserDefinedType[_] if this.userClass != null && other.userClass != null =>
|
||||
this.getClass == other.getClass ||
|
||||
this.userClass.isAssignableFrom(other.userClass)
|
||||
case _ => false
|
||||
|
@ -131,6 +131,11 @@ private[sql] class PythonUserDefinedType(
|
|||
("sqlType" -> sqlType.jsonValue)
|
||||
}
|
||||
|
||||
override private[sql] def acceptsType(dataType: DataType): Boolean = dataType match {
|
||||
case other: PythonUserDefinedType => pyUDT == other.pyUDT
|
||||
case _ => false
|
||||
}
|
||||
|
||||
override def equals(other: Any): Boolean = other match {
|
||||
case that: PythonUserDefinedType => pyUDT == that.pyUDT
|
||||
case _ => false
|
||||
|
|
Loading…
Reference in a new issue