[SPARK-26323][SQL] Scala UDF should still check input types even if some inputs are of type Any

## What changes were proposed in this pull request?

For Scala UDF, when checking input nullability, we will skip inputs with type `Any`, and only check the inputs that provide nullability info.

We should do the same for checking input types.

## How was this patch tested?

new tests

Closes #23275 from cloud-fan/udf.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
This commit is contained in:
Wenchen Fan 2019-01-08 22:44:33 +08:00 committed by Hyukjin Kwon
parent 29a7d2da44
commit 72a572ffd6
7 changed files with 184 additions and 193 deletions

View file

@ -882,7 +882,18 @@ object TypeCoercion {
case udf: ScalaUDF if udf.inputTypes.nonEmpty =>
val children = udf.children.zip(udf.inputTypes).map { case (in, expected) =>
implicitCast(in, udfInputToCastType(in.dataType, expected)).getOrElse(in)
// Currently Scala UDF will only expect `AnyDataType` at top level, so this trick works.
// In the future we should create types like `AbstractArrayType`, so that Scala UDF can
// accept inputs of array type of arbitrary element type.
if (expected == AnyDataType) {
in
} else {
implicitCast(
in,
udfInputToCastType(in.dataType, expected.asInstanceOf[DataType])
).getOrElse(in)
}
}
udf.withNewChildren(children)
}

View file

@ -21,7 +21,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{AbstractDataType, DataType}
/**
* User-defined function.
@ -48,7 +48,7 @@ case class ScalaUDF(
dataType: DataType,
children: Seq[Expression],
inputsNullSafe: Seq[Boolean],
inputTypes: Seq[DataType] = Nil,
inputTypes: Seq[AbstractDataType] = Nil,
udfName: Option[String] = None,
nullable: Boolean = true,
udfDeterministic: Boolean = true)

View file

@ -96,7 +96,7 @@ private[sql] object TypeCollection {
/**
* An `AbstractDataType` that matches any concrete data types.
*/
protected[sql] object AnyDataType extends AbstractDataType {
protected[sql] object AnyDataType extends AbstractDataType with Serializable {
// Note that since AnyDataType matches any concrete types, defaultConcreteType should never
// be invoked.

View file

@ -123,17 +123,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
|def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = {
| val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
| val inputSchemas: Seq[Option[ScalaReflection.Schema]] = $inputSchemas
| val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
| val finalUdf = if (nullable) udf else udf.asNonNullable()
| def builder(e: Seq[Expression]) = if (e.length == $x) {
| ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
| if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
| Some(name), nullable, udfDeterministic = true)
| finalUdf.createScalaUDF(e)
| } else {
| throw new AnalysisException("Invalid number of arguments for function " + name +
| ". Expected: $x; Found: " + e.length)
| }
| functionRegistry.createOrReplaceTempFunction(name, builder)
| val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
| if (nullable) udf else udf.asNonNullable()
| finalUdf
|}""".stripMargin)
}
@ -170,17 +169,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 0) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 0; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -191,17 +189,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 1) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 1; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -212,17 +209,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 2) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 2; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -233,17 +229,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 3) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 3; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -254,17 +249,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 4) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 4; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -275,17 +269,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 5) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 5; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -296,17 +289,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 6) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 6; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -317,17 +309,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 7) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 7; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -338,17 +329,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 8) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 8; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -359,17 +349,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 9) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 9; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -380,17 +369,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 10) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 10; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -401,17 +389,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 11) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 11; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -422,17 +409,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 12) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 12; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -443,17 +429,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 13) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 13; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -464,17 +449,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 14) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 14; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -485,17 +469,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 15) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 15; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -506,17 +489,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 16) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 16; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -527,17 +509,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 17) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 17; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -548,17 +529,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 18) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 18; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -569,17 +549,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 19) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 19; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -590,17 +569,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 20) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 20; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -611,17 +589,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Try(ScalaReflection.schemaFor[A21]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 21) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 21; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
/**
@ -632,17 +609,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas: Seq[Option[ScalaReflection.Schema]] = Try(ScalaReflection.schemaFor[A1]).toOption :: Try(ScalaReflection.schemaFor[A2]).toOption :: Try(ScalaReflection.schemaFor[A3]).toOption :: Try(ScalaReflection.schemaFor[A4]).toOption :: Try(ScalaReflection.schemaFor[A5]).toOption :: Try(ScalaReflection.schemaFor[A6]).toOption :: Try(ScalaReflection.schemaFor[A7]).toOption :: Try(ScalaReflection.schemaFor[A8]).toOption :: Try(ScalaReflection.schemaFor[A9]).toOption :: Try(ScalaReflection.schemaFor[A10]).toOption :: Try(ScalaReflection.schemaFor[A11]).toOption :: Try(ScalaReflection.schemaFor[A12]).toOption :: Try(ScalaReflection.schemaFor[A13]).toOption :: Try(ScalaReflection.schemaFor[A14]).toOption :: Try(ScalaReflection.schemaFor[A15]).toOption :: Try(ScalaReflection.schemaFor[A16]).toOption :: Try(ScalaReflection.schemaFor[A17]).toOption :: Try(ScalaReflection.schemaFor[A18]).toOption :: Try(ScalaReflection.schemaFor[A19]).toOption :: Try(ScalaReflection.schemaFor[A20]).toOption :: Try(ScalaReflection.schemaFor[A21]).toOption :: Try(ScalaReflection.schemaFor[A22]).toOption :: Nil
val udf = SparkUserDefinedFunction(func, dataType, inputSchemas).withName(name)
val finalUdf = if (nullable) udf else udf.asNonNullable()
def builder(e: Seq[Expression]) = if (e.length == 22) {
ScalaUDF(func, dataType, e, inputSchemas.map(_.map(_.nullable).getOrElse(true)),
if (inputSchemas.contains(None)) Nil else inputSchemas.map(_.get.dataType),
Some(name), nullable, udfDeterministic = true)
finalUdf.createScalaUDF(e)
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 22; Found: " + e.length)
}
functionRegistry.createOrReplaceTempFunction(name, builder)
val udf = SparkUserDefinedFunction.create(func, dataType, inputSchemas).withName(name)
if (nullable) udf else udf.asNonNullable()
finalUdf
}
//////////////////////////////////////////////////////////////////////////////////////////////

View file

@ -20,8 +20,8 @@ package org.apache.spark.sql.expressions
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.ScalaUDF
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
import org.apache.spark.sql.types.{AnyDataType, DataType}
/**
* A user-defined function. To create one, use the `udf` functions in `functions`.
@ -88,40 +88,47 @@ sealed abstract class UserDefinedFunction {
private[sql] case class SparkUserDefinedFunction(
f: AnyRef,
dataType: DataType,
inputTypes: Option[Seq[DataType]],
nullableTypes: Option[Seq[Boolean]],
inputSchemas: Seq[Option[ScalaReflection.Schema]],
name: Option[String] = None,
nullable: Boolean = true,
deterministic: Boolean = true) extends UserDefinedFunction {
@scala.annotation.varargs
override def apply(exprs: Column*): Column = {
// TODO: make sure this class is only instantiated through `SparkUserDefinedFunction.create()`
// and `nullableTypes` is always set.
if (inputTypes.isDefined) {
assert(inputTypes.get.length == nullableTypes.get.length)
}
val inputsNullSafe = nullableTypes.getOrElse {
ScalaReflection.getParameterTypeNullability(f)
}
Column(ScalaUDF(
f,
dataType,
exprs.map(_.expr),
inputsNullSafe,
inputTypes.getOrElse(Nil),
udfName = name,
nullable = nullable,
udfDeterministic = deterministic))
Column(createScalaUDF(exprs.map(_.expr)))
}
override def withName(name: String): UserDefinedFunction = {
private[sql] def createScalaUDF(exprs: Seq[Expression]): ScalaUDF = {
// It's possible that some of the inputs don't have a specific type(e.g. `Any`), skip type
// check and null check for them.
val inputTypes = inputSchemas.map(_.map(_.dataType).getOrElse(AnyDataType))
val inputsNullSafe = if (inputSchemas.isEmpty) {
// This is for backward compatibility of `functions.udf(AnyRef, DataType)`. We need to
// do reflection of the lambda function object and see if its arguments are nullable or not.
// This doesn't work for Scala 2.12 and we should consider removing this workaround, as Spark
// uses Scala 2.12 by default since 3.0.
ScalaReflection.getParameterTypeNullability(f)
} else {
inputSchemas.map(_.map(_.nullable).getOrElse(true))
}
ScalaUDF(
f,
dataType,
exprs,
inputsNullSafe,
inputTypes,
udfName = name,
nullable = nullable,
udfDeterministic = deterministic)
}
override def withName(name: String): SparkUserDefinedFunction = {
copy(name = Option(name))
}
override def asNonNullable(): UserDefinedFunction = {
override def asNonNullable(): SparkUserDefinedFunction = {
if (!nullable) {
this
} else {
@ -129,7 +136,7 @@ private[sql] case class SparkUserDefinedFunction(
}
}
override def asNondeterministic(): UserDefinedFunction = {
override def asNondeterministic(): SparkUserDefinedFunction = {
if (!deterministic) {
this
} else {
@ -137,19 +144,3 @@ private[sql] case class SparkUserDefinedFunction(
}
}
}
private[sql] object SparkUserDefinedFunction {
def create(
f: AnyRef,
dataType: DataType,
inputSchemas: Seq[Option[ScalaReflection.Schema]]): UserDefinedFunction = {
val inputTypes = if (inputSchemas.contains(None)) {
None
} else {
Some(inputSchemas.map(_.get.dataType))
}
val nullableTypes = Some(inputSchemas.map(_.map(_.nullable).getOrElse(true)))
SparkUserDefinedFunction(f, dataType, inputTypes, nullableTypes)
}
}

View file

@ -3874,7 +3874,7 @@ object functions {
|def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = {
| val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
| val inputSchemas = $inputSchemas
| val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
| val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
| if (nullable) udf else udf.asNonNullable()
|}""".stripMargin)
}
@ -3897,7 +3897,7 @@ object functions {
| */
|def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = {
| val func = f$anyCast.call($anyParams)
| SparkUserDefinedFunction.create($funcCall, returnType, inputSchemas = Seq.fill($i)(None))
| SparkUserDefinedFunction($funcCall, returnType, inputSchemas = Seq.fill($i)(None))
|}""".stripMargin)
}
@ -3919,7 +3919,7 @@ object functions {
def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas = Nil
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@ -3935,7 +3935,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Nil
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@ -3951,7 +3951,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Nil
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@ -3967,7 +3967,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Nil
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@ -3983,7 +3983,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Nil
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@ -3999,7 +3999,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Nil
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@ -4015,7 +4015,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Nil
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@ -4031,7 +4031,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Nil
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@ -4047,7 +4047,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Nil
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@ -4063,7 +4063,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Nil
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@ -4079,7 +4079,7 @@ object functions {
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A10])).toOption :: Nil
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
if (nullable) udf else udf.asNonNullable()
}
@ -4098,7 +4098,7 @@ object functions {
*/
def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF0[Any]].call()
SparkUserDefinedFunction.create(() => func, returnType, inputSchemas = Seq.fill(0)(None))
SparkUserDefinedFunction(() => func, returnType, inputSchemas = Seq.fill(0)(None))
}
/**
@ -4112,7 +4112,7 @@ object functions {
*/
def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any)
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(1)(None))
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(1)(None))
}
/**
@ -4126,7 +4126,7 @@ object functions {
*/
def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(2)(None))
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(2)(None))
}
/**
@ -4140,7 +4140,7 @@ object functions {
*/
def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any)
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(3)(None))
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(3)(None))
}
/**
@ -4154,7 +4154,7 @@ object functions {
*/
def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any)
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(4)(None))
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(4)(None))
}
/**
@ -4168,7 +4168,7 @@ object functions {
*/
def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any)
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(5)(None))
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(5)(None))
}
/**
@ -4182,7 +4182,7 @@ object functions {
*/
def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(6)(None))
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(6)(None))
}
/**
@ -4196,7 +4196,7 @@ object functions {
*/
def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(7)(None))
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(7)(None))
}
/**
@ -4210,7 +4210,7 @@ object functions {
*/
def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(8)(None))
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(8)(None))
}
/**
@ -4224,7 +4224,7 @@ object functions {
*/
def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(9)(None))
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(9)(None))
}
/**
@ -4238,7 +4238,7 @@ object functions {
*/
def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(10)(None))
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(10)(None))
}
// scalastyle:on parameter.number
@ -4257,9 +4257,7 @@ object functions {
* @since 2.0.0
*/
def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = {
// TODO: should call SparkUserDefinedFunction.create() instead but inputSchemas is currently
// unavailable. We may need to create type-safe overloaded versions of udf() methods.
SparkUserDefinedFunction(f, dataType, inputTypes = None, nullableTypes = None)
SparkUserDefinedFunction(f, dataType, inputSchemas = Nil)
}
/**

View file

@ -450,4 +450,19 @@ class UDFSuite extends QueryTest with SharedSQLContext {
})
checkAnswer(df2.select(udf2($"col1")), Seq(Row(Map("a" -> "2011000000000002456556"))))
}
test("SPARK-26323 Verify input type check - with udf()") {
val f = udf((x: Long, y: Any) => x)
val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j").select(f($"i", $"j"))
checkAnswer(df, Seq(Row(1L), Row(2L)))
}
test("SPARK-26323 Verify input type check - with udf.register") {
withTable("t") {
Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.format("json").saveAsTable("t")
spark.udf.register("f", (x: Long, y: Any) => x)
val df = spark.sql("SELECT f(i, j) FROM t")
checkAnswer(df, Seq(Row(1L), Row(2L)))
}
}
}