[SPARK-19163][PYTHON][SQL] Delay _judf initialization to the __call__
## What changes were proposed in this pull request? Defer `UserDefinedFunction._judf` initialization to the first call. This prevents unintended `SparkSession` initialization. This allows users to define and import UDF without creating a context / session as a side effect. [SPARK-19163](https://issues.apache.org/jira/browse/SPARK-19163) ## How was this patch tested? Unit tests. Author: zero323 <zero323@users.noreply.github.com> Closes #16536 from zero323/SPARK-19163.
This commit is contained in:
parent
081b7addaf
commit
9063835803
|
@ -1826,25 +1826,38 @@ class UserDefinedFunction(object):
|
|||
def __init__(self, func, returnType, name=None):
|
||||
self.func = func
|
||||
self.returnType = returnType
|
||||
self._judf = self._create_judf(name)
|
||||
# Stores UserDefinedPythonFunctions jobj, once initialized
|
||||
self._judf_placeholder = None
|
||||
self._name = name or (
|
||||
func.__name__ if hasattr(func, '__name__')
|
||||
else func.__class__.__name__)
|
||||
|
||||
def _create_judf(self, name):
|
||||
@property
|
||||
def _judf(self):
|
||||
# It is possible that concurrent access, to newly created UDF,
|
||||
# will initialize multiple UserDefinedPythonFunctions.
|
||||
# This is unlikely, doesn't affect correctness,
|
||||
# and should have a minimal performance impact.
|
||||
if self._judf_placeholder is None:
|
||||
self._judf_placeholder = self._create_judf()
|
||||
return self._judf_placeholder
|
||||
|
||||
def _create_judf(self):
|
||||
from pyspark.sql import SparkSession
|
||||
sc = SparkContext.getOrCreate()
|
||||
wrapped_func = _wrap_function(sc, self.func, self.returnType)
|
||||
|
||||
spark = SparkSession.builder.getOrCreate()
|
||||
sc = spark.sparkContext
|
||||
|
||||
wrapped_func = _wrap_function(sc, self.func, self.returnType)
|
||||
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
|
||||
if name is None:
|
||||
f = self.func
|
||||
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
|
||||
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
|
||||
name, wrapped_func, jdt)
|
||||
self._name, wrapped_func, jdt)
|
||||
return judf
|
||||
|
||||
def __call__(self, *cols):
|
||||
judf = self._judf
|
||||
sc = SparkContext._active_spark_context
|
||||
jc = self._judf.apply(_to_seq(sc, cols, _to_java_column))
|
||||
return Column(jc)
|
||||
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
|
||||
|
||||
|
||||
@since(1.3)
|
||||
|
|
|
@ -468,6 +468,27 @@ class SQLTests(ReusedPySparkTestCase):
|
|||
row2 = df2.select(sameText(df2['file'])).first()
|
||||
self.assertTrue(row2[0].find("people.json") != -1)
|
||||
|
||||
def test_udf_defers_judf_initalization(self):
|
||||
# This is separate of UDFInitializationTests
|
||||
# to avoid context initialization
|
||||
# when udf is called
|
||||
|
||||
from pyspark.sql.functions import UserDefinedFunction
|
||||
|
||||
f = UserDefinedFunction(lambda x: x, StringType())
|
||||
|
||||
self.assertIsNone(
|
||||
f._judf_placeholder,
|
||||
"judf should not be initialized before the first call."
|
||||
)
|
||||
|
||||
self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.")
|
||||
|
||||
self.assertIsNotNone(
|
||||
f._judf_placeholder,
|
||||
"judf should be initialized after UDF has been called."
|
||||
)
|
||||
|
||||
def test_basic_functions(self):
|
||||
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
|
||||
df = self.spark.read.json(rdd)
|
||||
|
@ -1947,6 +1968,29 @@ class SQLTests2(ReusedPySparkTestCase):
|
|||
df.collect()
|
||||
|
||||
|
||||
class UDFInitializationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
if SparkSession._instantiatedSession is not None:
|
||||
SparkSession._instantiatedSession.stop()
|
||||
|
||||
if SparkContext._active_spark_context is not None:
|
||||
SparkContext._active_spark_contex.stop()
|
||||
|
||||
def test_udf_init_shouldnt_initalize_context(self):
|
||||
from pyspark.sql.functions import UserDefinedFunction
|
||||
|
||||
UserDefinedFunction(lambda x: x, StringType())
|
||||
|
||||
self.assertIsNone(
|
||||
SparkContext._active_spark_context,
|
||||
"SparkContext shouldn't be initialized when UserDefinedFunction is created."
|
||||
)
|
||||
self.assertIsNone(
|
||||
SparkSession._instantiatedSession,
|
||||
"SparkSession shouldn't be initialized when UserDefinedFunction is created."
|
||||
)
|
||||
|
||||
|
||||
class HiveContextSQLTests(ReusedPySparkTestCase):
|
||||
|
||||
@classmethod
|
||||
|
|
Loading…
Reference in a new issue