[SPARK-18637][SQL] Stateful UDF should be considered as nondeterministic

## What changes were proposed in this pull request?

Make stateful udf as nondeterministic

## How was this patch tested?
Add new test cases with both Stateful and Stateless UDF.
Without the patch, the test cases will throw exception:

1 did not equal 10
ScalaTestFailureLocation: org.apache.spark.sql.hive.execution.HiveUDFSuite$$anonfun$21 at (HiveUDFSuite.scala:501)
org.scalatest.exceptions.TestFailedException: 1 did not equal 10
        at org.scalatest.Assertions$class.newAssertionFailedException(Assertions.scala:500)
        at org.scalatest.FunSuite.newAssertionFailedException(FunSuite.scala:1555)
        ...

Author: Zhan Zhang <zhanzhang@fb.com>

Closes #16068 from zhzhan/state.
This commit is contained in:
Zhan Zhang 2016-12-09 16:35:06 +08:00 committed by Wenchen Fan
parent c074c96dc5
commit 67587d961d
2 changed files with 45 additions and 4 deletions

View file

@ -61,7 +61,7 @@ private[hive] case class HiveSimpleUDF(
@transient @transient
private lazy val isUDFDeterministic = { private lazy val isUDFDeterministic = {
val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
udfType != null && udfType.deterministic() udfType != null && udfType.deterministic() && !udfType.stateful()
} }
override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable) override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable)
@ -144,7 +144,7 @@ private[hive] case class HiveGenericUDF(
@transient @transient
private lazy val isUDFDeterministic = { private lazy val isUDFDeterministic = {
val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
udfType != null && udfType.deterministic() udfType != null && udfType.deterministic() && !udfType.stateful()
} }
@transient @transient

View file

@ -21,15 +21,17 @@ import java.io.{DataInput, DataOutput, File, PrintWriter}
import java.util.{ArrayList, Arrays, Properties} import java.util.{ArrayList, Arrays, Properties}
import org.apache.hadoop.conf.Configuration import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hive.ql.udf.UDAFPercentile import org.apache.hadoop.hive.ql.exec.UDF
import org.apache.hadoop.hive.ql.udf.{UDAFPercentile, UDFType}
import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic._
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject
import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats}
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.io.Writable import org.apache.hadoop.io.{LongWritable, Writable}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.functions.max
import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
@ -487,6 +489,26 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
assert(count4 == 1) assert(count4 == 1)
sql("DROP TABLE parquet_tmp") sql("DROP TABLE parquet_tmp")
} }
test("Hive Stateful UDF") {
withUserDefinedFunction("statefulUDF" -> true, "statelessUDF" -> true) {
sql(s"CREATE TEMPORARY FUNCTION statefulUDF AS '${classOf[StatefulUDF].getName}'")
sql(s"CREATE TEMPORARY FUNCTION statelessUDF AS '${classOf[StatelessUDF].getName}'")
val testData = spark.range(10).repartition(1)
// Expected Max(s) is 10 as statefulUDF returns the sequence number starting from 1.
checkAnswer(testData.selectExpr("statefulUDF() as s").agg(max($"s")), Row(10))
// Expected Max(s) is 5 as statefulUDF returns the sequence number starting from 1,
// and the data is evenly distributed into 2 partitions.
checkAnswer(testData.repartition(2)
.selectExpr("statefulUDF() as s").agg(max($"s")), Row(5))
// Expected Max(s) is 1, as stateless UDF is deterministic and foldable and replaced
// by constant 1 by ConstantFolding optimizer.
checkAnswer(testData.selectExpr("statelessUDF() as s").agg(max($"s")), Row(1))
}
}
} }
class TestPair(x: Int, y: Int) extends Writable with Serializable { class TestPair(x: Int, y: Int) extends Writable with Serializable {
@ -551,3 +573,22 @@ class PairUDF extends GenericUDF {
override def getDisplayString(p1: Array[String]): String = "" override def getDisplayString(p1: Array[String]): String = ""
} }
@UDFType(stateful = true)
class StatefulUDF extends UDF {
private val result = new LongWritable(0)
def evaluate(): LongWritable = {
result.set(result.get() + 1)
result
}
}
class StatelessUDF extends UDF {
private val result = new LongWritable(0)
def evaluate(): LongWritable = {
result.set(result.get() + 1)
result
}
}