[SPARK-16906][SQL] Adds auxiliary info like input class and input schema in TypedAggregateExpression

## What changes were proposed in this pull request?

This PR adds auxiliary info like input class and input schema in TypedAggregateExpression

## How was this patch tested?

Manual test.

Author: Sean Zhong <seanzhong@databricks.com>

Closes #14501 from clockfly/typed_aggregation.
This commit is contained in:
Sean Zhong 2016-08-08 22:20:54 +08:00 committed by Wenchen Fan
parent 06f5dc8415
commit 94a9d11ed1
5 changed files with 14 additions and 7 deletions

View file

@ -69,12 +69,15 @@ class TypedColumn[-T, U](
* on a decoded object.
*/
private[sql] def withInputType(
inputDeserializer: Expression,
inputEncoder: ExpressionEncoder[_],
inputAttributes: Seq[Attribute]): TypedColumn[T, U] = {
val unresolvedDeserializer = UnresolvedDeserializer(inputDeserializer, inputAttributes)
val unresolvedDeserializer = UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes)
val newExpr = expr transform {
case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty =>
ta.copy(inputDeserializer = Some(unresolvedDeserializer))
ta.copy(
inputDeserializer = Some(unresolvedDeserializer),
inputClass = Some(inputEncoder.clsTag.runtimeClass),
inputSchema = Some(inputEncoder.schema))
}
new TypedColumn[T, U](newExpr, encoder)
}

View file

@ -1059,7 +1059,7 @@ class Dataset[T] private[sql](
@Experimental
def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
implicit val encoder = c1.encoder
val project = Project(c1.withInputType(exprEnc.deserializer, logicalPlan.output).named :: Nil,
val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil,
logicalPlan)
if (encoder.flat) {
@ -1078,7 +1078,7 @@ class Dataset[T] private[sql](
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(_.withInputType(exprEnc.deserializer, logicalPlan.output).named)
columns.map(_.withInputType(exprEnc, logicalPlan.output).named)
val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan))
new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders))
}

View file

@ -201,7 +201,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(_.withInputType(vExprEnc.deserializer, dataAttributes).named)
columns.map(_.withInputType(vExprEnc, dataAttributes).named)
val keyColumn = if (kExprEnc.flat) {
assert(groupingAttributes.length == 1)
groupingAttributes.head

View file

@ -219,7 +219,7 @@ class RelationalGroupedDataset protected[sql](
def agg(expr: Column, exprs: Column*): DataFrame = {
toDF((expr +: exprs).map {
case typed: TypedColumn[_, _] =>
typed.withInputType(df.exprEnc.deserializer, df.logicalPlan.output).expr
typed.withInputType(df.exprEnc, df.logicalPlan.output).expr
case c => c.expr
})
}

View file

@ -47,6 +47,8 @@ object TypedAggregateExpression {
new TypedAggregateExpression(
aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
None,
None,
None,
bufferSerializer,
bufferDeserializer,
outputEncoder.serializer,
@ -62,6 +64,8 @@ object TypedAggregateExpression {
case class TypedAggregateExpression(
aggregator: Aggregator[Any, Any, Any],
inputDeserializer: Option[Expression],
inputClass: Option[Class[_]],
inputSchema: Option[StructType],
bufferSerializer: Seq[NamedExpression],
bufferDeserializer: Expression,
outputSerializer: Seq[Expression],