[SPARK-34002][SQL] Fix the usage of encoder in ScalaUDF

### What changes were proposed in this pull request?

This patch fixes few issues when using encoders to serialize input/output in `ScalaUDF`.

### Why are the changes needed?

This fixes a bug when using encoders in Scala UDF. First, the output data type should be corrected to the corresponding data type of the object serializer. Second, `catalystConverter` should not serialize `Option[_]` as the ordinary row because in `ScalaUDF` case it is serialized to a column, not the top-level row. Otherwise, there will be a redundant `value` struct wrapping the serialized `Option[_]` object.

### Does this PR introduce _any_ user-facing change?

Yes, fixing a bug of `ScalaUDF`.

### How was this patch tested?

Unit test.

Closes #31103 from viirya/SPARK-34002.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
Liang-Chi Hsieh 2021-01-11 11:31:35 -08:00 committed by Dongjoon Hyun
parent 1495ad8c46
commit 0bcbafb4b8
5 changed files with 80 additions and 43 deletions

View file

@ -22,7 +22,6 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.ScalaReflection.Schema
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer}
import org.apache.spark.sql.catalyst.expressions._
@ -304,11 +303,6 @@ case class ExpressionEncoder[T](
StructField(s.name, s.dataType, s.nullable)
})
def dataTypeAndNullable: Schema = {
val dataType = if (isSerializedAsStruct) schema else schema.head.dataType
Schema(dataType, objSerializer.nullable)
}
/**
* Returns true if the type `T` is serialized as a struct by `objSerializer`.
*/

View file

@ -120,7 +120,7 @@ case class ScalaUDF(
*/
private def catalystConverter: Any => Any = outputEncoder.map { enc =>
val toRow = enc.createSerializer().asInstanceOf[Any => Any]
if (enc.isSerializedAsStruct) {
if (enc.isSerializedAsStructForTopLevel) {
value: Any =>
if (value == null) null else toRow(value).asInstanceOf[InternalRow]
} else {

View file

@ -49,6 +49,8 @@ import org.apache.spark.util.Utils
@Stable
class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends Logging {
import UDFRegistration._
protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = {
log.debug(
s"""
@ -135,7 +137,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
| */
|def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = {
| val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
| val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
| val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
| val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = $inputEncoders
| val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
| val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -183,7 +185,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -204,7 +206,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -225,7 +227,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -246,7 +248,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -267,7 +269,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -288,7 +290,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -309,7 +311,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -330,7 +332,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -351,7 +353,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -372,7 +374,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -393,7 +395,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -414,7 +416,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -435,7 +437,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -456,7 +458,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -477,7 +479,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -498,7 +500,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -519,7 +521,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -540,7 +542,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -561,7 +563,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -582,7 +584,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -603,7 +605,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -624,7 +626,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Try(ExpressionEncoder[A21]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -645,7 +647,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Try(ExpressionEncoder[A11]()).toOption :: Try(ExpressionEncoder[A12]()).toOption :: Try(ExpressionEncoder[A13]()).toOption :: Try(ExpressionEncoder[A14]()).toOption :: Try(ExpressionEncoder[A15]()).toOption :: Try(ExpressionEncoder[A16]()).toOption :: Try(ExpressionEncoder[A17]()).toOption :: Try(ExpressionEncoder[A18]()).toOption :: Try(ExpressionEncoder[A19]()).toOption :: Try(ExpressionEncoder[A20]()).toOption :: Try(ExpressionEncoder[A21]()).toOption :: Try(ExpressionEncoder[A22]()).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputEncoders, outputEncoder).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
@ -1121,3 +1123,17 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
// scalastyle:on line.size.limit
}
private[sql] object UDFRegistration {
/**
* Obtaining the schema of output encoder for `ScalaUDF`.
*
* As the serialization in `ScalaUDF` is for individual column, not the whole row,
* we just take the data type of vanilla object serializer, not `serializer` which
* is transformed somehow for top-level row.
*/
def outputSchema(outputEncoder: ExpressionEncoder[_]): ScalaReflection.Schema = {
ScalaReflection.Schema(outputEncoder.objSerializer.dataType,
outputEncoder.objSerializer.nullable)
}
}

View file

@ -4603,7 +4603,7 @@ object functions {
| */
|def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = {
| val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
| val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
| val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
| val inputEncoders = $inputEncoders
| val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
| if (nullable) udf else udf.asNonNullable()
@ -4710,7 +4710,7 @@ object functions {
*/
def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@ -4727,7 +4727,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@ -4744,7 +4744,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@ -4761,7 +4761,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@ -4778,7 +4778,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@ -4795,7 +4795,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@ -4812,7 +4812,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@ -4829,7 +4829,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@ -4846,7 +4846,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@ -4863,7 +4863,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()
@ -4880,7 +4880,7 @@ object functions {
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
val outputEncoder = Try(ExpressionEncoder[RT]()).toOption
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(_.dataTypeAndNullable).getOrElse(ScalaReflection.schemaFor[RT])
val ScalaReflection.Schema(dataType, nullable) = outputEncoder.map(UDFRegistration.outputSchema).getOrElse(ScalaReflection.schemaFor[RT])
val inputEncoders = Try(ExpressionEncoder[A1]()).toOption :: Try(ExpressionEncoder[A2]()).toOption :: Try(ExpressionEncoder[A3]()).toOption :: Try(ExpressionEncoder[A4]()).toOption :: Try(ExpressionEncoder[A5]()).toOption :: Try(ExpressionEncoder[A6]()).toOption :: Try(ExpressionEncoder[A7]()).toOption :: Try(ExpressionEncoder[A8]()).toOption :: Try(ExpressionEncoder[A9]()).toOption :: Try(ExpressionEncoder[A10]()).toOption :: Nil
val udf = SparkUserDefinedFunction(f, dataType, inputEncoders, outputEncoder)
if (nullable) udf else udf.asNonNullable()

View file

@ -1982,8 +1982,35 @@ class DatasetSuite extends QueryTest
assert(timezone == "Asia/Shanghai")
}
}
test("SPARK-34002: Fix broken Option input/output in UDF") {
def f1(bar: Bar): Option[Bar] = {
None
}
def f2(bar: Option[Bar]): Option[Bar] = {
bar
}
val udf1 = udf(f1 _).withName("f1")
val udf2 = udf(f2 _).withName("f2")
val df = (1 to 2).map(i => Tuple1(Bar(1))).toDF("c0")
val withUDF = df
.withColumn("c1", udf1(col("c0")))
.withColumn("c2", udf2(col("c1")))
assert(withUDF.schema == StructType(
StructField("c0", StructType(StructField("a", IntegerType, false) :: Nil)) ::
StructField("c1", StructType(StructField("a", IntegerType, false) :: Nil)) ::
StructField("c2", StructType(StructField("a", IntegerType, false) :: Nil)) :: Nil))
checkAnswer(withUDF, Row(Row(1), null, null) :: Row(Row(1), null, null) :: Nil)
}
}
case class Bar(a: Int)
object AssertExecutionId {
def apply(id: Long): Long = {
assert(TaskContext.get().getLocalProperty(SQLExecution.EXECUTION_ID_KEY) != null)