From 44d762abc6395570f1f493a145fd5d1cbdf0b49e Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Tue, 18 May 2021 08:45:55 +0000 Subject: [PATCH] [SPARK-35389][SQL] V2 ScalarFunction should support magic method with null arguments ### What changes were proposed in this pull request? When creating `Invoke` and `StaticInvoke` for `ScalarFunction`'s magic method, set `propagateNull` to false. ### Why are the changes needed? When `propgagateNull` is true (which is the default value), `Invoke` and `StaticInvoke` will return null if any of the argument is null. For scalar function this is incorrect, as we should leave the logic to function implementation instead. ### Does this PR introduce _any_ user-facing change? Yes. Now null arguments shall be properly handled with magic method. ### How was this patch tested? Added new tests. Closes #32553 from sunchao/SPARK-35389. Authored-by: Chao Sun Signed-off-by: Wenchen Fan --- .../catalog/functions/ScalarFunction.java | 19 ++++++++++ .../sql/catalyst/analysis/Analyzer.scala | 5 ++- .../expressions/objects/objects.scala | 26 +++++++++---- .../catalog/functions/JavaStrLen.java | 19 ++++++++++ .../connector/DataSourceV2FunctionSuite.scala | 37 ++++++++++++++++++- 5 files changed, 96 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java index 858ab92349..d261a245a1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java @@ -31,6 +31,7 @@ import org.apache.spark.sql.types.DataType; * InternalRow API for the {@link DataType SQL data type} returned by {@link #resultType()}. * The mapping between {@link DataType} and the corresponding JVM type is defined below. *

+ *

Magic method

* IMPORTANT: the default implementation of {@link #produceResult} throws * {@link UnsupportedOperationException}. Users must choose to either override this method, or * implement a magic method with name {@link #MAGIC_METHOD_NAME}, which takes individual parameters @@ -82,6 +83,24 @@ import org.apache.spark.sql.types.DataType; * following the mapping defined below, and then checking if there is a matching method from all the * declared methods in the UDF class, using method name and the Java types. *

+ *

Handling of nullable primitive arguments

+ * The handling of null primitive arguments is different between the magic method approach and + * the {@link #produceResult} approach. With the former, whenever any of the method arguments meet + * the following conditions: + *
    + *
  1. the argument is of primitive type
  2. + *
  3. the argument is nullable
  4. + *
  5. the value of the argument is null
  6. + *
+ * Spark will return null directly instead of calling the magic method. On the other hand, Spark + * will pass null primitive arguments to {@link #produceResult} and it is user's responsibility to + * handle them in the function implementation. + *

+ * Because of the difference, if Spark users want to implement special handling of nulls for + * nullable primitive arguments, they should override the {@link #produceResult} method instead + * of using the magic method approach. + *

+ *

Spark data type to Java type mapping

* The following are the mapping from {@link DataType SQL data type} to Java type which is used * by Spark to infer parameter types for the magic methods as well as return value type for * {@link #produceResult}: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9954ca08c2..3f2e93a735 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2204,11 +2204,12 @@ class Analyzer(override val catalogManager: CatalogManager) findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { case Some(m) if Modifier.isStatic(m.getModifiers) => StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(), - MAGIC_METHOD_NAME, arguments, returnNullable = scalarFunc.isResultNullable) + MAGIC_METHOD_NAME, arguments, propagateNull = false, + returnNullable = scalarFunc.isResultNullable) case Some(_) => val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass)) Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(), - arguments, returnNullable = scalarFunc.isResultNullable) + arguments, propagateNull = false, returnNullable = scalarFunc.isResultNullable) case _ => // TODO: handle functions defined in Scala too - in Scala, even if a // subclass do not override the default method in parent interface diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index e871c307aa..c88f785632 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -50,7 +50,10 @@ trait InvokeLike extends Expression with NonSQLExpression { def propagateNull: Boolean - protected lazy val needNullCheck: Boolean = propagateNull && arguments.exists(_.nullable) + protected lazy val needNullCheck: Boolean = needNullCheckForIndex.contains(true) + protected lazy val needNullCheckForIndex: Array[Boolean] = + arguments.map(a => a.nullable && (propagateNull || + ScalaReflection.dataTypeJavaClass(a.dataType).isPrimitive)).toArray protected lazy val evaluatedArgs: Array[Object] = new Array[Object](arguments.length) private lazy val boxingFn: Any => Any = ScalaReflection.typeBoxedJavaMapping @@ -89,7 +92,7 @@ trait InvokeLike extends Expression with NonSQLExpression { val reset = s"$resultIsNull = false;" val argCodes = arguments.zipWithIndex.map { case (e, i) => val expr = e.genCode(ctx) - val updateResultIsNull = if (e.nullable) { + val updateResultIsNull = if (needNullCheckForIndex(i)) { s"$resultIsNull = ${expr.isNull};" } else { "" @@ -131,11 +134,14 @@ trait InvokeLike extends Expression with NonSQLExpression { def invoke(obj: Any, method: Method, input: InternalRow): Any = { var i = 0 val len = arguments.length + var resultNull = false while (i < len) { - evaluatedArgs(i) = arguments(i).eval(input).asInstanceOf[Object] + val result = arguments(i).eval(input).asInstanceOf[Object] + evaluatedArgs(i) = result + resultNull = resultNull || (result == null && needNullCheckForIndex(i)) i += 1 } - if (needNullCheck && evaluatedArgs.contains(null)) { + if (needNullCheck && resultNull) { // return null if one of arguments is null null } else { @@ -226,7 +232,9 @@ object SerializerSupport { * @param functionName The name of the method to call. * @param arguments An optional list of expressions to pass as arguments to the function. * @param propagateNull When true, and any of the arguments is null, null will be returned instead - * of calling the function. + * of calling the function. Also note: when this is false but any of the + * arguments is of primitive type and is null, null also will be returned + * without invoking the function. * @param returnNullable When false, indicating the invoked method will always return * non-null value. */ @@ -318,7 +326,9 @@ case class StaticInvoke( * @param arguments An optional list of expressions, whose evaluation will be passed to the * function. * @param propagateNull When true, and any of the arguments is null, null will be returned instead - * of calling the function. + * of calling the function. Also note: when this is false but any of the + * arguments is of primitive type and is null, null also will be returned + * without invoking the function. * @param returnNullable When false, indicating the invoked method will always return * non-null value. */ @@ -452,7 +462,9 @@ object NewInstance { * @param cls The class to construct. * @param arguments A list of expression to use as arguments to the constructor. * @param propagateNull When true, if any of the arguments is null, then null will be returned - * instead of trying to construct the object. + * instead of trying to construct the object. Also note: when this is false + * but any of the arguments is of primitive type and is null, null also will + * be returned without constructing the object. * @param dataType The type of object being constructed, as a Spark SQL datatype. This allows you * to manually specify the type when the object in question is a valid internal * representation (i.e. ArrayData) instead of an object. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java index 7cd010b936..1b1689668e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java @@ -120,5 +120,24 @@ public class JavaStrLen implements UnboundFunction { public static class JavaStrLenNoImpl extends JavaStrLenBase { } + + // a null-safe version which returns 0 for null arguments + public static class JavaStrLenMagicNullSafe extends JavaStrLenBase { + public int invoke(UTF8String str) { + if (str == null) { + return 0; + } + return str.toString().length(); + } + } + + public static class JavaStrLenStaticMagicNullSafe extends JavaStrLenBase { + public static int invoke(UTF8String str) { + if (str == null) { + return 0; + } + return str.toString().length(); + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index bd4dfe4044..801aee5b03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -20,12 +20,14 @@ package org.apache.spark.sql.connector import java.util import java.util.Collections -import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaStrLen} +import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaLongAdd, JavaStrLen} +import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.JavaLongAddMagic import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK, NO_CODEGEN} import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog, SupportsNamespaces} import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.internal.SQLConf @@ -213,6 +215,39 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { .getMessage.contains("neither implement magic method nor override 'produceResult'")) } + test("SPARK-35389: magic function should handle null arguments") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "strlen"), new JavaStrLen(new JavaStrLenMagicNullSafe)) + addFunction(Identifier.of(Array("ns"), "strlen2"), + new JavaStrLen(new JavaStrLenStaticMagicNullSafe)) + Seq("strlen", "strlen2").foreach { name => + checkAnswer(sql(s"SELECT testcat.ns.$name(CAST(NULL as STRING))"), Row(0) :: Nil) + } + } + + test("SPARK-35389: magic function should handle null primitive arguments") { + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "add"), new JavaLongAdd(new JavaLongAddMagic(false))) + addFunction(Identifier.of(Array("ns"), "static_add"), + new JavaLongAdd(new JavaLongAddMagic(false))) + + Seq("add", "static_add").foreach { name => + Seq(true, false).foreach { codegenEnabled => + val codeGenFactoryMode = if (codegenEnabled) FALLBACK else NO_CODEGEN + + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString, + SQLConf.CODEGEN_FACTORY_MODE.key -> codeGenFactoryMode.toString) { + + checkAnswer(sql(s"SELECT testcat.ns.$name(CAST(NULL as BIGINT), 42L)"), Row(null) :: Nil) + checkAnswer(sql(s"SELECT testcat.ns.$name(42L, CAST(NULL as BIGINT))"), Row(null) :: Nil) + checkAnswer(sql(s"SELECT testcat.ns.$name(42L, 58L)"), Row(100) :: Nil) + checkAnswer(sql(s"SELECT testcat.ns.$name(CAST(NULL as BIGINT), CAST(NULL as BIGINT))"), + Row(null) :: Nil) + } + } + } + } + test("bad bound function (neither scalar nor aggregate)") { catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(BadBoundFunction))