From 0bcbafb4b868234d695b553f215e21fe66e7ff77 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 11 Jan 2021 11:31:35 -0800 Subject: [PATCH] [SPARK-34002][SQL] Fix the usage of encoder in ScalaUDF ### What changes were proposed in this pull request? This patch fixes few issues when using encoders to serialize input/output in `ScalaUDF`. ### Why are the changes needed? This fixes a bug when using encoders in Scala UDF. First, the output data type should be corrected to the corresponding data type of the object serializer. Second, `catalystConverter` should not serialize `Option[_]` as the ordinary row because in `ScalaUDF` case it is serialized to a column, not the top-level row. Otherwise, there will be a redundant `value` struct wrapping the serialized `Option[_]` object. ### Does this PR introduce _any_ user-facing change? Yes, fixing a bug of `ScalaUDF`. ### How was this patch tested? Unit test. Closes #31103 from viirya/SPARK-34002. Authored-by: Liang-Chi Hsieh Signed-off-by: Dongjoon Hyun --- .../catalyst/encoders/ExpressionEncoder.scala | 6 -- .../sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../apache/spark/sql/UDFRegistration.scala | 64 ++++++++++++------- .../org/apache/spark/sql/functions.scala | 24 +++---- .../org/apache/spark/sql/DatasetSuite.scala | 27 ++++++++ 5 files changed, 80 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 80a0374ae1..665acdbb5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -22,7 +22,6 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.ScalaReflection.Schema import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer} import org.apache.spark.sql.catalyst.expressions._ @@ -304,11 +303,6 @@ case class ExpressionEncoder[T]( StructField(s.name, s.dataType, s.nullable) }) - def dataTypeAndNullable: Schema = { - val dataType = if (isSerializedAsStruct) schema else schema.head.dataType - Schema(dataType, objSerializer.nullable) - } - /** * Returns true if the type `T` is serialized as a struct by `objSerializer`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 4a89d24e5f..8d2558579a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -120,7 +120,7 @@ case class ScalaUDF( */ private def catalystConverter: Any => Any = outputEncoder.map { enc => val toRow = enc.createSerializer().asInstanceOf[Any => Any] - if (enc.isSerializedAsStruct) { + if (enc.isSerializedAsStructForTopLevel) { value: Any => if (value == null) null else toRow(value).asInstanceOf[InternalRow] } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 237cfe18ed..7567ed6335 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -49,6 +49,8 @@ import org.apache.spark.util.Utils @Stable class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends Logging { + import UDFRegistration._ + protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = { log.debug( s""" @@ -135,7 +137,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends | */ |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { | val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - | val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + | val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) | val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = $inputEncoders | val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) | val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -183,7 +185,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -204,7 +206,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -225,7 +227,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -246,7 +248,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -267,7 +269,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -288,7 +290,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -309,7 +311,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -330,7 +332,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -351,7 +353,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -372,7 +374,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -393,7 +395,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -414,7 +416,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -435,7 +437,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -456,7 +458,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -477,7 +479,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -498,7 +500,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -519,7 +521,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -540,7 +542,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -561,7 +563,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -582,7 +584,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -603,7 +605,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -624,7 +626,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Try(ExpressionEncoder[A21]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -645,7 +647,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Try(ExpressionEncoder[A21]()).toOption :: Try(ExpressionEncoder[A22]()).toOption :: Nil val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name) val finalUdf = if (nullable) udf else udf.asNonNullable() @@ -1121,3 +1123,17 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends // scalastyle:on line.size.limit } + +private[sql] object UDFRegistration { + /** + * Obtaining the schema of output encoder for `ScalaUDF`. + * + * As the serialization in `ScalaUDF` is for individual column, not the whole row, + * we just take the data type of vanilla object serializer, not `serializer` which + * is transformed somehow for top-level row. + */ + def outputSchema(outputEncoder: ExpressionEncoder[_]): ScalaReflection.Schema = { + ScalaReflection.Schema(outputEncoder.objSerializer.dataType, + outputEncoder.objSerializer.nullable) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 764e08862a..ff3f568f06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4603,7 +4603,7 @@ object functions { | */ |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { | val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - | val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + | val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) | val inputEncoders = $inputEncoders | val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) | if (nullable) udf else udf.asNonNullable() @@ -4710,7 +4710,7 @@ object functions { */ def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders = Nil val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) if (nullable) udf else udf.asNonNullable() @@ -4727,7 +4727,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) if (nullable) udf else udf.asNonNullable() @@ -4744,7 +4744,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) if (nullable) udf else udf.asNonNullable() @@ -4761,7 +4761,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) if (nullable) udf else udf.asNonNullable() @@ -4778,7 +4778,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) if (nullable) udf else udf.asNonNullable() @@ -4795,7 +4795,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) if (nullable) udf else udf.asNonNullable() @@ -4812,7 +4812,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) if (nullable) udf else udf.asNonNullable() @@ -4829,7 +4829,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) if (nullable) udf else udf.asNonNullable() @@ -4846,7 +4846,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) if (nullable) udf else udf.asNonNullable() @@ -4863,7 +4863,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) if (nullable) udf else udf.asNonNullable() @@ -4880,7 +4880,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val outputEncoder = Try(ExpressionEncoder[RT]()).toOption - val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT]) + val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT]) val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Nil val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder) if (nullable) udf else udf.asNonNullable() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3a169e4878..2ec4c6918a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1982,8 +1982,35 @@ class DatasetSuite extends QueryTest assert(timezone == "Asia/Shanghai") } } + + test("SPARK-34002: Fix broken Option input/output in UDF") { + def f1(bar: Bar): Option[Bar] = { + None + } + + def f2(bar: Option[Bar]): Option[Bar] = { + bar + } + + val udf1 = udf(f1 _).withName("f1") + val udf2 = udf(f2 _).withName("f2") + + val df = (1 to 2).map(i => Tuple1(Bar(1))).toDF("c0") + val withUDF = df + .withColumn("c1", udf1(col("c0"))) + .withColumn("c2", udf2(col("c1"))) + + assert(withUDF.schema == StructType( + StructField("c0", StructType(StructField("a", IntegerType, false) :: Nil)) :: + StructField("c1", StructType(StructField("a", IntegerType, false) :: Nil)) :: + StructField("c2", StructType(StructField("a", IntegerType, false) :: Nil)) :: Nil)) + + checkAnswer(withUDF, Row(Row(1), null, null) :: Row(Row(1), null, null) :: Nil) + } } +case class Bar(a: Int) + object AssertExecutionId { def apply(id: Long): Long = { assert(TaskContext.get().getLocalProperty(SQLExecution.EXECUTION_ID_KEY) != null)