[SPARK-12024][SQL] More efficient multi-column counting.
In https://github.com/apache/spark/pull/9409 we enabled multi-column counting. The approach taken in that PR introduces a bit of overhead by first creating a row only to check if all of the columns are non-null. This PR fixes that technical debt. Count now takes multiple columns as its input. In order to make this work I have also added support for multiple columns in the single distinct code path. cc yhuai Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #10015 from hvanhovell/SPARK-12024.
This commit is contained in:
parent
cc7a1bc937
commit
3d28081e53
|
@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
case class Count(child: Expression) extends DeclarativeAggregate {
|
||||
override def children: Seq[Expression] = child :: Nil
|
||||
case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
|
||||
|
||||
override def nullable: Boolean = false
|
||||
|
||||
|
@ -30,7 +29,7 @@ case class Count(child: Expression) extends DeclarativeAggregate {
|
|||
override def dataType: DataType = LongType
|
||||
|
||||
// Expected input data type.
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)
|
||||
|
||||
private lazy val count = AttributeReference("count", LongType)()
|
||||
|
||||
|
@ -41,7 +40,7 @@ case class Count(child: Expression) extends DeclarativeAggregate {
|
|||
)
|
||||
|
||||
override lazy val updateExpressions = Seq(
|
||||
/* count = */ If(IsNull(child), count, count + 1L)
|
||||
/* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L)
|
||||
)
|
||||
|
||||
override lazy val mergeExpressions = Seq(
|
||||
|
@ -54,17 +53,5 @@ case class Count(child: Expression) extends DeclarativeAggregate {
|
|||
}
|
||||
|
||||
object Count {
|
||||
def apply(children: Seq[Expression]): Count = {
|
||||
// This is used to deal with COUNT DISTINCT. When we have multiple
|
||||
// children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row).
|
||||
// Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any
|
||||
// null in the arguments, we will not count that row. So, we use DropAnyNull at here
|
||||
// to return a null when any field of the created STRUCT is null.
|
||||
val child = if (children.size > 1) {
|
||||
DropAnyNull(CreateStruct(children))
|
||||
} else {
|
||||
children.head
|
||||
}
|
||||
Count(child)
|
||||
}
|
||||
def apply(child: Expression): Count = Count(child :: Nil)
|
||||
}
|
||||
|
|
|
@ -426,30 +426,3 @@ case class Greatest(children: Seq[Expression]) extends Expression {
|
|||
}
|
||||
}
|
||||
|
||||
/** Operator that drops a row when it contains any nulls. */
|
||||
case class DropAnyNull(child: Expression) extends UnaryExpression with ExpectsInputTypes {
|
||||
override def nullable: Boolean = true
|
||||
override def dataType: DataType = child.dataType
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(StructType)
|
||||
|
||||
protected override def nullSafeEval(input: Any): InternalRow = {
|
||||
val row = input.asInstanceOf[InternalRow]
|
||||
if (row.anyNull) {
|
||||
null
|
||||
} else {
|
||||
row
|
||||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||
nullSafeCodeGen(ctx, ev, eval => {
|
||||
s"""
|
||||
if ($eval.anyNull()) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
${ev.value} = $eval;
|
||||
}
|
||||
"""
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -362,9 +362,14 @@ object LikeSimplification extends Rule[LogicalPlan] {
|
|||
* Null value propagation from bottom to top of the expression tree.
|
||||
*/
|
||||
object NullPropagation extends Rule[LogicalPlan] {
|
||||
def nonNullLiteral(e: Expression): Boolean = e match {
|
||||
case Literal(null, _) => false
|
||||
case _ => true
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsUp {
|
||||
case e @ AggregateExpression(Count(Literal(null, _)), _, _) =>
|
||||
case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) =>
|
||||
Cast(Literal(0L), e.dataType)
|
||||
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
|
||||
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
|
||||
|
@ -377,16 +382,13 @@ object NullPropagation extends Rule[LogicalPlan] {
|
|||
Literal.create(null, e.dataType)
|
||||
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
|
||||
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
|
||||
case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable =>
|
||||
case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) =>
|
||||
// This rule should be only triggered when isDistinct field is false.
|
||||
AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
|
||||
|
||||
// For Coalesce, remove null literals.
|
||||
case e @ Coalesce(children) =>
|
||||
val newChildren = children.filter {
|
||||
case Literal(null, _) => false
|
||||
case _ => true
|
||||
}
|
||||
val newChildren = children.filter(nonNullLiteral)
|
||||
if (newChildren.length == 0) {
|
||||
Literal.create(null, e.dataType)
|
||||
} else if (newChildren.length == 1) {
|
||||
|
|
|
@ -231,18 +231,4 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2)
|
||||
}
|
||||
}
|
||||
|
||||
test("function dropAnyNull") {
|
||||
val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1))))
|
||||
val a = create_row("a", "q")
|
||||
val nullStr: String = null
|
||||
checkEvaluation(drop, a, a)
|
||||
checkEvaluation(drop, null, create_row("b", nullStr))
|
||||
checkEvaluation(drop, null, create_row(nullStr, nullStr))
|
||||
|
||||
val row = 'r.struct(
|
||||
StructField("a", StringType, false),
|
||||
StructField("b", StringType, true)).at(0)
|
||||
checkEvaluation(DropAnyNull(row), null, create_row(null))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -146,20 +146,16 @@ object Utils {
|
|||
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
|
||||
|
||||
// functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one
|
||||
// DISTINCT aggregate function, all of those functions will have the same column expression.
|
||||
// DISTINCT aggregate function, all of those functions will have the same column expressions.
|
||||
// For example, it would be valid for functionsWithDistinct to be
|
||||
// [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is
|
||||
// disallowed because those two distinct aggregates have different column expressions.
|
||||
val distinctColumnExpression: Expression = {
|
||||
val allDistinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
|
||||
assert(allDistinctColumnExpressions.length == 1)
|
||||
allDistinctColumnExpressions.head
|
||||
}
|
||||
val namedDistinctColumnExpression: NamedExpression = distinctColumnExpression match {
|
||||
val distinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
|
||||
val namedDistinctColumnExpressions = distinctColumnExpressions.map {
|
||||
case ne: NamedExpression => ne
|
||||
case other => Alias(other, other.toString)()
|
||||
}
|
||||
val distinctColumnAttribute: Attribute = namedDistinctColumnExpression.toAttribute
|
||||
val distinctColumnAttributes = namedDistinctColumnExpressions.map(_.toAttribute)
|
||||
val groupingAttributes = groupingExpressions.map(_.toAttribute)
|
||||
|
||||
// 1. Create an Aggregate Operator for partial aggregations.
|
||||
|
@ -170,10 +166,11 @@ object Utils {
|
|||
// We will group by the original grouping expression, plus an additional expression for the
|
||||
// DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
|
||||
// expressions will be [key, value].
|
||||
val partialAggregateGroupingExpressions = groupingExpressions :+ namedDistinctColumnExpression
|
||||
val partialAggregateGroupingExpressions =
|
||||
groupingExpressions ++ namedDistinctColumnExpressions
|
||||
val partialAggregateResult =
|
||||
groupingAttributes ++
|
||||
Seq(distinctColumnAttribute) ++
|
||||
distinctColumnAttributes ++
|
||||
partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
|
||||
if (usesTungstenAggregate) {
|
||||
TungstenAggregate(
|
||||
|
@ -208,28 +205,28 @@ object Utils {
|
|||
partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
|
||||
val partialMergeAggregateResult =
|
||||
groupingAttributes ++
|
||||
Seq(distinctColumnAttribute) ++
|
||||
distinctColumnAttributes ++
|
||||
partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
|
||||
if (usesTungstenAggregate) {
|
||||
TungstenAggregate(
|
||||
requiredChildDistributionExpressions = Some(groupingAttributes),
|
||||
groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
|
||||
groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
|
||||
nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
|
||||
nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
|
||||
completeAggregateExpressions = Nil,
|
||||
completeAggregateAttributes = Nil,
|
||||
initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
|
||||
initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
|
||||
resultExpressions = partialMergeAggregateResult,
|
||||
child = partialAggregate)
|
||||
} else {
|
||||
SortBasedAggregate(
|
||||
requiredChildDistributionExpressions = Some(groupingAttributes),
|
||||
groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
|
||||
groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
|
||||
nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
|
||||
nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
|
||||
completeAggregateExpressions = Nil,
|
||||
completeAggregateAttributes = Nil,
|
||||
initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
|
||||
initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
|
||||
resultExpressions = partialMergeAggregateResult,
|
||||
child = partialAggregate)
|
||||
}
|
||||
|
@ -244,14 +241,16 @@ object Utils {
|
|||
expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
|
||||
}
|
||||
|
||||
val distinctColumnAttributeLookup =
|
||||
distinctColumnExpressions.zip(distinctColumnAttributes).toMap
|
||||
val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
|
||||
// Children of an AggregateFunction with DISTINCT keyword has already
|
||||
// been evaluated. At here, we need to replace original children
|
||||
// to AttributeReferences.
|
||||
case agg @ AggregateExpression(aggregateFunction, mode, true) =>
|
||||
val rewrittenAggregateFunction = aggregateFunction.transformDown {
|
||||
case expr if expr == distinctColumnExpression => distinctColumnAttribute
|
||||
}.asInstanceOf[AggregateFunction]
|
||||
val rewrittenAggregateFunction = aggregateFunction
|
||||
.transformDown(distinctColumnAttributeLookup)
|
||||
.asInstanceOf[AggregateFunction]
|
||||
// We rewrite the aggregate function to a non-distinct aggregation because
|
||||
// its input will have distinct arguments.
|
||||
// We just keep the isDistinct setting to true, so when users look at the query plan,
|
||||
|
@ -270,7 +269,7 @@ object Utils {
|
|||
nonCompleteAggregateAttributes = finalAggregateAttributes,
|
||||
completeAggregateExpressions = completeAggregateExpressions,
|
||||
completeAggregateAttributes = completeAggregateAttributes,
|
||||
initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
|
||||
initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
|
||||
resultExpressions = resultExpressions,
|
||||
child = partialMergeAggregate)
|
||||
} else {
|
||||
|
@ -281,7 +280,7 @@ object Utils {
|
|||
nonCompleteAggregateAttributes = finalAggregateAttributes,
|
||||
completeAggregateExpressions = completeAggregateExpressions,
|
||||
completeAggregateAttributes = completeAggregateAttributes,
|
||||
initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
|
||||
initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
|
||||
resultExpressions = resultExpressions,
|
||||
child = partialMergeAggregate)
|
||||
}
|
||||
|
|
|
@ -152,8 +152,8 @@ class WindowSpec private[sql](
|
|||
case Sum(child) => WindowExpression(
|
||||
UnresolvedWindowFunction("sum", child :: Nil),
|
||||
WindowSpecDefinition(partitionSpec, orderSpec, frame))
|
||||
case Count(child) => WindowExpression(
|
||||
UnresolvedWindowFunction("count", child :: Nil),
|
||||
case Count(children) => WindowExpression(
|
||||
UnresolvedWindowFunction("count", children),
|
||||
WindowSpecDefinition(partitionSpec, orderSpec, frame))
|
||||
case First(child, ignoreNulls) => WindowExpression(
|
||||
// TODO this is a hack for Hive UDAF first_value
|
||||
|
|
Loading…
Reference in a new issue