[SPARK-11469][SQL] Allow users to define nondeterministic udfs.

This is the first task (https://issues.apache.org/jira/browse/SPARK-11469) of https://issues.apache.org/jira/browse/SPARK-11438

Author: Yin Huai <yhuai@databricks.com>

Closes #9393 from yhuai/udfNondeterministic.
This commit is contained in:
Yin Huai 2015-11-02 21:18:38 -08:00
parent efaa4721b5
commit 9cf56c96b7
6 changed files with 262 additions and 78 deletions

View file

@ -112,6 +112,53 @@ object MimaExcludes {
"org.apache.spark.rdd.MapPartitionsWithPreparationRDD"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.rdd.MapPartitionsWithPreparationRDD$")
) ++ Seq(
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$2"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$3"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$4"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$5"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$6"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$7"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$8"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$9"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$10"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$11"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$12"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$13"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$14"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$15"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$16"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$17"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$18"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$19"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$20"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$21"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$22"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$23"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24")
)
case v if v.startsWith("1.5") =>
Seq(

View file

@ -30,13 +30,18 @@ case class ScalaUDF(
function: AnyRef,
dataType: DataType,
children: Seq[Expression],
inputTypes: Seq[DataType] = Nil)
inputTypes: Seq[DataType] = Nil,
isDeterministic: Boolean = true)
extends Expression with ImplicitCastInputTypes with CodegenFallback {
override def nullable: Boolean = true
override def toString: String = s"UDF(${children.mkString(",")})"
override def foldable: Boolean = deterministic && children.forall(_.foldable)
override def deterministic: Boolean = isDeterministic && children.forall(_.deterministic)
// scalastyle:off
/** This method has been generated by this script

View file

@ -58,8 +58,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
* Register a user-defined aggregate function (UDAF).
*
* @param name the name of the UDAF.
* @param udaf the UDAF needs to be registered.
* @param udaf the UDAF that needs to be registered.
* @return the registered UDAF.
*
* @since 1.5.0
*/
def register(
name: String,
@ -69,6 +71,22 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
udaf
}
/**
* Register a user-defined function (UDF).
*
* @param name the name of the UDF.
* @param udf the UDF that needs to be registered.
* @return the registered UDF.
*
* @since 1.6.0
*/
def register(
name: String,
udf: UserDefinedFunction): UserDefinedFunction = {
functionRegistry.registerFunction(name, udf.builder)
udf
}
// scalastyle:off
/* register 0-22 were generated by this script
@ -86,9 +104,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try($inputTypes).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}""")
}
@ -118,9 +136,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -131,9 +149,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -144,9 +162,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -157,9 +175,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -170,9 +188,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -183,9 +201,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -196,9 +214,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -209,9 +227,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -222,9 +240,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -235,9 +253,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -248,9 +266,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -261,9 +279,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -274,9 +292,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -287,9 +305,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -300,9 +318,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -313,9 +331,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -326,9 +344,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -339,9 +357,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -352,9 +370,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -365,9 +383,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -378,9 +396,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -391,9 +409,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
/**
@ -404,9 +422,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = {
val dataType = ScalaReflection.schemaFor[RT].dataType
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil)
def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
functionRegistry.registerFunction(name, builder)
UserDefinedFunction(func, dataType)
val udf = UserDefinedFunction(func, dataType, inputTypes)
functionRegistry.registerFunction(name, udf.builder)
udf
}
//////////////////////////////////////////////////////////////////////////////////////////////

View file

@ -44,11 +44,20 @@ import org.apache.spark.sql.types.DataType
case class UserDefinedFunction protected[sql] (
f: AnyRef,
dataType: DataType,
inputTypes: Seq[DataType] = Nil) {
inputTypes: Seq[DataType] = Nil,
deterministic: Boolean = true) {
def apply(exprs: Column*): Column = {
Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes))
Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes, deterministic))
}
protected[sql] def builder: Seq[Expression] => ScalaUDF = {
(exprs: Seq[Expression]) =>
ScalaUDF(f, dataType, exprs, inputTypes, deterministic)
}
def nondeterministic: UserDefinedFunction =
UserDefinedFunction(f, dataType, inputTypes, deterministic = false)
}
/**

View file

@ -17,6 +17,8 @@
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.expressions.ScalaUDF
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
@ -191,4 +193,107 @@ class UDFSuite extends QueryTest with SharedSQLContext {
// pass a decimal to intExpected.
assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
}
private def checkNumUDFs(df: DataFrame, expectedNumUDFs: Int): Unit = {
val udfs = df.queryExecution.optimizedPlan.collect {
case p: logical.Project => p.projectList.flatMap {
case e => e.collect {
case udf: ScalaUDF => udf
}
}
}.flatten
assert(udfs.length === expectedNumUDFs)
}
test("foldable udf") {
import org.apache.spark.sql.functions._
val myUDF = udf((x: Int) => x + 1)
{
val df = sql("SELECT 1 as a")
.select(col("a"), myUDF(col("a")).as("b"))
.select(col("a"), col("b"), myUDF(col("b")).as("c"))
checkNumUDFs(df, 0)
checkAnswer(df, Row(1, 2, 3))
}
}
test("nondeterministic udf: using UDFRegistration") {
import org.apache.spark.sql.functions._
val myUDF = sqlContext.udf.register("plusOne1", (x: Int) => x + 1)
sqlContext.udf.register("plusOne2", myUDF.nondeterministic)
{
val df = sqlContext.range(1, 2).select(col("id").as("a"))
.select(col("a"), myUDF(col("a")).as("b"))
.select(col("a"), col("b"), myUDF(col("b")).as("c"))
checkNumUDFs(df, 3)
checkAnswer(df, Row(1, 2, 3))
}
{
val df = sqlContext.range(1, 2).select(col("id").as("a"))
.select(col("a"), callUDF("plusOne1", col("a")).as("b"))
.select(col("a"), col("b"), callUDF("plusOne1", col("b")).as("c"))
checkNumUDFs(df, 3)
checkAnswer(df, Row(1, 2, 3))
}
{
val df = sqlContext.range(1, 2).select(col("id").as("a"))
.select(col("a"), myUDF.nondeterministic(col("a")).as("b"))
.select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c"))
checkNumUDFs(df, 2)
checkAnswer(df, Row(1, 2, 3))
}
{
val df = sqlContext.range(1, 2).select(col("id").as("a"))
.select(col("a"), callUDF("plusOne2", col("a")).as("b"))
.select(col("a"), col("b"), callUDF("plusOne2", col("b")).as("c"))
checkNumUDFs(df, 2)
checkAnswer(df, Row(1, 2, 3))
}
}
test("nondeterministic udf: using udf function") {
import org.apache.spark.sql.functions._
val myUDF = udf((x: Int) => x + 1)
{
val df = sqlContext.range(1, 2).select(col("id").as("a"))
.select(col("a"), myUDF(col("a")).as("b"))
.select(col("a"), col("b"), myUDF(col("b")).as("c"))
checkNumUDFs(df, 3)
checkAnswer(df, Row(1, 2, 3))
}
{
val df = sqlContext.range(1, 2).select(col("id").as("a"))
.select(col("a"), myUDF.nondeterministic(col("a")).as("b"))
.select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c"))
checkNumUDFs(df, 2)
checkAnswer(df, Row(1, 2, 3))
}
{
// nondeterministicUDF will not be foldable.
val df = sql("SELECT 1 as a")
.select(col("a"), myUDF.nondeterministic(col("a")).as("b"))
.select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c"))
checkNumUDFs(df, 2)
checkAnswer(df, Row(1, 2, 3))
}
}
test("override a registered udf") {
sqlContext.udf.register("intExpected", (x: Int) => x)
assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
sqlContext.udf.register("intExpected", (x: Int) => x + 1)
assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 2)
}
}

View file

@ -381,7 +381,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
sqlContext.udf.register("div0", (x: Int) => x / 0)
withTempPath { dir =>
intercept[org.apache.spark.SparkException] {
sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath)
sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath)
}
val path = new Path(dir.getCanonicalPath, "_temporary")
val fs = path.getFileSystem(hadoopConfiguration)
@ -405,7 +405,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
sqlContext.udf.register("div0", (x: Int) => x / 0)
withTempPath { dir =>
intercept[org.apache.spark.SparkException] {
sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath)
sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath)
}
val path = new Path(dir.getCanonicalPath, "_temporary")
val fs = path.getFileSystem(hadoopConfiguration)