[SPARK-8828] [SQL] Revert SPARK-5680

JIRA: https://issues.apache.org/jira/browse/SPARK-8828

Author: Yijie Shen <henry.yijieshen@gmail.com>

Closes #7667 from yjshen/revert_combinesum_2 and squashes the following commits:

c37ccb1 [Yijie Shen] add test case
8377214 [Yijie Shen] revert spark.sql.useAggregate2 to its default value
e2305ac [Yijie Shen] fix bug - avg on decimal column
7cb0e95 [Yijie Shen] [wip] resolving bugs
1fadb5a [Yijie Shen] remove occurance
17c6248 [Yijie Shen] revert SPARK-5680
This commit is contained in:
Yijie Shen 2015-07-27 22:47:31 -07:00 committed by Michael Armbrust
parent 3bc7055e26
commit 63a492b931
7 changed files with 37 additions and 109 deletions

View file

@ -404,7 +404,7 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg
// partialSum already increase the precision by 10
val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType)
val castedCount = Sum(partialCount.toAttribute)
val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType)
SplitEvaluation(
Cast(Divide(castedSum, castedCount), dataType),
partialCount :: partialSum :: Nil)
@ -490,13 +490,13 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1
case DecimalType.Fixed(_, _) =>
val partialSum = Alias(Sum(child), "PartialSum")()
SplitEvaluation(
Cast(CombineSum(partialSum.toAttribute), dataType),
Cast(Sum(partialSum.toAttribute), dataType),
partialSum :: Nil)
case _ =>
val partialSum = Alias(Sum(child), "PartialSum")()
SplitEvaluation(
CombineSum(partialSum.toAttribute),
Sum(partialSum.toAttribute),
partialSum :: Nil)
}
}
@ -522,8 +522,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg
private val sum = MutableLiteral(null, calcType)
private val addFunction =
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum))
override def update(input: InternalRow): Unit = {
sum.update(addFunction, input)
@ -538,67 +537,6 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg
}
}
/**
* Sum should satisfy 3 cases:
* 1) sum of all null values = zero
* 2) sum for table column with no data = null
* 3) sum of column with null and not null values = sum of not null values
* Require separate CombineSum Expression and function as it has to distinguish "No data" case
* versus "data equals null" case, while aggregating results and at each partial expression.i.e.,
* Combining PartitionLevel InputData
* <-- null
* Zero <-- Zero <-- null
*
* <-- null <-- no data
* null <-- null <-- no data
*/
case class CombineSum(child: Expression) extends AggregateExpression1 {
def this() = this(null)
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
override def toString: String = s"CombineSum($child)"
override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this)
}
case class CombineSumFunction(expr: Expression, base: AggregateExpression1)
extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
private val calcType =
expr.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
case _ =>
expr.dataType
}
private val zero = Cast(Literal(0), calcType)
private val sum = MutableLiteral(null, calcType)
private val addFunction =
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
override def update(input: InternalRow): Unit = {
val result = expr.eval(input)
// partial sum result can be null only when no input rows present
if(result != null) {
sum.update(addFunction, input)
}
}
override def eval(input: InternalRow): Any = {
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Cast(sum, dataType).eval(null)
case _ => sum.eval(null)
}
}
}
case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 {
def this() = this(null)

View file

@ -108,7 +108,7 @@ case class GeneratedAggregate(
Add(
Coalesce(currentSum :: zero :: Nil),
Cast(expr, calcType)
) :: currentSum :: zero :: Nil)
) :: currentSum :: Nil)
val result =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
@ -118,45 +118,6 @@ case class GeneratedAggregate(
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
case cs @ CombineSum(expr) =>
val calcType =
expr.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 10, s)
case _ =>
expr.dataType
}
val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
val initialValue = Literal.create(null, calcType)
// Coalesce avoids double calculation...
// but really, common sub expression elimination would be better....
val zero = Cast(Literal(0), calcType)
// If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
// UnscaledValue will be null if and only if x is null; helps with Average on decimals
val actualExpr = expr match {
case UnscaledValue(e) => e
case _ => expr
}
// partial sum result can be null only when no input rows present
val updateFunction = If(
IsNotNull(actualExpr),
Coalesce(
Add(
Coalesce(currentSum :: zero :: Nil),
Cast(expr, calcType)) :: currentSum :: zero :: Nil),
currentSum)
val result =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Cast(currentSum, cs.dataType)
case _ => currentSum
}
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
case m @ Max(expr) =>
val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)()
val initialValue = Literal.create(null, expr.dataType)

View file

@ -201,7 +201,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall {
case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true
case _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true
// The generated set implementation is pretty limited ATM.
case CollectHashSet(exprs) if exprs.size == 1 &&
Seq(IntegerType, LongType).contains(exprs.head.dataType) => true

View file

@ -227,6 +227,37 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
Seq(Row("1"), Row("2")))
}
test("SPARK-8828 sum should return null if all input values are null") {
withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") {
withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") {
checkAnswer(
sql("select sum(a), avg(a) from allNulls"),
Seq(Row(null, null))
)
}
withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") {
checkAnswer(
sql("select sum(a), avg(a) from allNulls"),
Seq(Row(null, null))
)
}
}
withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") {
checkAnswer(
sql("select sum(a), avg(a) from allNulls"),
Seq(Row(null, null))
)
}
withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") {
checkAnswer(
sql("select sum(a), avg(a) from allNulls"),
Seq(Row(null, null))
)
}
}
}
test("aggregation with codegen") {
val originalValue = sqlContext.conf.codegenEnabled
sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)

View file

@ -822,7 +822,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udaf_covar_pop",
"udaf_covar_samp",
"udaf_histogram_numeric",
"udaf_number_format",
"udf2",
"udf5",
"udf6",