[SPARK-8828] [SQL] Revert SPARK-5680
JIRA: https://issues.apache.org/jira/browse/SPARK-8828 Author: Yijie Shen <henry.yijieshen@gmail.com> Closes #7667 from yjshen/revert_combinesum_2 and squashes the following commits: c37ccb1 [Yijie Shen] add test case 8377214 [Yijie Shen] revert spark.sql.useAggregate2 to its default value e2305ac [Yijie Shen] fix bug - avg on decimal column 7cb0e95 [Yijie Shen] [wip] resolving bugs 1fadb5a [Yijie Shen] remove occurance 17c6248 [Yijie Shen] revert SPARK-5680
This commit is contained in:
parent
3bc7055e26
commit
63a492b931
|
@ -404,7 +404,7 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg
|
|||
|
||||
// partialSum already increase the precision by 10
|
||||
val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType)
|
||||
val castedCount = Sum(partialCount.toAttribute)
|
||||
val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType)
|
||||
SplitEvaluation(
|
||||
Cast(Divide(castedSum, castedCount), dataType),
|
||||
partialCount :: partialSum :: Nil)
|
||||
|
@ -490,13 +490,13 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1
|
|||
case DecimalType.Fixed(_, _) =>
|
||||
val partialSum = Alias(Sum(child), "PartialSum")()
|
||||
SplitEvaluation(
|
||||
Cast(CombineSum(partialSum.toAttribute), dataType),
|
||||
Cast(Sum(partialSum.toAttribute), dataType),
|
||||
partialSum :: Nil)
|
||||
|
||||
case _ =>
|
||||
val partialSum = Alias(Sum(child), "PartialSum")()
|
||||
SplitEvaluation(
|
||||
CombineSum(partialSum.toAttribute),
|
||||
Sum(partialSum.toAttribute),
|
||||
partialSum :: Nil)
|
||||
}
|
||||
}
|
||||
|
@ -522,8 +522,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg
|
|||
|
||||
private val sum = MutableLiteral(null, calcType)
|
||||
|
||||
private val addFunction =
|
||||
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
|
||||
private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum))
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
sum.update(addFunction, input)
|
||||
|
@ -538,67 +537,6 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sum should satisfy 3 cases:
|
||||
* 1) sum of all null values = zero
|
||||
* 2) sum for table column with no data = null
|
||||
* 3) sum of column with null and not null values = sum of not null values
|
||||
* Require separate CombineSum Expression and function as it has to distinguish "No data" case
|
||||
* versus "data equals null" case, while aggregating results and at each partial expression.i.e.,
|
||||
* Combining PartitionLevel InputData
|
||||
* <-- null
|
||||
* Zero <-- Zero <-- null
|
||||
*
|
||||
* <-- null <-- no data
|
||||
* null <-- null <-- no data
|
||||
*/
|
||||
case class CombineSum(child: Expression) extends AggregateExpression1 {
|
||||
def this() = this(null)
|
||||
|
||||
override def children: Seq[Expression] = child :: Nil
|
||||
override def nullable: Boolean = true
|
||||
override def dataType: DataType = child.dataType
|
||||
override def toString: String = s"CombineSum($child)"
|
||||
override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this)
|
||||
}
|
||||
|
||||
case class CombineSumFunction(expr: Expression, base: AggregateExpression1)
|
||||
extends AggregateFunction1 {
|
||||
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
private val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
DecimalType.bounded(precision + 10, scale)
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
||||
private val zero = Cast(Literal(0), calcType)
|
||||
|
||||
private val sum = MutableLiteral(null, calcType)
|
||||
|
||||
private val addFunction =
|
||||
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
val result = expr.eval(input)
|
||||
// partial sum result can be null only when no input rows present
|
||||
if(result != null) {
|
||||
sum.update(addFunction, input)
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
Cast(sum, dataType).eval(null)
|
||||
case _ => sum.eval(null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 {
|
||||
|
||||
def this() = this(null)
|
||||
|
|
|
@ -108,7 +108,7 @@ case class GeneratedAggregate(
|
|||
Add(
|
||||
Coalesce(currentSum :: zero :: Nil),
|
||||
Cast(expr, calcType)
|
||||
) :: currentSum :: zero :: Nil)
|
||||
) :: currentSum :: Nil)
|
||||
val result =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
|
@ -118,45 +118,6 @@ case class GeneratedAggregate(
|
|||
|
||||
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
|
||||
|
||||
case cs @ CombineSum(expr) =>
|
||||
val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(p, s) =>
|
||||
DecimalType.bounded(p + 10, s)
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
||||
val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
|
||||
val initialValue = Literal.create(null, calcType)
|
||||
|
||||
// Coalesce avoids double calculation...
|
||||
// but really, common sub expression elimination would be better....
|
||||
val zero = Cast(Literal(0), calcType)
|
||||
// If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
|
||||
// UnscaledValue will be null if and only if x is null; helps with Average on decimals
|
||||
val actualExpr = expr match {
|
||||
case UnscaledValue(e) => e
|
||||
case _ => expr
|
||||
}
|
||||
// partial sum result can be null only when no input rows present
|
||||
val updateFunction = If(
|
||||
IsNotNull(actualExpr),
|
||||
Coalesce(
|
||||
Add(
|
||||
Coalesce(currentSum :: zero :: Nil),
|
||||
Cast(expr, calcType)) :: currentSum :: zero :: Nil),
|
||||
currentSum)
|
||||
|
||||
val result =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
Cast(currentSum, cs.dataType)
|
||||
case _ => currentSum
|
||||
}
|
||||
|
||||
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
|
||||
|
||||
case m @ Max(expr) =>
|
||||
val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)()
|
||||
val initialValue = Literal.create(null, expr.dataType)
|
||||
|
|
|
@ -201,7 +201,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
}
|
||||
|
||||
def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall {
|
||||
case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true
|
||||
case _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true
|
||||
// The generated set implementation is pretty limited ATM.
|
||||
case CollectHashSet(exprs) if exprs.size == 1 &&
|
||||
Seq(IntegerType, LongType).contains(exprs.head.dataType) => true
|
||||
|
|
|
@ -227,6 +227,37 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
|
|||
Seq(Row("1"), Row("2")))
|
||||
}
|
||||
|
||||
test("SPARK-8828 sum should return null if all input values are null") {
|
||||
withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") {
|
||||
withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") {
|
||||
checkAnswer(
|
||||
sql("select sum(a), avg(a) from allNulls"),
|
||||
Seq(Row(null, null))
|
||||
)
|
||||
}
|
||||
withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") {
|
||||
checkAnswer(
|
||||
sql("select sum(a), avg(a) from allNulls"),
|
||||
Seq(Row(null, null))
|
||||
)
|
||||
}
|
||||
}
|
||||
withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
|
||||
withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") {
|
||||
checkAnswer(
|
||||
sql("select sum(a), avg(a) from allNulls"),
|
||||
Seq(Row(null, null))
|
||||
)
|
||||
}
|
||||
withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") {
|
||||
checkAnswer(
|
||||
sql("select sum(a), avg(a) from allNulls"),
|
||||
Seq(Row(null, null))
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("aggregation with codegen") {
|
||||
val originalValue = sqlContext.conf.codegenEnabled
|
||||
sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
|
||||
|
|
|
@ -822,7 +822,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
|
|||
"udaf_covar_pop",
|
||||
"udaf_covar_samp",
|
||||
"udaf_histogram_numeric",
|
||||
"udaf_number_format",
|
||||
"udf2",
|
||||
"udf5",
|
||||
"udf6",
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
0.0 NULL NULL NULL
|
Loading…
Reference in a new issue