[SPARK-6734] [SQL] Add UDTF.close support in Generate
Some third-party UDTF extensions generate additional rows in the "GenericUDTF.close()" method, which is supported / documented by Hive.
https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF
However, Spark SQL ignores the "GenericUDTF.close()", and it causes bug while porting job from Hive to Spark SQL.
Author: Cheng Hao <hao.cheng@intel.com>
Closes #5383 from chenghao-intel/udtf_close and squashes the following commits:
98b4e4b [Cheng Hao] Support UDTF.close
(cherry picked from commit 0da254fb29
)
Signed-off-by: Cheng Lian <lian@databricks.com>
This commit is contained in:
parent
d78f0e1b48
commit
42cf4a2a5e
|
@ -56,6 +56,12 @@ abstract class Generator extends Expression {
|
|||
|
||||
/** Should be implemented by child classes to perform specific Generators. */
|
||||
override def eval(input: Row): TraversableOnce[Row]
|
||||
|
||||
/**
|
||||
* Notifies that there are no more rows to process, clean up code, and additional
|
||||
* rows can be made here.
|
||||
*/
|
||||
def terminate(): TraversableOnce[Row] = Nil
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -21,6 +21,18 @@ import org.apache.spark.annotation.DeveloperApi
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
|
||||
/**
|
||||
* For lazy computing, be sure the generator.terminate() called in the very last
|
||||
* TODO reusing the CompletionIterator?
|
||||
*/
|
||||
private[execution] sealed case class LazyIterator(func: () => TraversableOnce[Row])
|
||||
extends Iterator[Row] {
|
||||
|
||||
lazy val results = func().toIterator
|
||||
override def hasNext: Boolean = results.hasNext
|
||||
override def next(): Row = results.next()
|
||||
}
|
||||
|
||||
/**
|
||||
* :: DeveloperApi ::
|
||||
* Applies a [[catalyst.expressions.Generator Generator]] to a stream of input rows, combining the
|
||||
|
@ -47,27 +59,33 @@ case class Generate(
|
|||
val boundGenerator = BindReferences.bindReference(generator, child.output)
|
||||
|
||||
protected override def doExecute(): RDD[Row] = {
|
||||
// boundGenerator.terminate() should be triggered after all of the rows in the partition
|
||||
if (join) {
|
||||
child.execute().mapPartitions { iter =>
|
||||
val nullValues = Seq.fill(generator.elementTypes.size)(Literal(null))
|
||||
// Used to produce rows with no matches when outer = true.
|
||||
val outerProjection =
|
||||
newProjection(child.output ++ nullValues, child.output)
|
||||
|
||||
val joinProjection = newProjection(output, output)
|
||||
val generatorNullRow = Row.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null))
|
||||
val joinedRow = new JoinedRow
|
||||
|
||||
iter.flatMap {row =>
|
||||
iter.flatMap { row =>
|
||||
// we should always set the left (child output)
|
||||
joinedRow.withLeft(row)
|
||||
val outputRows = boundGenerator.eval(row)
|
||||
if (outer && outputRows.isEmpty) {
|
||||
outerProjection(row) :: Nil
|
||||
joinedRow.withRight(generatorNullRow) :: Nil
|
||||
} else {
|
||||
outputRows.map(or => joinProjection(joinedRow(row, or)))
|
||||
outputRows.map(or => joinedRow.withRight(or))
|
||||
}
|
||||
} ++ LazyIterator(() => boundGenerator.terminate()).map { row =>
|
||||
// we leave the left side as the last element of its child output
|
||||
// keep it the same as Hive does
|
||||
joinedRow.withRight(row)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row)))
|
||||
child.execute().mapPartitions { iter =>
|
||||
iter.flatMap(row => boundGenerator.eval(row)) ++
|
||||
LazyIterator(() => boundGenerator.terminate())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -483,7 +483,11 @@ private[hive] case class HiveGenericUdtf(
|
|||
extends Generator with HiveInspectors {
|
||||
|
||||
@transient
|
||||
protected lazy val function: GenericUDTF = funcWrapper.createFunction()
|
||||
protected lazy val function: GenericUDTF = {
|
||||
val fun: GenericUDTF = funcWrapper.createFunction()
|
||||
fun.setCollector(collector)
|
||||
fun
|
||||
}
|
||||
|
||||
@transient
|
||||
protected lazy val inputInspectors = children.map(toInspector)
|
||||
|
@ -494,6 +498,9 @@ private[hive] case class HiveGenericUdtf(
|
|||
@transient
|
||||
protected lazy val udtInput = new Array[AnyRef](children.length)
|
||||
|
||||
@transient
|
||||
protected lazy val collector = new UDTFCollector
|
||||
|
||||
lazy val elementTypes = outputInspector.getAllStructFieldRefs.map {
|
||||
field => (inspectorToDataType(field.getFieldObjectInspector), true)
|
||||
}
|
||||
|
@ -502,8 +509,7 @@ private[hive] case class HiveGenericUdtf(
|
|||
outputInspector // Make sure initialized.
|
||||
|
||||
val inputProjection = new InterpretedProjection(children)
|
||||
val collector = new UDTFCollector
|
||||
function.setCollector(collector)
|
||||
|
||||
function.process(wrap(inputProjection(input), inputInspectors, udtInput))
|
||||
collector.collectRows()
|
||||
}
|
||||
|
@ -525,6 +531,12 @@ private[hive] case class HiveGenericUdtf(
|
|||
}
|
||||
}
|
||||
|
||||
override def terminate(): TraversableOnce[Row] = {
|
||||
outputInspector // Make sure initialized.
|
||||
function.close()
|
||||
collector.collectRows()
|
||||
}
|
||||
|
||||
override def toString: String = {
|
||||
s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
|
||||
}
|
||||
|
|
BIN
sql/hive/src/test/resources/TestUDTF.jar
Normal file
BIN
sql/hive/src/test/resources/TestUDTF.jar
Normal file
Binary file not shown.
|
@ -0,0 +1,2 @@
|
|||
97 500
|
||||
97 500
|
|
@ -0,0 +1,2 @@
|
|||
3
|
||||
3
|
|
@ -20,6 +20,9 @@ package org.apache.spark.sql.hive.execution
|
|||
import java.io.File
|
||||
import java.util.{Locale, TimeZone}
|
||||
|
||||
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
|
||||
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
|
||||
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorFactory, StructObjectInspector, ObjectInspector}
|
||||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import scala.util.Try
|
||||
|
@ -51,14 +54,32 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
|
|||
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
|
||||
// Add Locale setting
|
||||
Locale.setDefault(Locale.US)
|
||||
sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}")
|
||||
// The function source code can be found at:
|
||||
// https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF
|
||||
sql(
|
||||
"""
|
||||
|CREATE TEMPORARY FUNCTION udtf_count2
|
||||
|AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
override def afterAll() {
|
||||
TestHive.cacheTables = false
|
||||
TimeZone.setDefault(originalTimeZone)
|
||||
Locale.setDefault(originalLocale)
|
||||
sql("DROP TEMPORARY FUNCTION udtf_count2")
|
||||
}
|
||||
|
||||
createQueryTest("Test UDTF.close in Lateral Views",
|
||||
"""
|
||||
|SELECT key, cc
|
||||
|FROM src LATERAL VIEW udtf_count2(value) dd AS cc
|
||||
""".stripMargin, false) // false mean we have to keep the temp function in registry
|
||||
|
||||
createQueryTest("Test UDTF.close in SELECT",
|
||||
"SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) table", false)
|
||||
|
||||
test("SPARK-4908: concurrent hive native commands") {
|
||||
(1 to 100).par.map { _ =>
|
||||
sql("USE default")
|
||||
|
|
Loading…
Reference in a new issue