From 63a492b931765b1edd66624421d503f1927825ec Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Mon, 27 Jul 2015 22:47:31 -0700 Subject: [PATCH] [SPARK-8828] [SQL] Revert SPARK-5680 JIRA: https://issues.apache.org/jira/browse/SPARK-8828 Author: Yijie Shen Closes #7667 from yjshen/revert_combinesum_2 and squashes the following commits: c37ccb1 [Yijie Shen] add test case 8377214 [Yijie Shen] revert spark.sql.useAggregate2 to its default value e2305ac [Yijie Shen] fix bug - avg on decimal column 7cb0e95 [Yijie Shen] [wip] resolving bugs 1fadb5a [Yijie Shen] remove occurance 17c6248 [Yijie Shen] revert SPARK-5680 --- .../sql/catalyst/expressions/aggregates.scala | 70 ++----------------- .../sql/execution/GeneratedAggregate.scala | 41 +---------- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 31 ++++++++ .../execution/HiveCompatibilitySuite.scala | 1 - ..._format-0-eff4ef3c207d14d5121368f294697964 | 0 ..._format-1-4a03c4328565c60ca99689239f07fb16 | 1 - 7 files changed, 37 insertions(+), 109 deletions(-) delete mode 100644 sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 delete mode 100644 sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 42343d4d8d..5d4b349b15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -404,7 +404,7 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg // partialSum already increase the precision by 10 val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType) - val castedCount = Sum(partialCount.toAttribute) + val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType) SplitEvaluation( Cast(Divide(castedSum, castedCount), dataType), partialCount :: partialSum :: Nil) @@ -490,13 +490,13 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 case DecimalType.Fixed(_, _) => val partialSum = Alias(Sum(child), "PartialSum")() SplitEvaluation( - Cast(CombineSum(partialSum.toAttribute), dataType), + Cast(Sum(partialSum.toAttribute), dataType), partialSum :: Nil) case _ => val partialSum = Alias(Sum(child), "PartialSum")() SplitEvaluation( - CombineSum(partialSum.toAttribute), + Sum(partialSum.toAttribute), partialSum :: Nil) } } @@ -522,8 +522,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg private val sum = MutableLiteral(null, calcType) - private val addFunction = - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum)) override def update(input: InternalRow): Unit = { sum.update(addFunction, input) @@ -538,67 +537,6 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg } } -/** - * Sum should satisfy 3 cases: - * 1) sum of all null values = zero - * 2) sum for table column with no data = null - * 3) sum of column with null and not null values = sum of not null values - * Require separate CombineSum Expression and function as it has to distinguish "No data" case - * versus "data equals null" case, while aggregating results and at each partial expression.i.e., - * Combining PartitionLevel InputData - * <-- null - * Zero <-- Zero <-- null - * - * <-- null <-- no data - * null <-- null <-- no data - */ -case class CombineSum(child: Expression) extends AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"CombineSum($child)" - override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) -} - -case class CombineSumFunction(expr: Expression, base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - private val calcType = - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision + 10, scale) - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) - - private val sum = MutableLiteral(null, calcType) - - private val addFunction = - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) - - override def update(input: InternalRow): Unit = { - val result = expr.eval(input) - // partial sum result can be null only when no input rows present - if(result != null) { - sum.update(addFunction, input) - } - } - - override def eval(input: InternalRow): Any = { - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(sum, dataType).eval(null) - case _ => sum.eval(null) - } - } -} - case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 { def this() = this(null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 5ad4691a5c..1cd1420480 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -108,7 +108,7 @@ case class GeneratedAggregate( Add( Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType) - ) :: currentSum :: zero :: Nil) + ) :: currentSum :: Nil) val result = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -118,45 +118,6 @@ case class GeneratedAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case cs @ CombineSum(expr) => - val calcType = - expr.dataType match { - case DecimalType.Fixed(p, s) => - DecimalType.bounded(p + 10, s) - case _ => - expr.dataType - } - - val currentSum = AttributeReference("currentSum", calcType, nullable = true)() - val initialValue = Literal.create(null, calcType) - - // Coalesce avoids double calculation... - // but really, common sub expression elimination would be better.... - val zero = Cast(Literal(0), calcType) - // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its - // UnscaledValue will be null if and only if x is null; helps with Average on decimals - val actualExpr = expr match { - case UnscaledValue(e) => e - case _ => expr - } - // partial sum result can be null only when no input rows present - val updateFunction = If( - IsNotNull(actualExpr), - Coalesce( - Add( - Coalesce(currentSum :: zero :: Nil), - Cast(expr, calcType)) :: currentSum :: zero :: Nil), - currentSum) - - val result = - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(currentSum, cs.dataType) - case _ => currentSum - } - - AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() val initialValue = Literal.create(null, expr.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 306bbfec62..d88a02298c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -201,7 +201,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall { - case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true + case _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && Seq(IntegerType, LongType).contains(exprs.head.dataType) => true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 358e319476..42724ed766 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -227,6 +227,37 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Seq(Row("1"), Row("2"))) } + test("SPARK-8828 sum should return null if all input values are null") { + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + } + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + } + } + test("aggregation with codegen") { val originalValue = sqlContext.conf.codegenEnabled sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index b12b3838e6..ec959cb219 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -822,7 +822,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udaf_covar_pop", "udaf_covar_samp", "udaf_histogram_numeric", - "udaf_number_format", "udf2", "udf5", "udf6", diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 b/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 b/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 deleted file mode 100644 index c6f275a0db..0000000000 --- a/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 +++ /dev/null @@ -1 +0,0 @@ -0.0 NULL NULL NULL