5aadbc929c
## What changes were proposed in this pull request? ```Python import random from pyspark.sql.functions import udf from pyspark.sql.types import IntegerType, StringType random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() spark.catalog.registerFunction("random_udf", random_udf, StringType()) spark.sql("SELECT random_udf()").collect() ``` We will get the following error. ``` Py4JError: An error occurred while calling o29.__getnewargs__. Trace: py4j.Py4JException: Method __getnewargs__([]) does not exist at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318) at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326) at py4j.Gateway.invoke(Gateway.java:274) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:214) at java.lang.Thread.run(Thread.java:745) ``` This PR is to support it. ## How was this patch tested? WIP Author: gatorsmile <gatorsmile@gmail.com> Closes #20137 from gatorsmile/registerFunction.
184 lines
7.4 KiB
Python
184 lines
7.4 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
"""
|
|
User-defined function related classes and functions
|
|
"""
|
|
import functools
|
|
|
|
from pyspark import SparkContext
|
|
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType
|
|
from pyspark.sql.column import Column, _to_java_column, _to_seq
|
|
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string
|
|
|
|
|
|
def _wrap_function(sc, func, returnType):
|
|
command = (func, returnType)
|
|
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
|
|
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
|
|
sc.pythonVer, broadcast_vars, sc._javaAccumulator)
|
|
|
|
|
|
def _create_udf(f, returnType, evalType):
|
|
|
|
if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF or \
|
|
evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
|
|
import inspect
|
|
from pyspark.sql.utils import require_minimum_pyarrow_version
|
|
|
|
require_minimum_pyarrow_version()
|
|
argspec = inspect.getargspec(f)
|
|
|
|
if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF and len(argspec.args) == 0 and \
|
|
argspec.varargs is None:
|
|
raise ValueError(
|
|
"Invalid function: 0-arg pandas_udfs are not supported. "
|
|
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
|
|
)
|
|
|
|
if evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF and len(argspec.args) != 1:
|
|
raise ValueError(
|
|
"Invalid function: pandas_udfs with function type GROUP_MAP "
|
|
"must take a single arg that is a pandas DataFrame."
|
|
)
|
|
|
|
# Set the name of the UserDefinedFunction object to be the name of function f
|
|
udf_obj = UserDefinedFunction(
|
|
f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
|
|
return udf_obj._wrapped()
|
|
|
|
|
|
class UserDefinedFunction(object):
|
|
"""
|
|
User defined function in Python
|
|
|
|
.. versionadded:: 1.3
|
|
"""
|
|
def __init__(self, func,
|
|
returnType=StringType(),
|
|
name=None,
|
|
evalType=PythonEvalType.SQL_BATCHED_UDF,
|
|
deterministic=True):
|
|
if not callable(func):
|
|
raise TypeError(
|
|
"Invalid function: not a function or callable (__call__ is not defined): "
|
|
"{0}".format(type(func)))
|
|
|
|
if not isinstance(returnType, (DataType, str)):
|
|
raise TypeError(
|
|
"Invalid returnType: returnType should be DataType or str "
|
|
"but is {}".format(returnType))
|
|
|
|
if not isinstance(evalType, int):
|
|
raise TypeError(
|
|
"Invalid evalType: evalType should be an int but is {}".format(evalType))
|
|
|
|
self.func = func
|
|
self._returnType = returnType
|
|
# Stores UserDefinedPythonFunctions jobj, once initialized
|
|
self._returnType_placeholder = None
|
|
self._judf_placeholder = None
|
|
self._name = name or (
|
|
func.__name__ if hasattr(func, '__name__')
|
|
else func.__class__.__name__)
|
|
self.evalType = evalType
|
|
self.deterministic = deterministic
|
|
|
|
@property
|
|
def returnType(self):
|
|
# This makes sure this is called after SparkContext is initialized.
|
|
# ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string.
|
|
if self._returnType_placeholder is None:
|
|
if isinstance(self._returnType, DataType):
|
|
self._returnType_placeholder = self._returnType
|
|
else:
|
|
self._returnType_placeholder = _parse_datatype_string(self._returnType)
|
|
|
|
if self.evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF \
|
|
and not isinstance(self._returnType_placeholder, StructType):
|
|
raise ValueError("Invalid returnType: returnType must be a StructType for "
|
|
"pandas_udf with function type GROUP_MAP")
|
|
|
|
return self._returnType_placeholder
|
|
|
|
@property
|
|
def _judf(self):
|
|
# It is possible that concurrent access, to newly created UDF,
|
|
# will initialize multiple UserDefinedPythonFunctions.
|
|
# This is unlikely, doesn't affect correctness,
|
|
# and should have a minimal performance impact.
|
|
if self._judf_placeholder is None:
|
|
self._judf_placeholder = self._create_judf()
|
|
return self._judf_placeholder
|
|
|
|
def _create_judf(self):
|
|
from pyspark.sql import SparkSession
|
|
|
|
spark = SparkSession.builder.getOrCreate()
|
|
sc = spark.sparkContext
|
|
|
|
wrapped_func = _wrap_function(sc, self.func, self.returnType)
|
|
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
|
|
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
|
|
self._name, wrapped_func, jdt, self.evalType, self.deterministic)
|
|
return judf
|
|
|
|
def __call__(self, *cols):
|
|
judf = self._judf
|
|
sc = SparkContext._active_spark_context
|
|
return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
|
|
|
|
# This function is for improving the online help system in the interactive interpreter.
|
|
# For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and
|
|
# argument annotation. (See: SPARK-19161)
|
|
def _wrapped(self):
|
|
"""
|
|
Wrap this udf with a function and attach docstring from func
|
|
"""
|
|
|
|
# It is possible for a callable instance without __name__ attribute or/and
|
|
# __module__ attribute to be wrapped here. For example, functools.partial. In this case,
|
|
# we should avoid wrapping the attributes from the wrapped function to the wrapper
|
|
# function. So, we take out these attribute names from the default names to set and
|
|
# then manually assign it after being wrapped.
|
|
assignments = tuple(
|
|
a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__')
|
|
|
|
@functools.wraps(self.func, assigned=assignments)
|
|
def wrapper(*args):
|
|
return self(*args)
|
|
|
|
wrapper.__name__ = self._name
|
|
wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__')
|
|
else self.func.__class__.__module__)
|
|
|
|
wrapper.func = self.func
|
|
wrapper.returnType = self.returnType
|
|
wrapper.evalType = self.evalType
|
|
wrapper.deterministic = self.deterministic
|
|
wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped()
|
|
|
|
return wrapper
|
|
|
|
def asNondeterministic(self):
|
|
"""
|
|
Updates UserDefinedFunction to nondeterministic.
|
|
|
|
.. versionadded:: 2.3
|
|
"""
|
|
self.deterministic = False
|
|
return self
|