[SPARK-16906][SQL] Adds auxiliary info like input class and input schema in TypedAggregateExpression
## What changes were proposed in this pull request? This PR adds auxiliary info like input class and input schema in TypedAggregateExpression ## How was this patch tested? Manual test. Author: Sean Zhong <seanzhong@databricks.com> Closes #14501 from clockfly/typed_aggregation.
This commit is contained in:
parent
06f5dc8415
commit
94a9d11ed1
|
@ -69,12 +69,15 @@ class TypedColumn[-T, U](
|
|||
* on a decoded object.
|
||||
*/
|
||||
private[sql] def withInputType(
|
||||
inputDeserializer: Expression,
|
||||
inputEncoder: ExpressionEncoder[_],
|
||||
inputAttributes: Seq[Attribute]): TypedColumn[T, U] = {
|
||||
val unresolvedDeserializer = UnresolvedDeserializer(inputDeserializer, inputAttributes)
|
||||
val unresolvedDeserializer = UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes)
|
||||
val newExpr = expr transform {
|
||||
case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty =>
|
||||
ta.copy(inputDeserializer = Some(unresolvedDeserializer))
|
||||
ta.copy(
|
||||
inputDeserializer = Some(unresolvedDeserializer),
|
||||
inputClass = Some(inputEncoder.clsTag.runtimeClass),
|
||||
inputSchema = Some(inputEncoder.schema))
|
||||
}
|
||||
new TypedColumn[T, U](newExpr, encoder)
|
||||
}
|
||||
|
|
|
@ -1059,7 +1059,7 @@ class Dataset[T] private[sql](
|
|||
@Experimental
|
||||
def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
|
||||
implicit val encoder = c1.encoder
|
||||
val project = Project(c1.withInputType(exprEnc.deserializer, logicalPlan.output).named :: Nil,
|
||||
val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil,
|
||||
logicalPlan)
|
||||
|
||||
if (encoder.flat) {
|
||||
|
@ -1078,7 +1078,7 @@ class Dataset[T] private[sql](
|
|||
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
|
||||
val encoders = columns.map(_.encoder)
|
||||
val namedColumns =
|
||||
columns.map(_.withInputType(exprEnc.deserializer, logicalPlan.output).named)
|
||||
columns.map(_.withInputType(exprEnc, logicalPlan.output).named)
|
||||
val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan))
|
||||
new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders))
|
||||
}
|
||||
|
|
|
@ -201,7 +201,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
|
|||
protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
|
||||
val encoders = columns.map(_.encoder)
|
||||
val namedColumns =
|
||||
columns.map(_.withInputType(vExprEnc.deserializer, dataAttributes).named)
|
||||
columns.map(_.withInputType(vExprEnc, dataAttributes).named)
|
||||
val keyColumn = if (kExprEnc.flat) {
|
||||
assert(groupingAttributes.length == 1)
|
||||
groupingAttributes.head
|
||||
|
|
|
@ -219,7 +219,7 @@ class RelationalGroupedDataset protected[sql](
|
|||
def agg(expr: Column, exprs: Column*): DataFrame = {
|
||||
toDF((expr +: exprs).map {
|
||||
case typed: TypedColumn[_, _] =>
|
||||
typed.withInputType(df.exprEnc.deserializer, df.logicalPlan.output).expr
|
||||
typed.withInputType(df.exprEnc, df.logicalPlan.output).expr
|
||||
case c => c.expr
|
||||
})
|
||||
}
|
||||
|
|
|
@ -47,6 +47,8 @@ object TypedAggregateExpression {
|
|||
new TypedAggregateExpression(
|
||||
aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
bufferSerializer,
|
||||
bufferDeserializer,
|
||||
outputEncoder.serializer,
|
||||
|
@ -62,6 +64,8 @@ object TypedAggregateExpression {
|
|||
case class TypedAggregateExpression(
|
||||
aggregator: Aggregator[Any, Any, Any],
|
||||
inputDeserializer: Option[Expression],
|
||||
inputClass: Option[Class[_]],
|
||||
inputSchema: Option[StructType],
|
||||
bufferSerializer: Seq[NamedExpression],
|
||||
bufferDeserializer: Expression,
|
||||
outputSerializer: Seq[Expression],
|
||||
|
|
Loading…
Reference in a new issue