[SPARK-24935][SQL] fix Hive UDAF with two aggregation buffers
## What changes were proposed in this pull request?
Hive UDAF knows the aggregation mode when creating the aggregation buffer, so that it can create different buffers for different inputs: the original data or the aggregation buffer. Please see an example in the [sketches library](7f9e76e9e0/src/main/java/com/yahoo/sketches/hive/cpc/DataToSketchUDAF.java (L107)
).
However, the Hive UDAF adapter in Spark always creates the buffer with partial1 mode, which can only deal with one input: the original data. This PR fixes it.
All credits go to pgandhi999 , who investigate the problem and study the Hive UDAF behaviors, and write the tests.
close https://github.com/apache/spark/pull/23778
## How was this patch tested?
a new test
Closes #24144 from cloud-fan/hive.
Lead-authored-by: pgandhi <pgandhi@verizonmedia.com>
Co-authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: gatorsmile <gatorsmile@gmail.com>
This commit is contained in:
parent
a15f17ce27
commit
a6c207c9c0
|
@ -352,21 +352,8 @@ private[hive] case class HiveUDAFFunction(
|
||||||
HiveEvaluator(evaluator, evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputInspectors))
|
HiveEvaluator(evaluator, evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputInspectors))
|
||||||
}
|
}
|
||||||
|
|
||||||
// The UDAF evaluator used to merge partial aggregation results.
|
// The UDAF evaluator used to consume partial aggregation results and produce final results.
|
||||||
@transient
|
// Hive `ObjectInspector` used to inspect final results.
|
||||||
private lazy val partial2ModeEvaluator = {
|
|
||||||
val evaluator = newEvaluator()
|
|
||||||
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partial1HiveEvaluator.objectInspector))
|
|
||||||
evaluator
|
|
||||||
}
|
|
||||||
|
|
||||||
// Spark SQL data type of partial aggregation results
|
|
||||||
@transient
|
|
||||||
private lazy val partialResultDataType =
|
|
||||||
inspectorToDataType(partial1HiveEvaluator.objectInspector)
|
|
||||||
|
|
||||||
// The UDAF evaluator used to compute the final result from a partial aggregation result objects.
|
|
||||||
// Hive `ObjectInspector` used to inspect the final aggregation result object.
|
|
||||||
@transient
|
@transient
|
||||||
private lazy val finalHiveEvaluator = {
|
private lazy val finalHiveEvaluator = {
|
||||||
val evaluator = newEvaluator()
|
val evaluator = newEvaluator()
|
||||||
|
@ -375,6 +362,11 @@ private[hive] case class HiveUDAFFunction(
|
||||||
evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(partial1HiveEvaluator.objectInspector)))
|
evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(partial1HiveEvaluator.objectInspector)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Spark SQL data type of partial aggregation results
|
||||||
|
@transient
|
||||||
|
private lazy val partialResultDataType =
|
||||||
|
inspectorToDataType(partial1HiveEvaluator.objectInspector)
|
||||||
|
|
||||||
// Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
|
// Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
|
||||||
@transient
|
@transient
|
||||||
private lazy val inputWrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
|
private lazy val inputWrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray
|
||||||
|
@ -401,25 +393,43 @@ private[hive] case class HiveUDAFFunction(
|
||||||
s"$name($distinct${children.map(_.sql).mkString(", ")})"
|
s"$name($distinct${children.map(_.sql).mkString(", ")})"
|
||||||
}
|
}
|
||||||
|
|
||||||
override def createAggregationBuffer(): AggregationBuffer =
|
// The hive UDAF may create different buffers to handle different inputs: original data or
|
||||||
partial1HiveEvaluator.evaluator.getNewAggregationBuffer
|
// aggregate buffer. However, the Spark UDAF framework does not expose this information when
|
||||||
|
// creating the buffer. Here we return null, and create the buffer in `update` and `merge`
|
||||||
|
// on demand, so that we can know what input we are dealing with.
|
||||||
|
override def createAggregationBuffer(): AggregationBuffer = null
|
||||||
|
|
||||||
@transient
|
@transient
|
||||||
private lazy val inputProjection = UnsafeProjection.create(children)
|
private lazy val inputProjection = UnsafeProjection.create(children)
|
||||||
|
|
||||||
override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = {
|
override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = {
|
||||||
|
// The input is original data, we create buffer with the partial1 evaluator.
|
||||||
|
val nonNullBuffer = if (buffer == null) {
|
||||||
|
partial1HiveEvaluator.evaluator.getNewAggregationBuffer
|
||||||
|
} else {
|
||||||
|
buffer
|
||||||
|
}
|
||||||
|
|
||||||
partial1HiveEvaluator.evaluator.iterate(
|
partial1HiveEvaluator.evaluator.iterate(
|
||||||
buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
|
nonNullBuffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
|
||||||
buffer
|
nonNullBuffer
|
||||||
}
|
}
|
||||||
|
|
||||||
override def merge(buffer: AggregationBuffer, input: AggregationBuffer): AggregationBuffer = {
|
override def merge(buffer: AggregationBuffer, input: AggregationBuffer): AggregationBuffer = {
|
||||||
|
// The input is aggregate buffer, we create buffer with the final evaluator.
|
||||||
|
val nonNullBuffer = if (buffer == null) {
|
||||||
|
finalHiveEvaluator.evaluator.getNewAggregationBuffer
|
||||||
|
} else {
|
||||||
|
buffer
|
||||||
|
}
|
||||||
|
|
||||||
// The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
|
// The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
|
||||||
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
|
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
|
||||||
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
|
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
|
||||||
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
|
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
|
||||||
partial2ModeEvaluator.merge(buffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
|
finalHiveEvaluator.evaluator.merge(
|
||||||
buffer
|
nonNullBuffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
|
||||||
|
nonNullBuffer
|
||||||
}
|
}
|
||||||
|
|
||||||
override def eval(buffer: AggregationBuffer): Any = {
|
override def eval(buffer: AggregationBuffer): Any = {
|
||||||
|
@ -450,11 +460,19 @@ private[hive] case class HiveUDAFFunction(
|
||||||
private val mutableRow = new GenericInternalRow(1)
|
private val mutableRow = new GenericInternalRow(1)
|
||||||
|
|
||||||
def serialize(buffer: AggregationBuffer): Array[Byte] = {
|
def serialize(buffer: AggregationBuffer): Array[Byte] = {
|
||||||
|
// The buffer may be null if there is no input. It's unclear if the hive UDAF accepts null
|
||||||
|
// buffer, for safety we create an empty buffer here.
|
||||||
|
val nonNullBuffer = if (buffer == null) {
|
||||||
|
partial1HiveEvaluator.evaluator.getNewAggregationBuffer
|
||||||
|
} else {
|
||||||
|
buffer
|
||||||
|
}
|
||||||
|
|
||||||
// `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
|
// `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
|
||||||
// that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
|
// that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
|
||||||
// Then we can unwrap it to a Spark SQL value.
|
// Then we can unwrap it to a Spark SQL value.
|
||||||
mutableRow.update(0, partialResultUnwrapper(
|
mutableRow.update(0, partialResultUnwrapper(
|
||||||
partial1HiveEvaluator.evaluator.terminatePartial(buffer)))
|
partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer)))
|
||||||
val unsafeRow = projection(mutableRow)
|
val unsafeRow = projection(mutableRow)
|
||||||
val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes)
|
val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes)
|
||||||
unsafeRow.writeTo(bytes)
|
unsafeRow.writeTo(bytes)
|
||||||
|
@ -466,11 +484,11 @@ private[hive] case class HiveUDAFFunction(
|
||||||
// returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The
|
// returned by `GenericUDAFEvaluator.terminatePartial()` back to an `AggregationBuffer`. The
|
||||||
// workaround here is creating an initial `AggregationBuffer` first and then merge the
|
// workaround here is creating an initial `AggregationBuffer` first and then merge the
|
||||||
// deserialized object into the buffer.
|
// deserialized object into the buffer.
|
||||||
val buffer = partial2ModeEvaluator.getNewAggregationBuffer
|
val buffer = finalHiveEvaluator.evaluator.getNewAggregationBuffer
|
||||||
val unsafeRow = new UnsafeRow(1)
|
val unsafeRow = new UnsafeRow(1)
|
||||||
unsafeRow.pointTo(bytes, bytes.length)
|
unsafeRow.pointTo(bytes, bytes.length)
|
||||||
val partialResult = unsafeRow.get(0, partialResultDataType)
|
val partialResult = unsafeRow.get(0, partialResultDataType)
|
||||||
partial2ModeEvaluator.merge(buffer, partialResultWrapper(partialResult))
|
finalHiveEvaluator.evaluator.merge(buffer, partialResultWrapper(partialResult))
|
||||||
buffer
|
buffer
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
|
||||||
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo
|
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo
|
||||||
import test.org.apache.spark.sql.MyDoubleAvg
|
import test.org.apache.spark.sql.MyDoubleAvg
|
||||||
|
|
||||||
|
import org.apache.spark.SparkException
|
||||||
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
|
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
|
||||||
import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
|
import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
|
||||||
import org.apache.spark.sql.hive.test.TestHiveSingleton
|
import org.apache.spark.sql.hive.test.TestHiveSingleton
|
||||||
|
@ -40,6 +41,7 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'")
|
sql(s"CREATE TEMPORARY FUNCTION mock AS '${classOf[MockUDAF].getName}'")
|
||||||
sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'")
|
sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'")
|
||||||
|
sql(s"CREATE TEMPORARY FUNCTION mock2 AS '${classOf[MockUDAF2].getName}'")
|
||||||
|
|
||||||
Seq(
|
Seq(
|
||||||
(0: Integer) -> "val_0",
|
(0: Integer) -> "val_0",
|
||||||
|
@ -92,6 +94,23 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("customized Hive UDAF with two aggregation buffers") {
|
||||||
|
val df = sql("SELECT key % 2, mock2(value) FROM t GROUP BY key % 2")
|
||||||
|
|
||||||
|
val aggs = df.queryExecution.executedPlan.collect {
|
||||||
|
case agg: ObjectHashAggregateExec => agg
|
||||||
|
}
|
||||||
|
|
||||||
|
// There should be two aggregate operators, one for partial aggregation, and the other for
|
||||||
|
// global aggregation.
|
||||||
|
assert(aggs.length == 2)
|
||||||
|
|
||||||
|
checkAnswer(df, Seq(
|
||||||
|
Row(0, Row(1, 1)),
|
||||||
|
Row(1, Row(1, 1))
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
test("call JAVA UDAF") {
|
test("call JAVA UDAF") {
|
||||||
withTempView("temp") {
|
withTempView("temp") {
|
||||||
withUserDefinedFunction("myDoubleAvg" -> false) {
|
withUserDefinedFunction("myDoubleAvg" -> false) {
|
||||||
|
@ -127,12 +146,22 @@ class MockUDAF extends AbstractGenericUDAFResolver {
|
||||||
override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator
|
override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class MockUDAF2 extends AbstractGenericUDAFResolver {
|
||||||
|
override def getEvaluator(info: Array[TypeInfo]): GenericUDAFEvaluator = new MockUDAFEvaluator2
|
||||||
|
}
|
||||||
|
|
||||||
class MockUDAFBuffer(var nonNullCount: Long, var nullCount: Long)
|
class MockUDAFBuffer(var nonNullCount: Long, var nullCount: Long)
|
||||||
extends GenericUDAFEvaluator.AbstractAggregationBuffer {
|
extends GenericUDAFEvaluator.AbstractAggregationBuffer {
|
||||||
|
|
||||||
override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2
|
override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class MockUDAFBuffer2(var nonNullCount: Long, var nullCount: Long)
|
||||||
|
extends GenericUDAFEvaluator.AbstractAggregationBuffer {
|
||||||
|
|
||||||
|
override def estimate(): Int = JavaDataModel.PRIMITIVES2 * 2
|
||||||
|
}
|
||||||
|
|
||||||
class MockUDAFEvaluator extends GenericUDAFEvaluator {
|
class MockUDAFEvaluator extends GenericUDAFEvaluator {
|
||||||
private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
|
private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
|
||||||
|
|
||||||
|
@ -184,3 +213,80 @@ class MockUDAFEvaluator extends GenericUDAFEvaluator {
|
||||||
|
|
||||||
override def terminate(agg: AggregationBuffer): AnyRef = terminatePartial(agg)
|
override def terminate(agg: AggregationBuffer): AnyRef = terminatePartial(agg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Same as MockUDAFEvaluator but using two aggregation buffers, one for PARTIAL1 and the other
|
||||||
|
// for PARTIAL2.
|
||||||
|
class MockUDAFEvaluator2 extends GenericUDAFEvaluator {
|
||||||
|
private val nonNullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
|
||||||
|
|
||||||
|
private val nullCountOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector
|
||||||
|
private var aggMode: Mode = null
|
||||||
|
|
||||||
|
private val bufferOI = {
|
||||||
|
val fieldNames = Seq("nonNullCount", "nullCount").asJava
|
||||||
|
val fieldOIs = Seq(nonNullCountOI: ObjectInspector, nullCountOI: ObjectInspector).asJava
|
||||||
|
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
|
||||||
|
}
|
||||||
|
|
||||||
|
private val nonNullCountField = bufferOI.getStructFieldRef("nonNullCount")
|
||||||
|
|
||||||
|
private val nullCountField = bufferOI.getStructFieldRef("nullCount")
|
||||||
|
|
||||||
|
override def getNewAggregationBuffer: AggregationBuffer = {
|
||||||
|
// These 2 modes consume original data.
|
||||||
|
if (aggMode == Mode.PARTIAL1 || aggMode == Mode.COMPLETE) {
|
||||||
|
new MockUDAFBuffer(0L, 0L)
|
||||||
|
} else {
|
||||||
|
new MockUDAFBuffer2(0L, 0L)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def reset(agg: AggregationBuffer): Unit = {
|
||||||
|
val buffer = agg.asInstanceOf[MockUDAFBuffer]
|
||||||
|
buffer.nonNullCount = 0L
|
||||||
|
buffer.nullCount = 0L
|
||||||
|
}
|
||||||
|
|
||||||
|
override def init(mode: Mode, parameters: Array[ObjectInspector]): ObjectInspector = {
|
||||||
|
aggMode = mode
|
||||||
|
bufferOI
|
||||||
|
}
|
||||||
|
|
||||||
|
override def iterate(agg: AggregationBuffer, parameters: Array[AnyRef]): Unit = {
|
||||||
|
val buffer = agg.asInstanceOf[MockUDAFBuffer]
|
||||||
|
if (parameters.head eq null) {
|
||||||
|
buffer.nullCount += 1L
|
||||||
|
} else {
|
||||||
|
buffer.nonNullCount += 1L
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override def merge(agg: AggregationBuffer, partial: Object): Unit = {
|
||||||
|
if (partial ne null) {
|
||||||
|
val nonNullCount = nonNullCountOI.get(bufferOI.getStructFieldData(partial, nonNullCountField))
|
||||||
|
val nullCount = nullCountOI.get(bufferOI.getStructFieldData(partial, nullCountField))
|
||||||
|
val buffer = agg.asInstanceOf[MockUDAFBuffer2]
|
||||||
|
buffer.nonNullCount += nonNullCount
|
||||||
|
buffer.nullCount += nullCount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// As this method is called for both states, Partial1 and Partial2, the hack in the method
|
||||||
|
// to check for class of aggregation buffer was necessary.
|
||||||
|
override def terminatePartial(agg: AggregationBuffer): AnyRef = {
|
||||||
|
var result: AnyRef = null
|
||||||
|
if (agg.getClass.toString.contains("MockUDAFBuffer2")) {
|
||||||
|
val buffer = agg.asInstanceOf[MockUDAFBuffer2]
|
||||||
|
result = Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
|
||||||
|
} else {
|
||||||
|
val buffer = agg.asInstanceOf[MockUDAFBuffer]
|
||||||
|
result = Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
override def terminate(agg: AggregationBuffer): AnyRef = {
|
||||||
|
val buffer = agg.asInstanceOf[MockUDAFBuffer2]
|
||||||
|
Array[Object](buffer.nonNullCount: java.lang.Long, buffer.nullCount: java.lang.Long)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue