[SPARK-28158][SQL] Hive UDFs supports UDT type
## What changes were proposed in this pull request? After this PR, we can create and register Hive UDFs to accept UDT type, like `VectorUDT` and `MatrixUDT`. These UDTs are widely used in Spark machine learning. ## How was this patch tested? add new ut Closes #24961 from uncleGen/SPARK-28158. Authored-by: uncleGen <hustyugm@gmail.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
a8d5134981
commit
0182817ea3
|
@ -787,6 +787,9 @@ private[hive] trait HiveInspectors {
|
|||
ObjectInspectorFactory.getStandardStructObjectInspector(
|
||||
java.util.Arrays.asList(fields.map(f => f.name) : _*),
|
||||
java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) : _*))
|
||||
case _: UserDefinedType[_] =>
|
||||
val sqlType = dataType.asInstanceOf[UserDefinedType[_]].sqlType
|
||||
toInspector(sqlType)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -849,6 +852,8 @@ private[hive] trait HiveInspectors {
|
|||
}
|
||||
case Literal(_, dt: StructType) =>
|
||||
toInspector(dt)
|
||||
case Literal(_, dt: UserDefinedType[_]) =>
|
||||
toInspector(dt.sqlType)
|
||||
// We will enumerate all of the possible constant expressions, throw exception if we missed
|
||||
case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].")
|
||||
// ideally, we don't test the foldable here(but in optimizer), however, some of the
|
||||
|
|
|
@ -28,7 +28,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
|
|||
import org.apache.hadoop.io.LongWritable
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.{Row, TestUserClassUDT}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.Literal
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
|
||||
|
@ -214,6 +214,12 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
|
|||
})
|
||||
}
|
||||
|
||||
test("wrap / unwrap UDT Type") {
|
||||
val dt = new TestUserClassUDT
|
||||
checkValue(1, unwrap(wrap(1, toInspector(dt), dt), toInspector(dt)))
|
||||
checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt)))
|
||||
}
|
||||
|
||||
test("wrap / unwrap Struct Type") {
|
||||
val dt = StructType(dataTypes.zipWithIndex.map {
|
||||
case (t, idx) => StructField(s"c_$idx", t)
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hive
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF
|
||||
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, StandardListObjectInspector}
|
||||
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
|
||||
|
||||
import org.apache.spark.sql.{QueryTest, RandomDataGenerator, Row}
|
||||
import org.apache.spark.sql.catalyst.FunctionIdentifier
|
||||
import org.apache.spark.sql.hive.test.TestHiveSingleton
|
||||
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
class HiveUserDefinedTypeSuite extends QueryTest with TestHiveSingleton {
|
||||
private val functionClass = classOf[org.apache.spark.sql.hive.TestUDF].getCanonicalName
|
||||
|
||||
test("Support UDT in Hive UDF") {
|
||||
val functionName = "get_point_x"
|
||||
try {
|
||||
val schema = new StructType().add("point", new ExamplePointUDT)
|
||||
val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get
|
||||
val input = inputGenerator.apply().asInstanceOf[Row]
|
||||
val df = spark.createDataFrame(Array(input).toList.asJava, schema)
|
||||
df.createOrReplaceTempView("src")
|
||||
spark.sql(s"CREATE FUNCTION $functionName AS '$functionClass'")
|
||||
|
||||
checkAnswer(
|
||||
spark.sql(s"SELECT $functionName(point) FROM src"),
|
||||
Row(input.getAs[ExamplePoint](0).x))
|
||||
} finally {
|
||||
// If the test failed part way, we don't want to mask the failure by failing to remove
|
||||
// temp tables that never got created.
|
||||
spark.sql(s"DROP FUNCTION IF EXISTS $functionName")
|
||||
assert(
|
||||
!spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)),
|
||||
s"Function $functionName should have been dropped. But, it still exists.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TestUDF extends GenericUDF {
|
||||
private var data: StandardListObjectInspector = _
|
||||
|
||||
override def getDisplayString(children: Array[String]): String = "get_point_x"
|
||||
|
||||
override def initialize(arguments: Array[ObjectInspector]): ObjectInspector = {
|
||||
data = arguments(0).asInstanceOf[StandardListObjectInspector]
|
||||
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector
|
||||
}
|
||||
|
||||
override def evaluate(arguments: Array[GenericUDF.DeferredObject]): AnyRef = {
|
||||
val point = data.getList(arguments(0).get())
|
||||
new java.lang.Double(point.get(0).asInstanceOf[Double])
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue