[SPARK-35389][SQL] V2 ScalarFunction should support magic method with null arguments

### What changes were proposed in this pull request?

When creating `Invoke` and `StaticInvoke` for `ScalarFunction`'s magic method, set `propagateNull` to false.

### Why are the changes needed?

When `propgagateNull` is true (which is the default value), `Invoke` and `StaticInvoke` will return null if any of the argument is null. For scalar function this is incorrect, as we should leave the logic to function implementation instead.

### Does this PR introduce _any_ user-facing change?

Yes. Now null arguments shall be properly handled with magic method.

### How was this patch tested?

Added new tests.

Closes #32553 from sunchao/SPARK-35389.

Authored-by: Chao Sun <sunchao@apple.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Chao Sun 2021-05-18 08:45:55 +00:00 committed by Wenchen Fan
parent cce0048c78
commit 44d762abc6
5 changed files with 96 additions and 10 deletions

View file

@ -31,6 +31,7 @@ import org.apache.spark.sql.types.DataType;
* InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}.
* The mapping between {@link DataType} and the corresponding JVM type is defined below.
* <p>
* <h2> Magic method </h2>
* <b>IMPORTANT</b>: the default implementation of {@link #produceResult} throws
* {@link UnsupportedOperationException}. Users must choose to either override this method, or
* implement a magic method with name {@link #MAGIC_METHOD_NAME}, which takes individual parameters
@ -82,6 +83,24 @@ import org.apache.spark.sql.types.DataType;
* following the mapping defined below, and then checking if there is a matching method from all the
* declared methods in the UDF class, using method name and the Java types.
* <p>
* <h2> Handling of nullable primitive arguments </h2>
* The handling of null primitive arguments is different between the magic method approach and
* the {@link #produceResult} approach. With the former, whenever any of the method arguments meet
* the following conditions:
* <ol>
* <li>the argument is of primitive type</li>
* <li>the argument is nullable</li>
* <li>the value of the argument is null</li>
* </ol>
* Spark will return null directly instead of calling the magic method. On the other hand, Spark
* will pass null primitive arguments to {@link #produceResult} and it is user's responsibility to
* handle them in the function implementation.
* <p>
* Because of the difference, if Spark users want to implement special handling of nulls for
* nullable primitive arguments, they should override the {@link #produceResult} method instead
* of using the magic method approach.
* <p>
* <h2> Spark data type to Java type mapping </h2>
* The following are the mapping from {@link DataType SQL data type} to Java type which is used
* by Spark to infer parameter types for the magic methods as well as return value type for
* {@link #produceResult}:

View file

@ -2204,11 +2204,12 @@ class Analyzer(override val catalogManager: CatalogManager)
findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
case Some(m) if Modifier.isStatic(m.getModifiers) =>
StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
MAGIC_METHOD_NAME, arguments, returnNullable = scalarFunc.isResultNullable)
MAGIC_METHOD_NAME, arguments, propagateNull = false,
returnNullable = scalarFunc.isResultNullable)
case Some(_) =>
val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
arguments, returnNullable = scalarFunc.isResultNullable)
arguments, propagateNull = false, returnNullable = scalarFunc.isResultNullable)
case _ =>
// TODO: handle functions defined in Scala too - in Scala, even if a
// subclass do not override the default method in parent interface

View file

@ -50,7 +50,10 @@ trait InvokeLike extends Expression with NonSQLExpression {
def propagateNull: Boolean
protected lazy val needNullCheck: Boolean = propagateNull && arguments.exists(_.nullable)
protected lazy val needNullCheck: Boolean = needNullCheckForIndex.contains(true)
protected lazy val needNullCheckForIndex: Array[Boolean] =
arguments.map(a => a.nullable && (propagateNull ||
ScalaReflection.dataTypeJavaClass(a.dataType).isPrimitive)).toArray
protected lazy val evaluatedArgs: Array[Object] = new Array[Object](arguments.length)
private lazy val boxingFn: Any => Any =
ScalaReflection.typeBoxedJavaMapping
@ -89,7 +92,7 @@ trait InvokeLike extends Expression with NonSQLExpression {
val reset = s"$resultIsNull = false;"
val argCodes = arguments.zipWithIndex.map { case (e, i) =>
val expr = e.genCode(ctx)
val updateResultIsNull = if (e.nullable) {
val updateResultIsNull = if (needNullCheckForIndex(i)) {
s"$resultIsNull = ${expr.isNull};"
} else {
""
@ -131,11 +134,14 @@ trait InvokeLike extends Expression with NonSQLExpression {
def invoke(obj: Any, method: Method, input: InternalRow): Any = {
var i = 0
val len = arguments.length
var resultNull = false
while (i < len) {
evaluatedArgs(i) = arguments(i).eval(input).asInstanceOf[Object]
val result = arguments(i).eval(input).asInstanceOf[Object]
evaluatedArgs(i) = result
resultNull = resultNull || (result == null && needNullCheckForIndex(i))
i += 1
}
if (needNullCheck && evaluatedArgs.contains(null)) {
if (needNullCheck && resultNull) {
// return null if one of arguments is null
null
} else {
@ -226,7 +232,9 @@ object SerializerSupport {
* @param functionName The name of the method to call.
* @param arguments An optional list of expressions to pass as arguments to the function.
* @param propagateNull When true, and any of the arguments is null, null will be returned instead
* of calling the function.
* of calling the function. Also note: when this is false but any of the
* arguments is of primitive type and is null, null also will be returned
* without invoking the function.
* @param returnNullable When false, indicating the invoked method will always return
* non-null value.
*/
@ -318,7 +326,9 @@ case class StaticInvoke(
* @param arguments An optional list of expressions, whose evaluation will be passed to the
* function.
* @param propagateNull When true, and any of the arguments is null, null will be returned instead
* of calling the function.
* of calling the function. Also note: when this is false but any of the
* arguments is of primitive type and is null, null also will be returned
* without invoking the function.
* @param returnNullable When false, indicating the invoked method will always return
* non-null value.
*/
@ -452,7 +462,9 @@ object NewInstance {
* @param cls The class to construct.
* @param arguments A list of expression to use as arguments to the constructor.
* @param propagateNull When true, if any of the arguments is null, then null will be returned
* instead of trying to construct the object.
* instead of trying to construct the object. Also note: when this is false
* but any of the arguments is of primitive type and is null, null also will
* be returned without constructing the object.
* @param dataType The type of object being constructed, as a Spark SQL datatype. This allows you
* to manually specify the type when the object in question is a valid internal
* representation (i.e. ArrayData) instead of an object.

View file

@ -120,5 +120,24 @@ public class JavaStrLen implements UnboundFunction {
public static class JavaStrLenNoImpl extends JavaStrLenBase {
}
// a null-safe version which returns 0 for null arguments
public static class JavaStrLenMagicNullSafe extends JavaStrLenBase {
public int invoke(UTF8String str) {
if (str == null) {
return 0;
}
return str.toString().length();
}
}
public static class JavaStrLenStaticMagicNullSafe extends JavaStrLenBase {
public static int invoke(UTF8String str) {
if (str == null) {
return 0;
}
return str.toString().length();
}
}
}

View file

@ -20,12 +20,14 @@ package org.apache.spark.sql.connector
import java.util
import java.util.Collections
import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaStrLen}
import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaLongAdd, JavaStrLen}
import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.JavaLongAddMagic
import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK, NO_CODEGEN}
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog, SupportsNamespaces}
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.internal.SQLConf
@ -213,6 +215,39 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
.getMessage.contains("neither implement magic method nor override 'produceResult'"))
}
test("SPARK-35389: magic function should handle null arguments") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"), new JavaStrLen(new JavaStrLenMagicNullSafe))
addFunction(Identifier.of(Array("ns"), "strlen2"),
new JavaStrLen(new JavaStrLenStaticMagicNullSafe))
Seq("strlen", "strlen2").foreach { name =>
checkAnswer(sql(s"SELECT testcat.ns.$name(CAST(NULL as STRING))"), Row(0) :: Nil)
}
}
test("SPARK-35389: magic function should handle null primitive arguments") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "add"), new JavaLongAdd(new JavaLongAddMagic(false)))
addFunction(Identifier.of(Array("ns"), "static_add"),
new JavaLongAdd(new JavaLongAddMagic(false)))
Seq("add", "static_add").foreach { name =>
Seq(true, false).foreach { codegenEnabled =>
val codeGenFactoryMode = if (codegenEnabled) FALLBACK else NO_CODEGEN
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString,
SQLConf.CODEGEN_FACTORY_MODE.key -> codeGenFactoryMode.toString) {
checkAnswer(sql(s"SELECT testcat.ns.$name(CAST(NULL as BIGINT), 42L)"), Row(null) :: Nil)
checkAnswer(sql(s"SELECT testcat.ns.$name(42L, CAST(NULL as BIGINT))"), Row(null) :: Nil)
checkAnswer(sql(s"SELECT testcat.ns.$name(42L, 58L)"), Row(100) :: Nil)
checkAnswer(sql(s"SELECT testcat.ns.$name(CAST(NULL as BIGINT), CAST(NULL as BIGINT))"),
Row(null) :: Nil)
}
}
}
}
test("bad bound function (neither scalar nor aggregate)") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(BadBoundFunction))