[SPARK-24935][SQL][FOLLOWUP] support INIT -> UPDATE -> MERGE -> FINISH in Hive UDAF adapter
## What changes were proposed in this pull request?
This is a followup of https://github.com/apache/spark/pull/24144 . #24144 missed one case: when hash aggregate fallback to sort aggregate, the life cycle of UDAF is: INIT -> UPDATE -> MERGE -> FINISH.
However, not all Hive UDAF can support it. Hive UDAF knows the aggregation mode when creating the aggregation buffer, so that it can create different buffers for different inputs: the original data or the aggregation buffer. Please see an example in the [sketches library](7f9e76e9e0/src/main/java/com/yahoo/sketches/hive/cpc/DataToSketchUDAF.java (L107)
). The buffer for UPDATE may not support MERGE.
This PR updates the Hive UDAF adapter in Spark to support INIT -> UPDATE -> MERGE -> FINISH, by turning it to INIT -> UPDATE -> FINISH + IINIT -> MERGE -> FINISH.
## How was this patch tested?
a new test case
Closes #24459 from cloud-fan/hive-udaf.
Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
618d6bff71
commit
7432e7ded4
|
@ -304,6 +304,13 @@ private[hive] case class HiveGenericUDTF(
|
|||
* - `wrap()`/`wrapperFor()`: from 3 to 1
|
||||
* - `unwrap()`/`unwrapperFor()`: from 1 to 3
|
||||
* - `GenericUDAFEvaluator.terminatePartial()`: from 2 to 3
|
||||
*
|
||||
* Note that, Hive UDAF is initialized with aggregate mode, and some specific Hive UDAFs can't
|
||||
* mix UPDATE and MERGE actions during its life cycle. However, Spark may do UPDATE on a UDAF and
|
||||
* then do MERGE, in case of hash aggregate falling back to sort aggregate. To work around this
|
||||
* issue, we track the ability to do MERGE in the Hive UDAF aggregate buffer. If Spark does
|
||||
* UPDATE then MERGE, we can detect it and re-create the aggregate buffer with a different
|
||||
* aggregate mode.
|
||||
*/
|
||||
private[hive] case class HiveUDAFFunction(
|
||||
name: String,
|
||||
|
@ -312,7 +319,7 @@ private[hive] case class HiveUDAFFunction(
|
|||
isUDAFBridgeRequired: Boolean = false,
|
||||
mutableAggBufferOffset: Int = 0,
|
||||
inputAggBufferOffset: Int = 0)
|
||||
extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer]
|
||||
extends TypedImperativeAggregate[HiveUDAFBuffer]
|
||||
with HiveInspectors
|
||||
with UserDefinedExpression {
|
||||
|
||||
|
@ -410,55 +417,70 @@ private[hive] case class HiveUDAFFunction(
|
|||
// aggregate buffer. However, the Spark UDAF framework does not expose this information when
|
||||
// creating the buffer. Here we return null, and create the buffer in `update` and `merge`
|
||||
// on demand, so that we can know what input we are dealing with.
|
||||
override def createAggregationBuffer(): AggregationBuffer = null
|
||||
override def createAggregationBuffer(): HiveUDAFBuffer = null
|
||||
|
||||
@transient
|
||||
private lazy val inputProjection = UnsafeProjection.create(children)
|
||||
|
||||
override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = {
|
||||
override def update(buffer: HiveUDAFBuffer, input: InternalRow): HiveUDAFBuffer = {
|
||||
// The input is original data, we create buffer with the partial1 evaluator.
|
||||
val nonNullBuffer = if (buffer == null) {
|
||||
partial1HiveEvaluator.evaluator.getNewAggregationBuffer
|
||||
HiveUDAFBuffer(partial1HiveEvaluator.evaluator.getNewAggregationBuffer, false)
|
||||
} else {
|
||||
buffer
|
||||
}
|
||||
|
||||
assert(!nonNullBuffer.canDoMerge, "can not call `merge` then `update` on a Hive UDAF.")
|
||||
|
||||
partial1HiveEvaluator.evaluator.iterate(
|
||||
nonNullBuffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
|
||||
nonNullBuffer.buf, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
|
||||
nonNullBuffer
|
||||
}
|
||||
|
||||
override def merge(buffer: AggregationBuffer, input: AggregationBuffer): AggregationBuffer = {
|
||||
override def merge(buffer: HiveUDAFBuffer, input: HiveUDAFBuffer): HiveUDAFBuffer = {
|
||||
// The input is aggregate buffer, we create buffer with the final evaluator.
|
||||
val nonNullBuffer = if (buffer == null) {
|
||||
finalHiveEvaluator.evaluator.getNewAggregationBuffer
|
||||
HiveUDAFBuffer(finalHiveEvaluator.evaluator.getNewAggregationBuffer, true)
|
||||
} else {
|
||||
buffer
|
||||
}
|
||||
|
||||
// It's possible that we've called `update` of this Hive UDAF, and some specific Hive UDAF
|
||||
// implementation can't mix the `update` and `merge` calls during its life cycle. To work
|
||||
// around it, here we create a fresh buffer with final evaluator, and merge the existing buffer
|
||||
// to it, and replace the existing buffer with it.
|
||||
val mergeableBuf = if (!nonNullBuffer.canDoMerge) {
|
||||
val newBuf = finalHiveEvaluator.evaluator.getNewAggregationBuffer
|
||||
finalHiveEvaluator.evaluator.merge(
|
||||
newBuf, partial1HiveEvaluator.evaluator.terminatePartial(nonNullBuffer.buf))
|
||||
HiveUDAFBuffer(newBuf, true)
|
||||
} else {
|
||||
nonNullBuffer
|
||||
}
|
||||
|
||||
// The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation
|
||||
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
|
||||
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
|
||||
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
|
||||
finalHiveEvaluator.evaluator.merge(
|
||||
nonNullBuffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
|
||||
nonNullBuffer
|
||||
mergeableBuf.buf, partial1HiveEvaluator.evaluator.terminatePartial(input.buf))
|
||||
mergeableBuf
|
||||
}
|
||||
|
||||
override def eval(buffer: AggregationBuffer): Any = {
|
||||
resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer))
|
||||
override def eval(buffer: HiveUDAFBuffer): Any = {
|
||||
resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer.buf))
|
||||
}
|
||||
|
||||
override def serialize(buffer: AggregationBuffer): Array[Byte] = {
|
||||
override def serialize(buffer: HiveUDAFBuffer): Array[Byte] = {
|
||||
// Serializes an `AggregationBuffer` that holds partial aggregation results so that we can
|
||||
// shuffle it for global aggregation later.
|
||||
aggBufferSerDe.serialize(buffer)
|
||||
aggBufferSerDe.serialize(buffer.buf)
|
||||
}
|
||||
|
||||
override def deserialize(bytes: Array[Byte]): AggregationBuffer = {
|
||||
override def deserialize(bytes: Array[Byte]): HiveUDAFBuffer = {
|
||||
// Deserializes an `AggregationBuffer` from the shuffled partial aggregation phase to prepare
|
||||
// for global aggregation by merging multiple partial aggregation results within a single group.
|
||||
aggBufferSerDe.deserialize(bytes)
|
||||
HiveUDAFBuffer(aggBufferSerDe.deserialize(bytes), false)
|
||||
}
|
||||
|
||||
// Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
|
||||
|
@ -506,3 +528,5 @@ private[hive] case class HiveUDAFFunction(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class HiveUDAFBuffer(buf: AggregationBuffer, canDoMerge: Boolean)
|
||||
|
|
|
@ -28,10 +28,10 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn
|
|||
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo
|
||||
import test.org.apache.spark.sql.MyDoubleAvg
|
||||
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
|
||||
import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
|
||||
import org.apache.spark.sql.hive.test.TestHiveSingleton
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
|
||||
class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
|
||||
|
@ -94,21 +94,33 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
|
|||
))
|
||||
}
|
||||
|
||||
test("customized Hive UDAF with two aggregation buffers") {
|
||||
val df = sql("SELECT key % 2, mock2(value) FROM t GROUP BY key % 2")
|
||||
test("SPARK-24935: customized Hive UDAF with two aggregation buffers") {
|
||||
withTempView("v") {
|
||||
spark.range(100).createTempView("v")
|
||||
val df = sql("SELECT id % 2, mock2(id) FROM v GROUP BY id % 2")
|
||||
|
||||
val aggs = df.queryExecution.executedPlan.collect {
|
||||
case agg: ObjectHashAggregateExec => agg
|
||||
val aggs = df.queryExecution.executedPlan.collect {
|
||||
case agg: ObjectHashAggregateExec => agg
|
||||
}
|
||||
|
||||
// There should be two aggregate operators, one for partial aggregation, and the other for
|
||||
// global aggregation.
|
||||
assert(aggs.length == 2)
|
||||
|
||||
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1") {
|
||||
checkAnswer(df, Seq(
|
||||
Row(0, Row(50, 0)),
|
||||
Row(1, Row(50, 0))
|
||||
))
|
||||
}
|
||||
|
||||
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") {
|
||||
checkAnswer(df, Seq(
|
||||
Row(0, Row(50, 0)),
|
||||
Row(1, Row(50, 0))
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// There should be two aggregate operators, one for partial aggregation, and the other for
|
||||
// global aggregation.
|
||||
assert(aggs.length == 2)
|
||||
|
||||
checkAnswer(df, Seq(
|
||||
Row(0, Row(1, 1)),
|
||||
Row(1, Row(1, 1))
|
||||
))
|
||||
}
|
||||
|
||||
test("call JAVA UDAF") {
|
||||
|
|
Loading…
Reference in a new issue