From 0182817ea3ac11742f158a5d04cc8322fd992d14 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Mon, 28 Oct 2019 20:50:34 +0900 Subject: [PATCH] [SPARK-28158][SQL] Hive UDFs supports UDT type ## What changes were proposed in this pull request? After this PR, we can create and register Hive UDFs to accept UDT type, like `VectorUDT` and `MatrixUDT`. These UDTs are widely used in Spark machine learning. ## How was this patch tested? add new ut Closes #24961 from uncleGen/SPARK-28158. Authored-by: uncleGen Signed-off-by: HyukjinKwon --- .../spark/sql/hive/HiveInspectors.scala | 5 ++ .../spark/sql/hive/HiveInspectorSuite.scala | 8 +- .../sql/hive/HiveUserDefinedTypeSuite.scala | 73 +++++++++++++++++++ 3 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 33b5bcefd8..5b627b8164 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -787,6 +787,9 @@ private[hive] trait HiveInspectors { ObjectInspectorFactory.getStandardStructObjectInspector( java.util.Arrays.asList(fields.map(f => f.name) : _*), java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) : _*)) + case _: UserDefinedType[_] => + val sqlType = dataType.asInstanceOf[UserDefinedType[_]].sqlType + toInspector(sqlType) } /** @@ -849,6 +852,8 @@ private[hive] trait HiveInspectors { } case Literal(_, dt: StructType) => toInspector(dt) + case Literal(_, dt: UserDefinedType[_]) => + toInspector(dt.sqlType) // We will enumerate all of the possible constant expressions, throw exception if we missed case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].") // ideally, we don't test the foldable here(but in optimizer), however, some of the diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index c300660458..5912992694 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.io.LongWritable import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Row, TestUserClassUDT} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} @@ -214,6 +214,12 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { }) } + test("wrap / unwrap UDT Type") { + val dt = new TestUserClassUDT + checkValue(1, unwrap(wrap(1, toInspector(dt), dt), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) + } + test("wrap / unwrap Struct Type") { val dt = StructType(dataTypes.zipWithIndex.map { case (t, idx) => StructField(s"c_$idx", t) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala new file mode 100644 index 0000000000..bddb7688fe --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, StandardListObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory + +import org.apache.spark.sql.{QueryTest, RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT} +import org.apache.spark.sql.types.StructType + +class HiveUserDefinedTypeSuite extends QueryTest with TestHiveSingleton { + private val functionClass = classOf[org.apache.spark.sql.hive.TestUDF].getCanonicalName + + test("Support UDT in Hive UDF") { + val functionName = "get_point_x" + try { + val schema = new StructType().add("point", new ExamplePointUDT) + val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get + val input = inputGenerator.apply().asInstanceOf[Row] + val df = spark.createDataFrame(Array(input).toList.asJava, schema) + df.createOrReplaceTempView("src") + spark.sql(s"CREATE FUNCTION $functionName AS '$functionClass'") + + checkAnswer( + spark.sql(s"SELECT $functionName(point) FROM src"), + Row(input.getAs[ExamplePoint](0).x)) + } finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // temp tables that never got created. + spark.sql(s"DROP FUNCTION IF EXISTS $functionName") + assert( + !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), + s"Function $functionName should have been dropped. But, it still exists.") + } + } +} + +class TestUDF extends GenericUDF { + private var data: StandardListObjectInspector = _ + + override def getDisplayString(children: Array[String]): String = "get_point_x" + + override def initialize(arguments: Array[ObjectInspector]): ObjectInspector = { + data = arguments(0).asInstanceOf[StandardListObjectInspector] + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector + } + + override def evaluate(arguments: Array[GenericUDF.DeferredObject]): AnyRef = { + val point = data.getList(arguments(0).get()) + new java.lang.Double(point.get(0).asInstanceOf[Double]) + } +}