[SPARK-13410][SQL] Support unionAll for DataFrames with UDT columns.

## What changes were proposed in this pull request?

This PR adds equality operators to UDT classes so that they can be correctly tested for dataType equality during union operations.

This was previously causing `"AnalysisException: u"unresolved operator 'Union;""` when trying to unionAll two dataframes with UDT columns as below.

```
from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT
from pyspark.sql import types

schema = types.StructType([types.StructField("point", PythonOnlyUDT(), True)])

a = sqlCtx.createDataFrame([[PythonOnlyPoint(1.0, 2.0)]], schema)
b = sqlCtx.createDataFrame([[PythonOnlyPoint(3.0, 4.0)]], schema)

c = a.unionAll(b)
```

## How was the this patch tested?

Tested using two unit tests in sql/test.py and the DataFrameSuite.

Additional information here : https://issues.apache.org/jira/browse/SPARK-13410

Author: Franklyn D'souza <franklynd@gmail.com>

Closes #11279 from damnMeddlingKid/udt-union-all.
This commit is contained in:
Franklyn D'souza 2016-02-21 16:58:17 -08:00 committed by Reynold Xin
parent 0cbadf28c9
commit 0f90f4e6ac
4 changed files with 50 additions and 1 deletions

View file

@ -601,6 +601,24 @@ class SQLTests(ReusedPySparkTestCase):
point = df1.head().point
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
def test_unionAll_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row1 = (1.0, ExamplePoint(1.0, 2.0))
row2 = (2.0, ExamplePoint(3.0, 4.0))
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
df1 = self.sqlCtx.createDataFrame([row1], schema)
df2 = self.sqlCtx.createDataFrame([row2], schema)
result = df1.unionAll(df2).orderBy("label").collect()
self.assertEqual(
result,
[
Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
Row(label=2.0, point=ExamplePoint(3.0, 4.0))
]
)
def test_column_operators(self):
ci = self.df.key
cs = self.df.value

View file

@ -86,6 +86,11 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
this.getClass == dataType.getClass
override def sql: String = sqlType.sql
override def equals(other: Any): Boolean = other match {
case that: UserDefinedType[_] => this.acceptsType(that)
case _ => false
}
}
/**
@ -112,4 +117,9 @@ private[sql] class PythonUserDefinedType(
("serializedClass" -> serializedPyClass) ~
("sqlType" -> sqlType.jsonValue)
}
override def equals(other: Any): Boolean = other match {
case that: PythonUserDefinedType => this.pyUDT.equals(that.pyUDT)
case _ => false
}
}

View file

@ -26,7 +26,12 @@ import org.apache.spark.sql.types._
* @param y y coordinate
*/
@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable
private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable {
override def equals(other: Any): Boolean = other match {
case that: ExamplePoint => this.x == that.x && this.y == that.y
case _ => false
}
}
/**
* User-defined type for [[ExamplePoint]].

View file

@ -112,6 +112,22 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
)
}
test("unionAll should union DataFrames with UDTs (SPARK-13410)") {
val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0))))
val schema1 = StructType(Array(StructField("label", IntegerType, false),
StructField("point", new ExamplePointUDT(), false)))
val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0))))
val schema2 = StructType(Array(StructField("label", IntegerType, false),
StructField("point", new ExamplePointUDT(), false)))
val df1 = sqlContext.createDataFrame(rowRDD1, schema1)
val df2 = sqlContext.createDataFrame(rowRDD2, schema2)
checkAnswer(
df1.unionAll(df2).orderBy("label"),
Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0)))
)
}
test("empty data frame") {
assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String])
assert(sqlContext.emptyDataFrame.count() === 0)