[SPARK-25768][SQL] fix constant argument expecting UDAFs
## What changes were proposed in this pull request? Without this PR some UDAFs like `GenericUDAFPercentileApprox` can throw an exception because expecting a constant parameter (object inspector) as a particular argument. The exception is thrown because `toPrettySQL` call in `ResolveAliases` analyzer rule transforms a `Literal` parameter to a `PrettyAttribute` which is then transformed to an `ObjectInspector` instead of a `ConstantObjectInspector`. The exception comes from `getEvaluator` method of `GenericUDAFPercentileApprox` that actually shouldn't be called during `toPrettySQL` transformation. The reason why it is called are the non lazy fields in `HiveUDAFFunction`. This PR makes all fields of `HiveUDAFFunction` lazy. ## How was this patch tested? added new UT Closes #22766 from peter-toth/SPARK-25768. Authored-by: Peter Toth <peter.toth@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
e8167768cf
commit
f38594fc56
|
@ -340,39 +340,40 @@ private[hive] case class HiveUDAFFunction(
|
|||
resolver.getEvaluator(parameterInfo)
|
||||
}
|
||||
|
||||
// The UDAF evaluator used to consume raw input rows and produce partial aggregation results.
|
||||
@transient
|
||||
private lazy val partial1ModeEvaluator = newEvaluator()
|
||||
private case class HiveEvaluator(
|
||||
evaluator: GenericUDAFEvaluator,
|
||||
objectInspector: ObjectInspector)
|
||||
|
||||
// The UDAF evaluator used to consume raw input rows and produce partial aggregation results.
|
||||
// Hive `ObjectInspector` used to inspect partial aggregation results.
|
||||
@transient
|
||||
private val partialResultInspector = partial1ModeEvaluator.init(
|
||||
GenericUDAFEvaluator.Mode.PARTIAL1,
|
||||
inputInspectors
|
||||
)
|
||||
private lazy val partial1HiveEvaluator = {
|
||||
val evaluator = newEvaluator()
|
||||
HiveEvaluator(evaluator, evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputInspectors))
|
||||
}
|
||||
|
||||
// The UDAF evaluator used to merge partial aggregation results.
|
||||
@transient
|
||||
private lazy val partial2ModeEvaluator = {
|
||||
val evaluator = newEvaluator()
|
||||
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partialResultInspector))
|
||||
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partial1HiveEvaluator.objectInspector))
|
||||
evaluator
|
||||
}
|
||||
|
||||
// Spark SQL data type of partial aggregation results
|
||||
@transient
|
||||
private lazy val partialResultDataType = inspectorToDataType(partialResultInspector)
|
||||
private lazy val partialResultDataType =
|
||||
inspectorToDataType(partial1HiveEvaluator.objectInspector)
|
||||
|
||||
// The UDAF evaluator used to compute the final result from a partial aggregation result objects.
|
||||
@transient
|
||||
private lazy val finalModeEvaluator = newEvaluator()
|
||||
|
||||
// Hive `ObjectInspector` used to inspect the final aggregation result object.
|
||||
@transient
|
||||
private val returnInspector = finalModeEvaluator.init(
|
||||
GenericUDAFEvaluator.Mode.FINAL,
|
||||
Array(partialResultInspector)
|
||||
)
|
||||
private lazy val finalHiveEvaluator = {
|
||||
val evaluator = newEvaluator()
|
||||
HiveEvaluator(
|
||||
evaluator,
|
||||
evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(partial1HiveEvaluator.objectInspector)))
|
||||
}
|
||||
|
||||
// Wrapper functions used to wrap Spark SQL input arguments into Hive specific format.
|
||||
@transient
|
||||
|
@ -381,7 +382,7 @@ private[hive] case class HiveUDAFFunction(
|
|||
// Unwrapper function used to unwrap final aggregation result objects returned by Hive UDAFs into
|
||||
// Spark SQL specific format.
|
||||
@transient
|
||||
private lazy val resultUnwrapper = unwrapperFor(returnInspector)
|
||||
private lazy val resultUnwrapper = unwrapperFor(finalHiveEvaluator.objectInspector)
|
||||
|
||||
@transient
|
||||
private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
|
||||
|
@ -391,7 +392,7 @@ private[hive] case class HiveUDAFFunction(
|
|||
|
||||
override def nullable: Boolean = true
|
||||
|
||||
override lazy val dataType: DataType = inspectorToDataType(returnInspector)
|
||||
override lazy val dataType: DataType = inspectorToDataType(finalHiveEvaluator.objectInspector)
|
||||
|
||||
override def prettyName: String = name
|
||||
|
||||
|
@ -401,13 +402,13 @@ private[hive] case class HiveUDAFFunction(
|
|||
}
|
||||
|
||||
override def createAggregationBuffer(): AggregationBuffer =
|
||||
partial1ModeEvaluator.getNewAggregationBuffer
|
||||
partial1HiveEvaluator.evaluator.getNewAggregationBuffer
|
||||
|
||||
@transient
|
||||
private lazy val inputProjection = UnsafeProjection.create(children)
|
||||
|
||||
override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = {
|
||||
partial1ModeEvaluator.iterate(
|
||||
partial1HiveEvaluator.evaluator.iterate(
|
||||
buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes))
|
||||
buffer
|
||||
}
|
||||
|
@ -417,12 +418,12 @@ private[hive] case class HiveUDAFFunction(
|
|||
// buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts
|
||||
// this `AggregationBuffer`s into this format before shuffling partial aggregation results, and
|
||||
// calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
|
||||
partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input))
|
||||
partial2ModeEvaluator.merge(buffer, partial1HiveEvaluator.evaluator.terminatePartial(input))
|
||||
buffer
|
||||
}
|
||||
|
||||
override def eval(buffer: AggregationBuffer): Any = {
|
||||
resultUnwrapper(finalModeEvaluator.terminate(buffer))
|
||||
resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer))
|
||||
}
|
||||
|
||||
override def serialize(buffer: AggregationBuffer): Array[Byte] = {
|
||||
|
@ -439,9 +440,10 @@ private[hive] case class HiveUDAFFunction(
|
|||
|
||||
// Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
|
||||
private class AggregationBufferSerDe {
|
||||
private val partialResultUnwrapper = unwrapperFor(partialResultInspector)
|
||||
private val partialResultUnwrapper = unwrapperFor(partial1HiveEvaluator.objectInspector)
|
||||
|
||||
private val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType)
|
||||
private val partialResultWrapper =
|
||||
wrapperFor(partial1HiveEvaluator.objectInspector, partialResultDataType)
|
||||
|
||||
private val projection = UnsafeProjection.create(Array(partialResultDataType))
|
||||
|
||||
|
@ -451,7 +453,8 @@ private[hive] case class HiveUDAFFunction(
|
|||
// `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object
|
||||
// that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`.
|
||||
// Then we can unwrap it to a Spark SQL value.
|
||||
mutableRow.update(0, partialResultUnwrapper(partial1ModeEvaluator.terminatePartial(buffer)))
|
||||
mutableRow.update(0, partialResultUnwrapper(
|
||||
partial1HiveEvaluator.evaluator.terminatePartial(buffer)))
|
||||
val unsafeRow = projection(mutableRow)
|
||||
val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes)
|
||||
unsafeRow.writeTo(bytes)
|
||||
|
|
|
@ -638,6 +638,20 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
|
|||
Row(3) :: Row(3) :: Nil)
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-25768 constant argument expecting Hive UDF") {
|
||||
withTempView("inputTable") {
|
||||
spark.range(10).createOrReplaceTempView("inputTable")
|
||||
withUserDefinedFunction("testGenericUDAFPercentileApprox" -> false) {
|
||||
val numFunc = spark.catalog.listFunctions().count()
|
||||
sql(s"CREATE FUNCTION testGenericUDAFPercentileApprox AS '" +
|
||||
s"${classOf[GenericUDAFPercentileApprox].getName}'")
|
||||
checkAnswer(
|
||||
sql("SELECT testGenericUDAFPercentileApprox(id, 0.5) FROM inputTable"),
|
||||
Seq(Row(4.0)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class TestPair(x: Int, y: Int) extends Writable with Serializable {
|
||||
|
|
Loading…
Reference in a new issue