[SPARK-12024][SQL] More efficient multi-column counting.

In https://github.com/apache/spark/pull/9409 we enabled multi-column counting. The approach taken in that PR introduces a bit of overhead by first creating a row only to check if all of the columns are non-null.

This PR fixes that technical debt. Count now takes multiple columns as its input. In order to make this work I have also added support for multiple columns in the single distinct code path.

cc yhuai

Author: Herman van Hovell <hvanhovell@questtec.nl>

Closes #10015 from hvanhovell/SPARK-12024.
This commit is contained in:
Herman van Hovell 2015-11-29 14:13:11 -08:00 committed by Yin Huai
parent cc7a1bc937
commit 3d28081e53
6 changed files with 33 additions and 86 deletions

View file

@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
case class Count(child: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = child :: Nil
case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
override def nullable: Boolean = false
@ -30,7 +29,7 @@ case class Count(child: Expression) extends DeclarativeAggregate {
override def dataType: DataType = LongType
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)
private lazy val count = AttributeReference("count", LongType)()
@ -41,7 +40,7 @@ case class Count(child: Expression) extends DeclarativeAggregate {
)
override lazy val updateExpressions = Seq(
/* count = */ If(IsNull(child), count, count + 1L)
/* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L)
)
override lazy val mergeExpressions = Seq(
@ -54,17 +53,5 @@ case class Count(child: Expression) extends DeclarativeAggregate {
}
object Count {
def apply(children: Seq[Expression]): Count = {
// This is used to deal with COUNT DISTINCT. When we have multiple
// children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row).
// Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any
// null in the arguments, we will not count that row. So, we use DropAnyNull at here
// to return a null when any field of the created STRUCT is null.
val child = if (children.size > 1) {
DropAnyNull(CreateStruct(children))
} else {
children.head
}
Count(child)
}
def apply(child: Expression): Count = Count(child :: Nil)
}

View file

@ -426,30 +426,3 @@ case class Greatest(children: Seq[Expression]) extends Expression {
}
}
/** Operator that drops a row when it contains any nulls. */
case class DropAnyNull(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(StructType)
protected override def nullSafeEval(input: Any): InternalRow = {
val row = input.asInstanceOf[InternalRow]
if (row.anyNull) {
null
} else {
row
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, eval => {
s"""
if ($eval.anyNull()) {
${ev.isNull} = true;
} else {
${ev.value} = $eval;
}
"""
})
}
}

View file

@ -362,9 +362,14 @@ object LikeSimplification extends Rule[LogicalPlan] {
* Null value propagation from bottom to top of the expression tree.
*/
object NullPropagation extends Rule[LogicalPlan] {
def nonNullLiteral(e: Expression): Boolean = e match {
case Literal(null, _) => false
case _ => true
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case e @ AggregateExpression(Count(Literal(null, _)), _, _) =>
case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) =>
Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
@ -377,16 +382,13 @@ object NullPropagation extends Rule[LogicalPlan] {
Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable =>
case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) =>
// This rule should be only triggered when isDistinct field is false.
AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
val newChildren = children.filter {
case Literal(null, _) => false
case _ => true
}
val newChildren = children.filter(nonNullLiteral)
if (newChildren.length == 0) {
Literal.create(null, e.dataType)
} else if (newChildren.length == 1) {

View file

@ -231,18 +231,4 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2)
}
}
test("function dropAnyNull") {
val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1))))
val a = create_row("a", "q")
val nullStr: String = null
checkEvaluation(drop, a, a)
checkEvaluation(drop, null, create_row("b", nullStr))
checkEvaluation(drop, null, create_row(nullStr, nullStr))
val row = 'r.struct(
StructField("a", StringType, false),
StructField("b", StringType, true)).at(0)
checkEvaluation(DropAnyNull(row), null, create_row(null))
}
}

View file

@ -146,20 +146,16 @@ object Utils {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
// functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one
// DISTINCT aggregate function, all of those functions will have the same column expression.
// DISTINCT aggregate function, all of those functions will have the same column expressions.
// For example, it would be valid for functionsWithDistinct to be
// [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is
// disallowed because those two distinct aggregates have different column expressions.
val distinctColumnExpression: Expression = {
val allDistinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
assert(allDistinctColumnExpressions.length == 1)
allDistinctColumnExpressions.head
}
val namedDistinctColumnExpression: NamedExpression = distinctColumnExpression match {
val distinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
val namedDistinctColumnExpressions = distinctColumnExpressions.map {
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
}
val distinctColumnAttribute: Attribute = namedDistinctColumnExpression.toAttribute
val distinctColumnAttributes = namedDistinctColumnExpressions.map(_.toAttribute)
val groupingAttributes = groupingExpressions.map(_.toAttribute)
// 1. Create an Aggregate Operator for partial aggregations.
@ -170,10 +166,11 @@ object Utils {
// We will group by the original grouping expression, plus an additional expression for the
// DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
// expressions will be [key, value].
val partialAggregateGroupingExpressions = groupingExpressions :+ namedDistinctColumnExpression
val partialAggregateGroupingExpressions =
groupingExpressions ++ namedDistinctColumnExpressions
val partialAggregateResult =
groupingAttributes ++
Seq(distinctColumnAttribute) ++
distinctColumnAttributes ++
partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
if (usesTungstenAggregate) {
TungstenAggregate(
@ -208,28 +205,28 @@ object Utils {
partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
val partialMergeAggregateResult =
groupingAttributes ++
Seq(distinctColumnAttribute) ++
distinctColumnAttributes ++
partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
if (usesTungstenAggregate) {
TungstenAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
completeAggregateExpressions = Nil,
completeAggregateAttributes = Nil,
initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = partialMergeAggregateResult,
child = partialAggregate)
} else {
SortBasedAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
completeAggregateExpressions = Nil,
completeAggregateAttributes = Nil,
initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = partialMergeAggregateResult,
child = partialAggregate)
}
@ -244,14 +241,16 @@ object Utils {
expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
}
val distinctColumnAttributeLookup =
distinctColumnExpressions.zip(distinctColumnAttributes).toMap
val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
// Children of an AggregateFunction with DISTINCT keyword has already
// been evaluated. At here, we need to replace original children
// to AttributeReferences.
case agg @ AggregateExpression(aggregateFunction, mode, true) =>
val rewrittenAggregateFunction = aggregateFunction.transformDown {
case expr if expr == distinctColumnExpression => distinctColumnAttribute
}.asInstanceOf[AggregateFunction]
val rewrittenAggregateFunction = aggregateFunction
.transformDown(distinctColumnAttributeLookup)
.asInstanceOf[AggregateFunction]
// We rewrite the aggregate function to a non-distinct aggregation because
// its input will have distinct arguments.
// We just keep the isDistinct setting to true, so when users look at the query plan,
@ -270,7 +269,7 @@ object Utils {
nonCompleteAggregateAttributes = finalAggregateAttributes,
completeAggregateExpressions = completeAggregateExpressions,
completeAggregateAttributes = completeAggregateAttributes,
initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = resultExpressions,
child = partialMergeAggregate)
} else {
@ -281,7 +280,7 @@ object Utils {
nonCompleteAggregateAttributes = finalAggregateAttributes,
completeAggregateExpressions = completeAggregateExpressions,
completeAggregateAttributes = completeAggregateAttributes,
initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
resultExpressions = resultExpressions,
child = partialMergeAggregate)
}

View file

@ -152,8 +152,8 @@ class WindowSpec private[sql](
case Sum(child) => WindowExpression(
UnresolvedWindowFunction("sum", child :: Nil),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
case Count(child) => WindowExpression(
UnresolvedWindowFunction("count", child :: Nil),
case Count(children) => WindowExpression(
UnresolvedWindowFunction("count", children),
WindowSpecDefinition(partitionSpec, orderSpec, frame))
case First(child, ignoreNulls) => WindowExpression(
// TODO this is a hack for Hive UDAF first_value