diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index b096a6db85..a08433ba79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -203,12 +203,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT // Tests to make sure that all operators correctly convert types on the way out. test("Local UDTs") { - val df = Seq((1, new UDT.MyDenseVector(Array(0.1, 1.0)))).toDF("int", "vec") - df.collect()(0).getAs[UDT.MyDenseVector](1) - df.take(1)(0).getAs[UDT.MyDenseVector](1) - df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[UDT.MyDenseVector](0) - df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0) - .getAs[UDT.MyDenseVector](0) + val vec = new UDT.MyDenseVector(Array(0.1, 1.0)) + val df = Seq((1, vec)).toDF("int", "vec") + assert(vec === df.collect()(0).getAs[UDT.MyDenseVector](1)) + assert(vec === df.take(1)(0).getAs[UDT.MyDenseVector](1)) + checkAnswer(df.limit(1).groupBy('int).agg(first('vec)), Row(1, vec)) + checkAnswer(df.orderBy('int).limit(1).groupBy('int).agg(first('vec)), Row(1, vec)) } test("UDTs with JSON") {