[SPARK-35389][SQL] V2 ScalarFunction should support magic method with null arguments
### What changes were proposed in this pull request? When creating `Invoke` and `StaticInvoke` for `ScalarFunction`'s magic method, set `propagateNull` to false. ### Why are the changes needed? When `propgagateNull` is true (which is the default value), `Invoke` and `StaticInvoke` will return null if any of the argument is null. For scalar function this is incorrect, as we should leave the logic to function implementation instead. ### Does this PR introduce _any_ user-facing change? Yes. Now null arguments shall be properly handled with magic method. ### How was this patch tested? Added new tests. Closes #32553 from sunchao/SPARK-35389. Authored-by: Chao Sun <sunchao@apple.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
cce0048c78
commit
44d762abc6
|
@ -31,6 +31,7 @@ import org.apache.spark.sql.types.DataType;
|
|||
* InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
|
||||
* The mapping between {@link DataType} and the corresponding JVM type is defined below.
|
||||
* <p>
|
||||
* <h2> Magic method </h2>
|
||||
* <b>IMPORTANT</b>: the default implementation of {@link #produceResult} throws
|
||||
* {@link UnsupportedOperationException}. Users must choose to either override this method, or
|
||||
* implement a magic method with name {@link #MAGIC_METHOD_NAME}, which takes individual parameters
|
||||
|
@ -82,6 +83,24 @@ import org.apache.spark.sql.types.DataType;
|
|||
* following the mapping defined below, and then checking if there is a matching method from all the
|
||||
* declared methods in the UDF class, using method name and the Java types.
|
||||
* <p>
|
||||
* <h2> Handling of nullable primitive arguments </h2>
|
||||
* The handling of null primitive arguments is different between the magic method approach and
|
||||
* the {@link #produceResult} approach. With the former, whenever any of the method arguments meet
|
||||
* the following conditions:
|
||||
* <ol>
|
||||
* <li>the argument is of primitive type</li>
|
||||
* <li>the argument is nullable</li>
|
||||
* <li>the value of the argument is null</li>
|
||||
* </ol>
|
||||
* Spark will return null directly instead of calling the magic method. On the other hand, Spark
|
||||
* will pass null primitive arguments to {@link #produceResult} and it is user's responsibility to
|
||||
* handle them in the function implementation.
|
||||
* <p>
|
||||
* Because of the difference, if Spark users want to implement special handling of nulls for
|
||||
* nullable primitive arguments, they should override the {@link #produceResult} method instead
|
||||
* of using the magic method approach.
|
||||
* <p>
|
||||
* <h2> Spark data type to Java type mapping </h2>
|
||||
* The following are the mapping from {@link DataType SQL data type} to Java type which is used
|
||||
* by Spark to infer parameter types for the magic methods as well as return value type for
|
||||
* {@link #produceResult}:
|
||||
|
|
|
@ -2204,11 +2204,12 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
|
||||
case Some(m) if Modifier.isStatic(m.getModifiers) =>
|
||||
StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
|
||||
MAGIC_METHOD_NAME, arguments, returnNullable = scalarFunc.isResultNullable)
|
||||
MAGIC_METHOD_NAME, arguments, propagateNull = false,
|
||||
returnNullable = scalarFunc.isResultNullable)
|
||||
case Some(_) =>
|
||||
val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
|
||||
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
|
||||
arguments, returnNullable = scalarFunc.isResultNullable)
|
||||
arguments, propagateNull = false, returnNullable = scalarFunc.isResultNullable)
|
||||
case _ =>
|
||||
// TODO: handle functions defined in Scala too - in Scala, even if a
|
||||
// subclass do not override the default method in parent interface
|
||||
|
|
|
@ -50,7 +50,10 @@ trait InvokeLike extends Expression with NonSQLExpression {
|
|||
|
||||
def propagateNull: Boolean
|
||||
|
||||
protected lazy val needNullCheck: Boolean = propagateNull && arguments.exists(_.nullable)
|
||||
protected lazy val needNullCheck: Boolean = needNullCheckForIndex.contains(true)
|
||||
protected lazy val needNullCheckForIndex: Array[Boolean] =
|
||||
arguments.map(a => a.nullable && (propagateNull ||
|
||||
ScalaReflection.dataTypeJavaClass(a.dataType).isPrimitive)).toArray
|
||||
protected lazy val evaluatedArgs: Array[Object] = new Array[Object](arguments.length)
|
||||
private lazy val boxingFn: Any => Any =
|
||||
ScalaReflection.typeBoxedJavaMapping
|
||||
|
@ -89,7 +92,7 @@ trait InvokeLike extends Expression with NonSQLExpression {
|
|||
val reset = s"$resultIsNull = false;"
|
||||
val argCodes = arguments.zipWithIndex.map { case (e, i) =>
|
||||
val expr = e.genCode(ctx)
|
||||
val updateResultIsNull = if (e.nullable) {
|
||||
val updateResultIsNull = if (needNullCheckForIndex(i)) {
|
||||
s"$resultIsNull = ${expr.isNull};"
|
||||
} else {
|
||||
""
|
||||
|
@ -131,11 +134,14 @@ trait InvokeLike extends Expression with NonSQLExpression {
|
|||
def invoke(obj: Any, method: Method, input: InternalRow): Any = {
|
||||
var i = 0
|
||||
val len = arguments.length
|
||||
var resultNull = false
|
||||
while (i < len) {
|
||||
evaluatedArgs(i) = arguments(i).eval(input).asInstanceOf[Object]
|
||||
val result = arguments(i).eval(input).asInstanceOf[Object]
|
||||
evaluatedArgs(i) = result
|
||||
resultNull = resultNull || (result == null && needNullCheckForIndex(i))
|
||||
i += 1
|
||||
}
|
||||
if (needNullCheck && evaluatedArgs.contains(null)) {
|
||||
if (needNullCheck && resultNull) {
|
||||
// return null if one of arguments is null
|
||||
null
|
||||
} else {
|
||||
|
@ -226,7 +232,9 @@ object SerializerSupport {
|
|||
* @param functionName The name of the method to call.
|
||||
* @param arguments An optional list of expressions to pass as arguments to the function.
|
||||
* @param propagateNull When true, and any of the arguments is null, null will be returned instead
|
||||
* of calling the function.
|
||||
* of calling the function. Also note: when this is false but any of the
|
||||
* arguments is of primitive type and is null, null also will be returned
|
||||
* without invoking the function.
|
||||
* @param returnNullable When false, indicating the invoked method will always return
|
||||
* non-null value.
|
||||
*/
|
||||
|
@ -318,7 +326,9 @@ case class StaticInvoke(
|
|||
* @param arguments An optional list of expressions, whose evaluation will be passed to the
|
||||
* function.
|
||||
* @param propagateNull When true, and any of the arguments is null, null will be returned instead
|
||||
* of calling the function.
|
||||
* of calling the function. Also note: when this is false but any of the
|
||||
* arguments is of primitive type and is null, null also will be returned
|
||||
* without invoking the function.
|
||||
* @param returnNullable When false, indicating the invoked method will always return
|
||||
* non-null value.
|
||||
*/
|
||||
|
@ -452,7 +462,9 @@ object NewInstance {
|
|||
* @param cls The class to construct.
|
||||
* @param arguments A list of expression to use as arguments to the constructor.
|
||||
* @param propagateNull When true, if any of the arguments is null, then null will be returned
|
||||
* instead of trying to construct the object.
|
||||
* instead of trying to construct the object. Also note: when this is false
|
||||
* but any of the arguments is of primitive type and is null, null also will
|
||||
* be returned without constructing the object.
|
||||
* @param dataType The type of object being constructed, as a Spark SQL datatype. This allows you
|
||||
* to manually specify the type when the object in question is a valid internal
|
||||
* representation (i.e. ArrayData) instead of an object.
|
||||
|
|
|
@ -120,5 +120,24 @@ public class JavaStrLen implements UnboundFunction {
|
|||
|
||||
public static class JavaStrLenNoImpl extends JavaStrLenBase {
|
||||
}
|
||||
|
||||
// a null-safe version which returns 0 for null arguments
|
||||
public static class JavaStrLenMagicNullSafe extends JavaStrLenBase {
|
||||
public int invoke(UTF8String str) {
|
||||
if (str == null) {
|
||||
return 0;
|
||||
}
|
||||
return str.toString().length();
|
||||
}
|
||||
}
|
||||
|
||||
public static class JavaStrLenStaticMagicNullSafe extends JavaStrLenBase {
|
||||
public static int invoke(UTF8String str) {
|
||||
if (str == null) {
|
||||
return 0;
|
||||
}
|
||||
return str.toString().length();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,12 +20,14 @@ package org.apache.spark.sql.connector
|
|||
import java.util
|
||||
import java.util.Collections
|
||||
|
||||
import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaStrLen}
|
||||
import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaLongAdd, JavaStrLen}
|
||||
import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.JavaLongAddMagic
|
||||
import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.sql.{AnalysisException, Row}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK, NO_CODEGEN}
|
||||
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog, SupportsNamespaces}
|
||||
import org.apache.spark.sql.connector.catalog.functions._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
@ -213,6 +215,39 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
|
|||
.getMessage.contains("neither implement magic method nor override 'produceResult'"))
|
||||
}
|
||||
|
||||
test("SPARK-35389: magic function should handle null arguments") {
|
||||
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns"), "strlen"), new JavaStrLen(new JavaStrLenMagicNullSafe))
|
||||
addFunction(Identifier.of(Array("ns"), "strlen2"),
|
||||
new JavaStrLen(new JavaStrLenStaticMagicNullSafe))
|
||||
Seq("strlen", "strlen2").foreach { name =>
|
||||
checkAnswer(sql(s"SELECT testcat.ns.$name(CAST(NULL as STRING))"), Row(0) :: Nil)
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-35389: magic function should handle null primitive arguments") {
|
||||
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns"), "add"), new JavaLongAdd(new JavaLongAddMagic(false)))
|
||||
addFunction(Identifier.of(Array("ns"), "static_add"),
|
||||
new JavaLongAdd(new JavaLongAddMagic(false)))
|
||||
|
||||
Seq("add", "static_add").foreach { name =>
|
||||
Seq(true, false).foreach { codegenEnabled =>
|
||||
val codeGenFactoryMode = if (codegenEnabled) FALLBACK else NO_CODEGEN
|
||||
|
||||
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString,
|
||||
SQLConf.CODEGEN_FACTORY_MODE.key -> codeGenFactoryMode.toString) {
|
||||
|
||||
checkAnswer(sql(s"SELECT testcat.ns.$name(CAST(NULL as BIGINT), 42L)"), Row(null) :: Nil)
|
||||
checkAnswer(sql(s"SELECT testcat.ns.$name(42L, CAST(NULL as BIGINT))"), Row(null) :: Nil)
|
||||
checkAnswer(sql(s"SELECT testcat.ns.$name(42L, 58L)"), Row(100) :: Nil)
|
||||
checkAnswer(sql(s"SELECT testcat.ns.$name(CAST(NULL as BIGINT), CAST(NULL as BIGINT))"),
|
||||
Row(null) :: Nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("bad bound function (neither scalar nor aggregate)") {
|
||||
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
|
||||
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(BadBoundFunction))
|
||||
|
|
Loading…
Reference in a new issue