diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 349faae40b..26dc372d7c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -61,7 +61,7 @@ private[hive] case class HiveSimpleUDF( @transient private lazy val isUDFDeterministic = { val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) - udfType != null && udfType.deterministic() + udfType != null && udfType.deterministic() && !udfType.stateful() } override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable) @@ -144,7 +144,7 @@ private[hive] case class HiveGenericUDF( @transient private lazy val isUDFDeterministic = { val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) - udfType != null && udfType.deterministic() + udfType != null && udfType.deterministic() && !udfType.stateful() } @transient diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 48adc833f4..4098bb597b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -21,15 +21,17 @@ import java.io.{DataInput, DataOutput, File, PrintWriter} import java.util.{ArrayList, Arrays, Properties} import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.udf.UDAFPercentile +import org.apache.hadoop.hive.ql.exec.UDF +import org.apache.hadoop.hive.ql.udf.{UDAFPercentile, UDFType} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.hadoop.io.Writable +import org.apache.hadoop.io.{LongWritable, Writable} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.functions.max import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils @@ -487,6 +489,26 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { assert(count4 == 1) sql("DROP TABLE parquet_tmp") } + + test("Hive Stateful UDF") { + withUserDefinedFunction("statefulUDF" -> true, "statelessUDF" -> true) { + sql(s"CREATE TEMPORARY FUNCTION statefulUDF AS '${classOf[StatefulUDF].getName}'") + sql(s"CREATE TEMPORARY FUNCTION statelessUDF AS '${classOf[StatelessUDF].getName}'") + val testData = spark.range(10).repartition(1) + + // Expected Max(s) is 10 as statefulUDF returns the sequence number starting from 1. + checkAnswer(testData.selectExpr("statefulUDF() as s").agg(max($"s")), Row(10)) + + // Expected Max(s) is 5 as statefulUDF returns the sequence number starting from 1, + // and the data is evenly distributed into 2 partitions. + checkAnswer(testData.repartition(2) + .selectExpr("statefulUDF() as s").agg(max($"s")), Row(5)) + + // Expected Max(s) is 1, as stateless UDF is deterministic and foldable and replaced + // by constant 1 by ConstantFolding optimizer. + checkAnswer(testData.selectExpr("statelessUDF() as s").agg(max($"s")), Row(1)) + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { @@ -551,3 +573,22 @@ class PairUDF extends GenericUDF { override def getDisplayString(p1: Array[String]): String = "" } + +@UDFType(stateful = true) +class StatefulUDF extends UDF { + private val result = new LongWritable(0) + + def evaluate(): LongWritable = { + result.set(result.get() + 1) + result + } +} + +class StatelessUDF extends UDF { + private val result = new LongWritable(0) + + def evaluate(): LongWritable = { + result.set(result.get() + 1) + result + } +}