[SPARK-19439][PYSPARK][SQL] PySpark's registerJavaFunction Should Support UDAFs
## What changes were proposed in this pull request? Support register Java UDAFs in PySpark so that user can use Java UDAF in PySpark. Besides that I also add api in `UDFRegistration` ## How was this patch tested? Unit test is added Author: Jeff Zhang <zjffdu@apache.org> Closes #17222 from zjffdu/SPARK-19439.
This commit is contained in:
parent
960298ee66
commit
742da08685
|
@ -232,6 +232,23 @@ class SQLContext(object):
|
|||
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
|
||||
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(2.3)
|
||||
def registerJavaUDAF(self, name, javaClassName):
|
||||
"""Register a java UDAF so it can be used in SQL statements.
|
||||
|
||||
:param name: name of the UDAF
|
||||
:param javaClassName: fully qualified name of java class
|
||||
|
||||
>>> sqlContext.registerJavaUDAF("javaUDAF",
|
||||
... "test.org.apache.spark.sql.MyDoubleAvg")
|
||||
>>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
|
||||
>>> df.registerTempTable("df")
|
||||
>>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect()
|
||||
[Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
|
||||
"""
|
||||
self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)
|
||||
|
||||
# TODO(andrew): delete this once we refactor things to take in SparkSession
|
||||
def _inferSchema(self, rdd, samplingRatio=None):
|
||||
"""
|
||||
|
@ -551,6 +568,12 @@ class UDFRegistration(object):
|
|||
def register(self, name, f, returnType=StringType()):
|
||||
return self.sqlContext.registerFunction(name, f, returnType)
|
||||
|
||||
def registerJavaFunction(self, name, javaClassName, returnType=None):
|
||||
self.sqlContext.registerJavaFunction(name, javaClassName, returnType)
|
||||
|
||||
def registerJavaUDAF(self, name, javaClassName):
|
||||
self.sqlContext.registerJavaUDAF(name, javaClassName)
|
||||
|
||||
register.__doc__ = SQLContext.registerFunction.__doc__
|
||||
|
||||
|
||||
|
|
|
@ -481,6 +481,16 @@ class SQLTests(ReusedPySparkTestCase):
|
|||
df.select(add_three("id").alias("plus_three")).collect()
|
||||
)
|
||||
|
||||
def test_non_existed_udf(self):
|
||||
spark = self.spark
|
||||
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
|
||||
lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
|
||||
|
||||
def test_non_existed_udaf(self):
|
||||
spark = self.spark
|
||||
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
|
||||
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
|
||||
|
||||
def test_multiLine_json(self):
|
||||
people1 = self.spark.read.json("python/test_support/sql/people.json")
|
||||
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import java.io.IOException
|
||||
import java.lang.reflect.{ParameterizedType, Type}
|
||||
|
||||
import scala.reflect.runtime.universe.TypeTag
|
||||
|
@ -456,9 +455,9 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
|
|||
.map(_.asInstanceOf[ParameterizedType])
|
||||
.filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF"))
|
||||
if (udfInterfaces.length == 0) {
|
||||
throw new IOException(s"UDF class ${className} doesn't implement any UDF interface")
|
||||
throw new AnalysisException(s"UDF class ${className} doesn't implement any UDF interface")
|
||||
} else if (udfInterfaces.length > 1) {
|
||||
throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}")
|
||||
throw new AnalysisException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}")
|
||||
} else {
|
||||
try {
|
||||
val udf = clazz.newInstance()
|
||||
|
@ -491,19 +490,41 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
|
|||
case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
|
||||
case n => logError(s"UDF class with ${n} type arguments is not supported ")
|
||||
case n =>
|
||||
throw new AnalysisException(s"UDF class with ${n} type arguments is not supported.")
|
||||
}
|
||||
} catch {
|
||||
case e @ (_: InstantiationException | _: IllegalArgumentException) =>
|
||||
logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor")
|
||||
throw new AnalysisException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor")
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath")
|
||||
case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a Java UDAF class using reflection, for use from pyspark
|
||||
*
|
||||
* @param name UDAF name
|
||||
* @param className fully qualified class name of UDAF
|
||||
*/
|
||||
private[sql] def registerJavaUDAF(name: String, className: String): Unit = {
|
||||
try {
|
||||
val clazz = Utils.classForName(className)
|
||||
if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) {
|
||||
throw new AnalysisException(s"class $className doesn't implement interface UserDefinedAggregateFunction")
|
||||
}
|
||||
val udaf = clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction]
|
||||
register(name, udaf)
|
||||
} catch {
|
||||
case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath")
|
||||
case e @ (_: InstantiationException | _: IllegalArgumentException) =>
|
||||
throw new AnalysisException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a user-defined function with 1 arguments.
|
||||
* @since 1.3.0
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package test.org.apache.spark.sql;
|
||||
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
import org.junit.After;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
|
||||
public class JavaUDAFSuite {
|
||||
|
||||
private transient SparkSession spark;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
spark = SparkSession.builder()
|
||||
.master("local[*]")
|
||||
.appName("testing")
|
||||
.getOrCreate();
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
spark.stop();
|
||||
spark = null;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Test
|
||||
public void udf1Test() {
|
||||
spark.range(1, 10).toDF("value").registerTempTable("df");
|
||||
spark.udf().registerJavaUDAF("myDoubleAvg", MyDoubleAvg.class.getName());
|
||||
Row result = spark.sql("SELECT myDoubleAvg(value) as my_avg from df").head();
|
||||
Assert.assertEquals(105.0, result.getDouble(0), 1.0e-6);
|
||||
}
|
||||
|
||||
}
|
|
@ -15,7 +15,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hive.aggregate;
|
||||
package test.org.apache.spark.sql;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
|
@ -15,18 +15,18 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hive.aggregate;
|
||||
package test.org.apache.spark.sql;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
|
||||
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
|
||||
import org.apache.spark.sql.types.StructField;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
import org.apache.spark.sql.types.DataType;
|
||||
import org.apache.spark.sql.types.DataTypes;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.types.StructField;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
|
||||
/**
|
||||
* An example {@link UserDefinedAggregateFunction} to calculate the sum of a
|
|
@ -57,6 +57,13 @@
|
|||
<artifactId>spark-sql_${scala.binary.version}</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-sql_${scala.binary.version}</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<type>test-jar</type>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-tags_${scala.binary.version}</artifactId>
|
||||
|
|
|
@ -31,7 +31,7 @@ import org.apache.spark.sql.expressions.Window;
|
|||
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
|
||||
import static org.apache.spark.sql.functions.*;
|
||||
import org.apache.spark.sql.hive.test.TestHive$;
|
||||
import org.apache.spark.sql.hive.aggregate.MyDoubleSum;
|
||||
import test.org.apache.spark.sql.MyDoubleSum;
|
||||
|
||||
public class JavaDataFrameSuite {
|
||||
private transient SQLContext hc;
|
||||
|
|
|
@ -20,16 +20,19 @@ package org.apache.spark.sql.hive.execution
|
|||
import scala.collection.JavaConverters._
|
||||
import scala.util.Random
|
||||
|
||||
import test.org.apache.spark.sql.MyDoubleAvg
|
||||
import test.org.apache.spark.sql.MyDoubleSum
|
||||
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
|
||||
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
|
||||
import org.apache.spark.sql.hive.test.TestHiveSingleton
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
||||
class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction {
|
||||
|
||||
def inputSchema: StructType = schema
|
||||
|
|
Loading…
Reference in a new issue