[SPARK-3230][SQL] Fix udfs that return structs
We need to convert the case classes into Rows. Author: Michael Armbrust <michael@databricks.com> Closes #2133 from marmbrus/structUdfs and squashes the following commits: 189722f [Michael Armbrust] Merge remote-tracking branch 'origin/master' into structUdfs 8e29b1c [Michael Armbrust] Use existing function d8d0b76 [Michael Armbrust] Fix udfs that return structs
This commit is contained in:
parent
68f75dcdfe
commit
76e3ba4264
|
@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst
|
|||
|
||||
import java.sql.Timestamp
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||
import org.apache.spark.sql.catalyst.expressions.AttributeReference
|
||||
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
|
||||
import org.apache.spark.sql.catalyst.types._
|
||||
|
||||
|
@ -32,6 +31,15 @@ object ScalaReflection {
|
|||
|
||||
case class Schema(dataType: DataType, nullable: Boolean)
|
||||
|
||||
/** Converts Scala objects to catalyst rows / types */
|
||||
def convertToCatalyst(a: Any): Any = a match {
|
||||
case o: Option[_] => o.orNull
|
||||
case s: Seq[_] => s.map(convertToCatalyst)
|
||||
case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
|
||||
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
|
||||
case other => other
|
||||
}
|
||||
|
||||
/** Returns a Sequence of attributes for the given case class type. */
|
||||
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
|
||||
case Schema(s: StructType, _) =>
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection
|
||||
import org.apache.spark.sql.catalyst.types.DataType
|
||||
import org.apache.spark.util.ClosureCleaner
|
||||
|
||||
|
@ -27,6 +28,8 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
|
|||
|
||||
def nullable = true
|
||||
|
||||
override def toString = s"scalaUDF(${children.mkString(",")})"
|
||||
|
||||
/** This method has been generated by this script
|
||||
|
||||
(1 to 22).map { x =>
|
||||
|
@ -44,7 +47,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
|
|||
|
||||
// scalastyle:off
|
||||
override def eval(input: Row): Any = {
|
||||
children.size match {
|
||||
val result = children.size match {
|
||||
case 0 => function.asInstanceOf[() => Any]()
|
||||
case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input))
|
||||
case 2 =>
|
||||
|
@ -343,5 +346,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
|
|||
children(21).eval(input))
|
||||
}
|
||||
// scalastyle:on
|
||||
|
||||
ScalaReflection.convertToCatalyst(result)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -204,14 +204,6 @@ case class Sort(
|
|||
*/
|
||||
@DeveloperApi
|
||||
object ExistingRdd {
|
||||
def convertToCatalyst(a: Any): Any = a match {
|
||||
case o: Option[_] => o.orNull
|
||||
case s: Seq[_] => s.map(convertToCatalyst)
|
||||
case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
|
||||
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
|
||||
case other => other
|
||||
}
|
||||
|
||||
def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
|
||||
data.mapPartitions { iterator =>
|
||||
if (iterator.isEmpty) {
|
||||
|
@ -223,7 +215,7 @@ object ExistingRdd {
|
|||
bufferedIterator.map { r =>
|
||||
var i = 0
|
||||
while (i < mutableRow.length) {
|
||||
mutableRow(i) = convertToCatalyst(r.productElement(i))
|
||||
mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i))
|
||||
i += 1
|
||||
}
|
||||
|
||||
|
@ -245,6 +237,7 @@ object ExistingRdd {
|
|||
case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
|
||||
override def execute() = rdd
|
||||
}
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
* Computes the set of distinct input rows using a HashSet.
|
||||
|
|
|
@ -22,6 +22,8 @@ import org.apache.spark.sql.test._
|
|||
/* Implicits */
|
||||
import TestSQLContext._
|
||||
|
||||
case class FunctionResult(f1: String, f2: String)
|
||||
|
||||
class UDFSuite extends QueryTest {
|
||||
|
||||
test("Simple UDF") {
|
||||
|
@ -33,4 +35,14 @@ class UDFSuite extends QueryTest {
|
|||
registerFunction("strLenScala", (_: String).length + (_:Int))
|
||||
assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5)
|
||||
}
|
||||
|
||||
|
||||
test("struct UDF") {
|
||||
registerFunction("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
|
||||
|
||||
val result=
|
||||
sql("SELECT returnStruct('test', 'test2') as ret")
|
||||
.select("ret.f1".attr).first().getString(0)
|
||||
assert(result == "test")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue