[SPARK-11775][PYSPARK][SQL] Allow PySpark to register Java UDF
Currently pyspark can only call the builtin java UDF, but can not call custom java UDF. It would be better to allow that. 2 benefits: * Leverage the power of rich third party java library * Improve the performance. Because if we use python UDF, python daemons will be started on worker which will affect the performance. Author: Jeff Zhang <zjffdu@apache.org> Closes #9766 from zjffdu/SPARK-11775.
This commit is contained in:
parent
5aeb7384c7
commit
f00df40cfe
|
@ -28,7 +28,7 @@ from pyspark.sql.session import _monkey_patch_RDD, SparkSession
|
|||
from pyspark.sql.dataframe import DataFrame
|
||||
from pyspark.sql.readwriter import DataFrameReader
|
||||
from pyspark.sql.streaming import DataStreamReader
|
||||
from pyspark.sql.types import Row, StringType
|
||||
from pyspark.sql.types import IntegerType, Row, StringType
|
||||
from pyspark.sql.utils import install_exception_handler
|
||||
|
||||
__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
|
||||
|
@ -202,6 +202,32 @@ class SQLContext(object):
|
|||
"""
|
||||
self.sparkSession.catalog.registerFunction(name, f, returnType)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(2.1)
|
||||
def registerJavaFunction(self, name, javaClassName, returnType=None):
|
||||
"""Register a java UDF so it can be used in SQL statements.
|
||||
|
||||
In addition to a name and the function itself, the return type can be optionally specified.
|
||||
When the return type is not specified we would infer it via reflection.
|
||||
:param name: name of the UDF
|
||||
:param javaClassName: fully qualified name of java class
|
||||
:param returnType: a :class:`pyspark.sql.types.DataType` object
|
||||
|
||||
>>> sqlContext.registerJavaFunction("javaStringLength",
|
||||
... "test.org.apache.spark.sql.JavaStringLength", IntegerType())
|
||||
>>> sqlContext.sql("SELECT javaStringLength('test')").collect()
|
||||
[Row(UDF(test)=4)]
|
||||
>>> sqlContext.registerJavaFunction("javaStringLength2",
|
||||
... "test.org.apache.spark.sql.JavaStringLength")
|
||||
>>> sqlContext.sql("SELECT javaStringLength2('test')").collect()
|
||||
[Row(UDF(test)=4)]
|
||||
|
||||
"""
|
||||
jdt = None
|
||||
if returnType is not None:
|
||||
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
|
||||
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
|
||||
|
||||
# TODO(andrew): delete this once we refactor things to take in SparkSession
|
||||
def _inferSchema(self, rdd, samplingRatio=None):
|
||||
"""
|
||||
|
|
|
@ -59,7 +59,7 @@ object JavaTypeInference {
|
|||
* @param typeToken Java type
|
||||
* @return (SQL data type, nullable)
|
||||
*/
|
||||
private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
|
||||
private[sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
|
||||
typeToken.getRawType match {
|
||||
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
|
||||
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
|
||||
|
|
|
@ -17,19 +17,25 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import java.io.IOException
|
||||
import java.lang.reflect.{ParameterizedType, Type}
|
||||
|
||||
import scala.reflect.runtime.universe.TypeTag
|
||||
import scala.util.Try
|
||||
|
||||
import com.google.common.reflect.TypeToken
|
||||
|
||||
import org.apache.spark.annotation.InterfaceStability
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.api.java._
|
||||
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
|
||||
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
|
||||
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
|
||||
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
|
||||
import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction}
|
||||
import org.apache.spark.sql.types.DataType
|
||||
import org.apache.spark.sql.types.{DataType, DataTypes}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
* Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this.
|
||||
|
@ -413,6 +419,71 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
|
|||
//////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* Register a Java UDF class using reflection, for use from pyspark
|
||||
*
|
||||
* @param name udf name
|
||||
* @param className fully qualified class name of udf
|
||||
* @param returnDataType return type of udf. If it is null, spark would try to infer
|
||||
* via reflection.
|
||||
*/
|
||||
private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = {
|
||||
|
||||
try {
|
||||
val clazz = Utils.classForName(className)
|
||||
val udfInterfaces = clazz.getGenericInterfaces
|
||||
.filter(_.isInstanceOf[ParameterizedType])
|
||||
.map(_.asInstanceOf[ParameterizedType])
|
||||
.filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF"))
|
||||
if (udfInterfaces.length == 0) {
|
||||
throw new IOException(s"UDF class ${className} doesn't implement any UDF interface")
|
||||
} else if (udfInterfaces.length > 1) {
|
||||
throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}")
|
||||
} else {
|
||||
try {
|
||||
val udf = clazz.newInstance()
|
||||
val udfReturnType = udfInterfaces(0).getActualTypeArguments.last
|
||||
var returnType = returnDataType
|
||||
if (returnType == null) {
|
||||
returnType = JavaTypeInference.inferDataType(TypeToken.of(udfReturnType))._1
|
||||
}
|
||||
|
||||
udfInterfaces(0).getActualTypeArguments.length match {
|
||||
case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType)
|
||||
case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType)
|
||||
case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType)
|
||||
case 5 => register(name, udf.asInstanceOf[UDF4[_, _, _, _, _]], returnType)
|
||||
case 6 => register(name, udf.asInstanceOf[UDF5[_, _, _, _, _, _]], returnType)
|
||||
case 7 => register(name, udf.asInstanceOf[UDF6[_, _, _, _, _, _, _]], returnType)
|
||||
case 8 => register(name, udf.asInstanceOf[UDF7[_, _, _, _, _, _, _, _]], returnType)
|
||||
case 9 => register(name, udf.asInstanceOf[UDF8[_, _, _, _, _, _, _, _, _]], returnType)
|
||||
case 10 => register(name, udf.asInstanceOf[UDF9[_, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case 11 => register(name, udf.asInstanceOf[UDF10[_, _, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case 12 => register(name, udf.asInstanceOf[UDF11[_, _, _, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case 13 => register(name, udf.asInstanceOf[UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case 14 => register(name, udf.asInstanceOf[UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case 15 => register(name, udf.asInstanceOf[UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case 16 => register(name, udf.asInstanceOf[UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case 17 => register(name, udf.asInstanceOf[UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case 18 => register(name, udf.asInstanceOf[UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case 19 => register(name, udf.asInstanceOf[UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case 20 => register(name, udf.asInstanceOf[UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case n => logError(s"UDF class with ${n} type arguments is not supported ")
|
||||
}
|
||||
} catch {
|
||||
case e @ (_: InstantiationException | _: IllegalArgumentException) =>
|
||||
logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor")
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a user-defined function with 1 arguments.
|
||||
* @since 1.3.0
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package test.org.apache.spark.sql;
|
||||
|
||||
import org.apache.spark.sql.api.java.UDF1;
|
||||
|
||||
/**
|
||||
* It is used for register Java UDF from PySpark
|
||||
*/
|
||||
public class JavaStringLength implements UDF1<String, Integer> {
|
||||
@Override
|
||||
public Integer call(String str) throws Exception {
|
||||
return new Integer(str.length());
|
||||
}
|
||||
}
|
|
@ -87,4 +87,25 @@ public class JavaUDFSuite implements Serializable {
|
|||
Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
|
||||
Assert.assertEquals(9, result.getInt(0));
|
||||
}
|
||||
|
||||
public static class StringLengthTest implements UDF2<String, String, Integer> {
|
||||
@Override
|
||||
public Integer call(String str1, String str2) throws Exception {
|
||||
return new Integer(str1.length() + str2.length());
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Test
|
||||
public void udf3Test() {
|
||||
spark.udf().registerJava("stringLengthTest", StringLengthTest.class.getName(),
|
||||
DataTypes.IntegerType);
|
||||
Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
|
||||
Assert.assertEquals(9, result.getInt(0));
|
||||
|
||||
// returnType is not provided
|
||||
spark.udf().registerJava("stringLengthTest2", StringLengthTest.class.getName(), null);
|
||||
result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
|
||||
Assert.assertEquals(9, result.getInt(0));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue