[SPARK-22901][PYTHON] Add deterministic flag to pyspark UDF
## What changes were proposed in this pull request? In SPARK-20586 the flag `deterministic` was added to Scala UDF, but it is not available for python UDF. This flag is useful for cases when the UDF's code can return different result with the same input. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked more times than it is present in the query. This can lead to unexpected behavior. This PR adds the deterministic flag, via the `asNondeterministic` method, to let the user mark the function as non-deterministic and therefore avoid the optimizations which might lead to strange behaviors. ## How was this patch tested? Manual tests: ``` >>> from pyspark.sql.functions import * >>> from pyspark.sql.types import * >>> df_br = spark.createDataFrame([{'name': 'hello'}]) >>> import random >>> udf_random_col = udf(lambda: int(100*random.random()), IntegerType()).asNondeterministic() >>> df_br = df_br.withColumn('RAND', udf_random_col()) >>> random.seed(1234) >>> udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) >>> df_br.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).show() +-----+----+-------------+ | name|RAND|RAND_PLUS_TEN| +-----+----+-------------+ |hello| 3| 13| +-----+----+-------------+ ``` Author: Marco Gaido <marcogaido91@gmail.com> Author: Marco Gaido <mgaido@hortonworks.com> Closes #19929 from mgaido91/SPARK-22629.
This commit is contained in:
parent
eb386be1ed
commit
ff48b1b338
|
@ -39,6 +39,13 @@ private[spark] object PythonEvalType {
|
|||
|
||||
val SQL_PANDAS_SCALAR_UDF = 200
|
||||
val SQL_PANDAS_GROUP_MAP_UDF = 201
|
||||
|
||||
def toString(pythonEvalType: Int): String = pythonEvalType match {
|
||||
case NON_UDF => "NON_UDF"
|
||||
case SQL_BATCHED_UDF => "SQL_BATCHED_UDF"
|
||||
case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF"
|
||||
case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -2093,9 +2093,14 @@ class PandasUDFType(object):
|
|||
def udf(f=None, returnType=StringType()):
|
||||
"""Creates a user defined function (UDF).
|
||||
|
||||
.. note:: The user-defined functions must be deterministic. Due to optimization,
|
||||
duplicate invocations may be eliminated or the function may even be invoked more times than
|
||||
it is present in the query.
|
||||
.. note:: The user-defined functions are considered deterministic by default. Due to
|
||||
optimization, duplicate invocations may be eliminated or the function may even be invoked
|
||||
more times than it is present in the query. If your function is not deterministic, call
|
||||
`asNondeterministic` on the user defined function. E.g.:
|
||||
|
||||
>>> from pyspark.sql.types import IntegerType
|
||||
>>> import random
|
||||
>>> random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic()
|
||||
|
||||
.. note:: The user-defined functions do not support conditional expressions or short curcuiting
|
||||
in boolean expressions and it ends up with being executed all internally. If the functions
|
||||
|
|
|
@ -435,6 +435,15 @@ class SQLTests(ReusedSQLTestCase):
|
|||
self.assertEqual(list(range(3)), l1)
|
||||
self.assertEqual(1, l2)
|
||||
|
||||
def test_nondeterministic_udf(self):
|
||||
from pyspark.sql.functions import udf
|
||||
import random
|
||||
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
|
||||
df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
|
||||
udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
|
||||
[row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
|
||||
self.assertEqual(row[0] + 10, row[1])
|
||||
|
||||
def test_broadcast_in_udf(self):
|
||||
bar = {"a": "aa", "b": "bb", "c": "abc"}
|
||||
foo = self.sc.broadcast(bar)
|
||||
|
|
|
@ -92,6 +92,7 @@ class UserDefinedFunction(object):
|
|||
func.__name__ if hasattr(func, '__name__')
|
||||
else func.__class__.__name__)
|
||||
self.evalType = evalType
|
||||
self._deterministic = True
|
||||
|
||||
@property
|
||||
def returnType(self):
|
||||
|
@ -129,7 +130,7 @@ class UserDefinedFunction(object):
|
|||
wrapped_func = _wrap_function(sc, self.func, self.returnType)
|
||||
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
|
||||
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
|
||||
self._name, wrapped_func, jdt, self.evalType)
|
||||
self._name, wrapped_func, jdt, self.evalType, self._deterministic)
|
||||
return judf
|
||||
|
||||
def __call__(self, *cols):
|
||||
|
@ -161,5 +162,15 @@ class UserDefinedFunction(object):
|
|||
wrapper.func = self.func
|
||||
wrapper.returnType = self.returnType
|
||||
wrapper.evalType = self.evalType
|
||||
wrapper.asNondeterministic = self.asNondeterministic
|
||||
|
||||
return wrapper
|
||||
|
||||
def asNondeterministic(self):
|
||||
"""
|
||||
Updates UserDefinedFunction to nondeterministic.
|
||||
|
||||
.. versionadded:: 2.3
|
||||
"""
|
||||
self._deterministic = False
|
||||
return self
|
||||
|
|
|
@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.TypeTag
|
|||
import scala.util.Try
|
||||
|
||||
import org.apache.spark.annotation.InterfaceStability
|
||||
import org.apache.spark.api.python.PythonEvalType
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.api.java._
|
||||
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
|
||||
|
@ -41,8 +42,6 @@ import org.apache.spark.util.Utils
|
|||
* spark.udf
|
||||
* }}}
|
||||
*
|
||||
* @note The user-defined functions must be deterministic.
|
||||
*
|
||||
* @since 1.3.0
|
||||
*/
|
||||
@InterfaceStability.Stable
|
||||
|
@ -58,6 +57,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
|
|||
| pythonIncludes: ${udf.func.pythonIncludes}
|
||||
| pythonExec: ${udf.func.pythonExec}
|
||||
| dataType: ${udf.dataType}
|
||||
| pythonEvalType: ${PythonEvalType.toString(udf.pythonEvalType)}
|
||||
| udfDeterministic: ${udf.udfDeterministic}
|
||||
""".stripMargin)
|
||||
|
||||
functionRegistry.createOrReplaceTempFunction(name, udf.builder)
|
||||
|
|
|
@ -29,9 +29,12 @@ case class PythonUDF(
|
|||
func: PythonFunction,
|
||||
dataType: DataType,
|
||||
children: Seq[Expression],
|
||||
evalType: Int)
|
||||
evalType: Int,
|
||||
udfDeterministic: Boolean)
|
||||
extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {
|
||||
|
||||
override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
|
||||
|
||||
override def toString: String = s"$name(${children.mkString(", ")})"
|
||||
|
||||
override def nullable: Boolean = true
|
||||
|
|
|
@ -29,10 +29,11 @@ case class UserDefinedPythonFunction(
|
|||
name: String,
|
||||
func: PythonFunction,
|
||||
dataType: DataType,
|
||||
pythonEvalType: Int) {
|
||||
pythonEvalType: Int,
|
||||
udfDeterministic: Boolean) {
|
||||
|
||||
def builder(e: Seq[Expression]): PythonUDF = {
|
||||
PythonUDF(name, func, dataType, e, pythonEvalType)
|
||||
PythonUDF(name, func, dataType, e, pythonEvalType, udfDeterministic)
|
||||
}
|
||||
|
||||
/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
|
||||
|
|
|
@ -109,4 +109,5 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction(
|
|||
name = "dummyUDF",
|
||||
func = new DummyUDF,
|
||||
dataType = BooleanType,
|
||||
pythonEvalType = PythonEvalType.SQL_BATCHED_UDF)
|
||||
pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
|
||||
udfDeterministic = true)
|
||||
|
|
Loading…
Reference in a new issue