[SPARK-3230][SQL] Fix udfs that return structs

We need to convert the case classes into Rows.

Author: Michael Armbrust <michael@databricks.com>

Closes #2133 from marmbrus/structUdfs and squashes the following commits:

189722f [Michael Armbrust] Merge remote-tracking branch 'origin/master' into structUdfs
8e29b1c [Michael Armbrust] Use existing function
d8d0b76 [Michael Armbrust] Fix udfs that return structs
This commit is contained in:
Michael Armbrust 2014-08-28 00:15:23 -07:00
parent 68f75dcdfe
commit 76e3ba4264
4 changed files with 30 additions and 12 deletions

View file

@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst
import java.sql.Timestamp
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.types._
@ -32,6 +31,15 @@ object ScalaReflection {
case class Schema(dataType: DataType, nullable: Boolean)
/** Converts Scala objects to catalyst rows / types */
def convertToCatalyst(a: Any): Any = a match {
case o: Option[_] => o.orNull
case s: Seq[_] => s.map(convertToCatalyst)
case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
case other => other
}
/** Returns a Sequence of attributes for the given case class type. */
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
case Schema(s: StructType, _) =>

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.types.DataType
import org.apache.spark.util.ClosureCleaner
@ -27,6 +28,8 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
def nullable = true
override def toString = s"scalaUDF(${children.mkString(",")})"
/** This method has been generated by this script
(1 to 22).map { x =>
@ -44,7 +47,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
// scalastyle:off
override def eval(input: Row): Any = {
children.size match {
val result = children.size match {
case 0 => function.asInstanceOf[() => Any]()
case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input))
case 2 =>
@ -343,5 +346,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
children(21).eval(input))
}
// scalastyle:on
ScalaReflection.convertToCatalyst(result)
}
}

View file

@ -204,14 +204,6 @@ case class Sort(
*/
@DeveloperApi
object ExistingRdd {
def convertToCatalyst(a: Any): Any = a match {
case o: Option[_] => o.orNull
case s: Seq[_] => s.map(convertToCatalyst)
case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
case other => other
}
def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
data.mapPartitions { iterator =>
if (iterator.isEmpty) {
@ -223,7 +215,7 @@ object ExistingRdd {
bufferedIterator.map { r =>
var i = 0
while (i < mutableRow.length) {
mutableRow(i) = convertToCatalyst(r.productElement(i))
mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i))
i += 1
}
@ -245,6 +237,7 @@ object ExistingRdd {
case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
override def execute() = rdd
}
/**
* :: DeveloperApi ::
* Computes the set of distinct input rows using a HashSet.

View file

@ -22,6 +22,8 @@ import org.apache.spark.sql.test._
/* Implicits */
import TestSQLContext._
case class FunctionResult(f1: String, f2: String)
class UDFSuite extends QueryTest {
test("Simple UDF") {
@ -33,4 +35,14 @@ class UDFSuite extends QueryTest {
registerFunction("strLenScala", (_: String).length + (_:Int))
assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5)
}
test("struct UDF") {
registerFunction("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
val result=
sql("SELECT returnStruct('test', 'test2') as ret")
.select("ret.f1".attr).first().getString(0)
assert(result == "test")
}
}