[SPARK-28158][SQL] Hive UDFs supports UDT type

## What changes were proposed in this pull request?

After this PR, we can create and register Hive UDFs to accept UDT type, like `VectorUDT` and `MatrixUDT`. These UDTs are widely used in Spark machine learning.

## How was this patch tested?

add new ut

Closes #24961 from uncleGen/SPARK-28158.

Authored-by: uncleGen <hustyugm@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
uncleGen 2019-10-28 20:50:34 +09:00 committed by HyukjinKwon
parent a8d5134981
commit 0182817ea3
3 changed files with 85 additions and 1 deletions

View file

@ -787,6 +787,9 @@ private[hive] trait HiveInspectors {
ObjectInspectorFactory.getStandardStructObjectInspector(
java.util.Arrays.asList(fields.map(f => f.name) : _*),
java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) : _*))
case _: UserDefinedType[_] =>
val sqlType = dataType.asInstanceOf[UserDefinedType[_]].sqlType
toInspector(sqlType)
}
/**
@ -849,6 +852,8 @@ private[hive] trait HiveInspectors {
}
case Literal(_, dt: StructType) =>
toInspector(dt)
case Literal(_, dt: UserDefinedType[_]) =>
toInspector(dt.sqlType)
// We will enumerate all of the possible constant expressions, throw exception if we missed
case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].")
// ideally, we don't test the foldable here(but in optimizer), however, some of the

View file

@ -28,7 +28,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
import org.apache.hadoop.io.LongWritable
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.{Row, TestUserClassUDT}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
@ -214,6 +214,12 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
})
}
test("wrap / unwrap UDT Type") {
val dt = new TestUserClassUDT
checkValue(1, unwrap(wrap(1, toInspector(dt), dt), toInspector(dt)))
checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt)))
}
test("wrap / unwrap Struct Type") {
val dt = StructType(dataTypes.zipWithIndex.map {
case (t, idx) => StructField(s"c_$idx", t)

View file

@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hive
import scala.collection.JavaConverters._
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, StandardListObjectInspector}
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.spark.sql.{QueryTest, RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT}
import org.apache.spark.sql.types.StructType
class HiveUserDefinedTypeSuite extends QueryTest with TestHiveSingleton {
private val functionClass = classOf[org.apache.spark.sql.hive.TestUDF].getCanonicalName
test("Support UDT in Hive UDF") {
val functionName = "get_point_x"
try {
val schema = new StructType().add("point", new ExamplePointUDT)
val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get
val input = inputGenerator.apply().asInstanceOf[Row]
val df = spark.createDataFrame(Array(input).toList.asJava, schema)
df.createOrReplaceTempView("src")
spark.sql(s"CREATE FUNCTION $functionName AS '$functionClass'")
checkAnswer(
spark.sql(s"SELECT $functionName(point) FROM src"),
Row(input.getAs[ExamplePoint](0).x))
} finally {
// If the test failed part way, we don't want to mask the failure by failing to remove
// temp tables that never got created.
spark.sql(s"DROP FUNCTION IF EXISTS $functionName")
assert(
!spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)),
s"Function $functionName should have been dropped. But, it still exists.")
}
}
}
class TestUDF extends GenericUDF {
private var data: StandardListObjectInspector = _
override def getDisplayString(children: Array[String]): String = "get_point_x"
override def initialize(arguments: Array[ObjectInspector]): ObjectInspector = {
data = arguments(0).asInstanceOf[StandardListObjectInspector]
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector
}
override def evaluate(arguments: Array[GenericUDF.DeferredObject]): AnyRef = {
val point = data.getList(arguments(0).get())
new java.lang.Double(point.get(0).asInstanceOf[Double])
}
}