diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8282f7ea62..ec0e44b7f2 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -112,6 +112,53 @@ object MimaExcludes { "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$") + ) ++ Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$2"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$3"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$4"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$5"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$6"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$7"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$8"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$9"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$10"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$11"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$12"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$13"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$14"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$15"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$16"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$17"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$18"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$19"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$20"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$21"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$22"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$23"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24") ) case v if v.startsWith("1.5") => Seq( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 11c7950c06..a04af7f1dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -30,13 +30,18 @@ case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], - inputTypes: Seq[DataType] = Nil) + inputTypes: Seq[DataType] = Nil, + isDeterministic: Boolean = true) extends Expression with ImplicitCastInputTypes with CodegenFallback { override def nullable: Boolean = true override def toString: String = s"UDF(${children.mkString(",")})" + override def foldable: Boolean = deterministic && children.forall(_.foldable) + + override def deterministic: Boolean = isDeterministic && children.forall(_.deterministic) + // scalastyle:off /** This method has been generated by this script diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index fc4d0938c5..f5b95e13e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -58,8 +58,10 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined aggregate function (UDAF). * * @param name the name of the UDAF. - * @param udaf the UDAF needs to be registered. + * @param udaf the UDAF that needs to be registered. * @return the registered UDAF. + * + * @since 1.5.0 */ def register( name: String, @@ -69,6 +71,22 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { udaf } + /** + * Register a user-defined function (UDF). + * + * @param name the name of the UDF. + * @param udf the UDF that needs to be registered. + * @return the registered UDF. + * + * @since 1.6.0 + */ + def register( + name: String, + udf: UserDefinedFunction): UserDefinedFunction = { + functionRegistry.registerFunction(name, udf.builder) + udf + } + // scalastyle:off /* register 0-22 were generated by this script @@ -86,9 +104,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try($inputTypes).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf }""") } @@ -118,9 +136,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -131,9 +149,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -144,9 +162,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -157,9 +175,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -170,9 +188,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -183,9 +201,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -196,9 +214,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -209,9 +227,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -222,9 +240,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -235,9 +253,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -248,9 +266,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -261,9 +279,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -274,9 +292,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -287,9 +305,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -300,9 +318,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -313,9 +331,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -326,9 +344,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -339,9 +357,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -352,9 +370,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -365,9 +383,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -378,9 +396,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -391,9 +409,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } /** @@ -404,9 +422,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) - functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType) + val udf = UserDefinedFunction(func, dataType, inputTypes) + functionRegistry.registerFunction(name, udf.builder) + udf } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index 0f8cd280b5..1319391db5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -44,11 +44,20 @@ import org.apache.spark.sql.types.DataType case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, - inputTypes: Seq[DataType] = Nil) { + inputTypes: Seq[DataType] = Nil, + deterministic: Boolean = true) { def apply(exprs: Column*): Column = { - Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes)) + Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes, deterministic)) } + + protected[sql] def builder: Seq[Expression] => ScalaUDF = { + (exprs: Seq[Expression]) => + ScalaUDF(f, dataType, exprs, inputTypes, deterministic) + } + + def nondeterministic: UserDefinedFunction = + UserDefinedFunction(f, dataType, inputTypes, deterministic = false) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index e0435a0dba..6e510f0b8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ @@ -191,4 +193,107 @@ class UDFSuite extends QueryTest with SharedSQLContext { // pass a decimal to intExpected. assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } + + private def checkNumUDFs(df: DataFrame, expectedNumUDFs: Int): Unit = { + val udfs = df.queryExecution.optimizedPlan.collect { + case p: logical.Project => p.projectList.flatMap { + case e => e.collect { + case udf: ScalaUDF => udf + } + } + }.flatten + assert(udfs.length === expectedNumUDFs) + } + + test("foldable udf") { + import org.apache.spark.sql.functions._ + + val myUDF = udf((x: Int) => x + 1) + + { + val df = sql("SELECT 1 as a") + .select(col("a"), myUDF(col("a")).as("b")) + .select(col("a"), col("b"), myUDF(col("b")).as("c")) + checkNumUDFs(df, 0) + checkAnswer(df, Row(1, 2, 3)) + } + } + + test("nondeterministic udf: using UDFRegistration") { + import org.apache.spark.sql.functions._ + + val myUDF = sqlContext.udf.register("plusOne1", (x: Int) => x + 1) + sqlContext.udf.register("plusOne2", myUDF.nondeterministic) + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), myUDF(col("a")).as("b")) + .select(col("a"), col("b"), myUDF(col("b")).as("c")) + checkNumUDFs(df, 3) + checkAnswer(df, Row(1, 2, 3)) + } + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), callUDF("plusOne1", col("a")).as("b")) + .select(col("a"), col("b"), callUDF("plusOne1", col("b")).as("c")) + checkNumUDFs(df, 3) + checkAnswer(df, Row(1, 2, 3)) + } + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) + .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) + checkNumUDFs(df, 2) + checkAnswer(df, Row(1, 2, 3)) + } + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), callUDF("plusOne2", col("a")).as("b")) + .select(col("a"), col("b"), callUDF("plusOne2", col("b")).as("c")) + checkNumUDFs(df, 2) + checkAnswer(df, Row(1, 2, 3)) + } + } + + test("nondeterministic udf: using udf function") { + import org.apache.spark.sql.functions._ + + val myUDF = udf((x: Int) => x + 1) + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), myUDF(col("a")).as("b")) + .select(col("a"), col("b"), myUDF(col("b")).as("c")) + checkNumUDFs(df, 3) + checkAnswer(df, Row(1, 2, 3)) + } + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) + .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) + checkNumUDFs(df, 2) + checkAnswer(df, Row(1, 2, 3)) + } + + { + // nondeterministicUDF will not be foldable. + val df = sql("SELECT 1 as a") + .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) + .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) + checkNumUDFs(df, 2) + checkAnswer(df, Row(1, 2, 3)) + } + } + + test("override a registered udf") { + sqlContext.udf.register("intExpected", (x: Int) => x) + assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) + + sqlContext.udf.register("intExpected", (x: Int) => x + 1) + assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 2) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 7274479989..f14b2886a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -381,7 +381,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(hadoopConfiguration) @@ -405,7 +405,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(hadoopConfiguration)