[SPARK-11775][PYSPARK][SQL] Allow PySpark to register Java UDF

Currently pyspark can only call the builtin java UDF, but can not call custom java UDF. It would be better to allow that. 2 benefits:
* Leverage the power of rich third party java library
* Improve the performance. Because if we use python UDF, python daemons will be started on worker which will affect the performance.

Author: Jeff Zhang <zjffdu@apache.org>

Closes #9766 from zjffdu/SPARK-11775.
This commit is contained in:
Jeff Zhang 2016-10-14 15:50:35 -07:00 committed by Michael Armbrust
parent 5aeb7384c7
commit f00df40cfe
5 changed files with 152 additions and 4 deletions

View file

@ -28,7 +28,7 @@ from pyspark.sql.session import _monkey_patch_RDD, SparkSession
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.streaming import DataStreamReader
from pyspark.sql.types import Row, StringType
from pyspark.sql.types import IntegerType, Row, StringType
from pyspark.sql.utils import install_exception_handler
__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
@ -202,6 +202,32 @@ class SQLContext(object):
"""
self.sparkSession.catalog.registerFunction(name, f, returnType)
@ignore_unicode_prefix
@since(2.1)
def registerJavaFunction(self, name, javaClassName, returnType=None):
"""Register a java UDF so it can be used in SQL statements.
In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not specified we would infer it via reflection.
:param name: name of the UDF
:param javaClassName: fully qualified name of java class
:param returnType: a :class:`pyspark.sql.types.DataType` object
>>> sqlContext.registerJavaFunction("javaStringLength",
... "test.org.apache.spark.sql.JavaStringLength", IntegerType())
>>> sqlContext.sql("SELECT javaStringLength('test')").collect()
[Row(UDF(test)=4)]
>>> sqlContext.registerJavaFunction("javaStringLength2",
... "test.org.apache.spark.sql.JavaStringLength")
>>> sqlContext.sql("SELECT javaStringLength2('test')").collect()
[Row(UDF(test)=4)]
"""
jdt = None
if returnType is not None:
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
# TODO(andrew): delete this once we refactor things to take in SparkSession
def _inferSchema(self, rdd, samplingRatio=None):
"""

View file

@ -59,7 +59,7 @@ object JavaTypeInference {
* @param typeToken Java type
* @return (SQL data type, nullable)
*/
private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
private[sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
typeToken.getRawType match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)

View file

@ -17,19 +17,25 @@
package org.apache.spark.sql
import java.io.IOException
import java.lang.reflect.{ParameterizedType, Type}
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try
import com.google.common.reflect.TypeToken
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.internal.Logging
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, DataTypes}
import org.apache.spark.util.Utils
/**
* Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this.
@ -413,6 +419,71 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////
/**
* Register a Java UDF class using reflection, for use from pyspark
*
* @param name udf name
* @param className fully qualified class name of udf
* @param returnDataType return type of udf. If it is null, spark would try to infer
* via reflection.
*/
private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = {
try {
val clazz = Utils.classForName(className)
val udfInterfaces = clazz.getGenericInterfaces
.filter(_.isInstanceOf[ParameterizedType])
.map(_.asInstanceOf[ParameterizedType])
.filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF"))
if (udfInterfaces.length == 0) {
throw new IOException(s"UDF class ${className} doesn't implement any UDF interface")
} else if (udfInterfaces.length > 1) {
throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}")
} else {
try {
val udf = clazz.newInstance()
val udfReturnType = udfInterfaces(0).getActualTypeArguments.last
var returnType = returnDataType
if (returnType == null) {
returnType = JavaTypeInference.inferDataType(TypeToken.of(udfReturnType))._1
}
udfInterfaces(0).getActualTypeArguments.length match {
case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType)
case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType)
case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType)
case 5 => register(name, udf.asInstanceOf[UDF4[_, _, _, _, _]], returnType)
case 6 => register(name, udf.asInstanceOf[UDF5[_, _, _, _, _, _]], returnType)
case 7 => register(name, udf.asInstanceOf[UDF6[_, _, _, _, _, _, _]], returnType)
case 8 => register(name, udf.asInstanceOf[UDF7[_, _, _, _, _, _, _, _]], returnType)
case 9 => register(name, udf.asInstanceOf[UDF8[_, _, _, _, _, _, _, _, _]], returnType)
case 10 => register(name, udf.asInstanceOf[UDF9[_, _, _, _, _, _, _, _, _, _]], returnType)
case 11 => register(name, udf.asInstanceOf[UDF10[_, _, _, _, _, _, _, _, _, _, _]], returnType)
case 12 => register(name, udf.asInstanceOf[UDF11[_, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 13 => register(name, udf.asInstanceOf[UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 14 => register(name, udf.asInstanceOf[UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 15 => register(name, udf.asInstanceOf[UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 16 => register(name, udf.asInstanceOf[UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 17 => register(name, udf.asInstanceOf[UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 18 => register(name, udf.asInstanceOf[UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 19 => register(name, udf.asInstanceOf[UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 20 => register(name, udf.asInstanceOf[UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case n => logError(s"UDF class with ${n} type arguments is not supported ")
}
} catch {
case e @ (_: InstantiationException | _: IllegalArgumentException) =>
logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor")
}
}
} catch {
case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath")
}
}
/**
* Register a user-defined function with 1 arguments.
* @since 1.3.0

View file

@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package test.org.apache.spark.sql;
import org.apache.spark.sql.api.java.UDF1;
/**
* It is used for register Java UDF from PySpark
*/
public class JavaStringLength implements UDF1<String, Integer> {
@Override
public Integer call(String str) throws Exception {
return new Integer(str.length());
}
}

View file

@ -87,4 +87,25 @@ public class JavaUDFSuite implements Serializable {
Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
Assert.assertEquals(9, result.getInt(0));
}
public static class StringLengthTest implements UDF2<String, String, Integer> {
@Override
public Integer call(String str1, String str2) throws Exception {
return new Integer(str1.length() + str2.length());
}
}
@SuppressWarnings("unchecked")
@Test
public void udf3Test() {
spark.udf().registerJava("stringLengthTest", StringLengthTest.class.getName(),
DataTypes.IntegerType);
Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
Assert.assertEquals(9, result.getInt(0));
// returnType is not provided
spark.udf().registerJava("stringLengthTest2", StringLengthTest.class.getName(), null);
result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
Assert.assertEquals(9, result.getInt(0));
}
}