[SPARK-4318][SQL] Fix empty sum distinct.

Executing sum distinct for empty table throws `java.lang.UnsupportedOperationException: empty.reduceLeft`.

Author: Takuya UESHIN <ueshin@happy-camper.st>

Closes #3184 from ueshin/issues/SPARK-4318 and squashes the following commits:

8168c42 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4318
66fdb0a [Takuya UESHIN] Re-refine aggregate functions.
6186eb4 [Takuya UESHIN] Fix Sum of GeneratedAggregate.
d2975f6 [Takuya UESHIN] Refine Sum and Average of GeneratedAggregate.
1bba675 [Takuya UESHIN] Refine Sum, SumDistinct and Average functions.
917e533 [Takuya UESHIN] Use aggregate instead of groupBy().
1a5f874 [Takuya UESHIN] Add tests to be executed as non-partial aggregation.
a5a57d2 [Takuya UESHIN] Fix empty Average.
22799dc [Takuya UESHIN] Fix empty Sum and SumDistinct.
65b7dd2 [Takuya UESHIN] Fix empty sum distinct.
This commit is contained in:
Takuya UESHIN 2014-11-20 15:41:24 -08:00 committed by Michael Armbrust
parent 98e9419784
commit 2c2e7a44db
4 changed files with 194 additions and 51 deletions

View file

@ -158,7 +158,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
override def asPartial: SplitEvaluation = {
val partialCount = Alias(Count(child), "PartialCount")()
SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil)
}
override def newInstance() = new CountFunction(child, this)
@ -285,7 +285,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def nullable = false
override def nullable = true
override def dataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
@ -299,12 +299,12 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
override def toString = s"AVG($child)"
override def asPartial: SplitEvaluation = {
val partialSum = Alias(Sum(child), "PartialSum")()
val partialCount = Alias(Count(child), "PartialCount")()
child.dataType match {
case DecimalType.Fixed(_, _) =>
// Turn the results to unlimited decimals for the division, before going back to fixed
// Turn the child to unlimited decimals for calculation, before going back to fixed
val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
val partialCount = Alias(Count(child), "PartialCount")()
val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited)
val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited)
SplitEvaluation(
@ -312,6 +312,9 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
partialCount :: partialSum :: Nil)
case _ =>
val partialSum = Alias(Sum(child), "PartialSum")()
val partialCount = Alias(Count(child), "PartialCount")()
val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
SplitEvaluation(
@ -325,7 +328,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def nullable = false
override def nullable = true
override def dataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
@ -339,10 +342,19 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
override def toString = s"SUM($child)"
override def asPartial: SplitEvaluation = {
val partialSum = Alias(Sum(child), "PartialSum")()
SplitEvaluation(
Sum(partialSum.toAttribute),
partialSum :: Nil)
child.dataType match {
case DecimalType.Fixed(_, _) =>
val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
SplitEvaluation(
Cast(Sum(partialSum.toAttribute), dataType),
partialSum :: Nil)
case _ =>
val partialSum = Alias(Sum(child), "PartialSum")()
SplitEvaluation(
Sum(partialSum.toAttribute),
partialSum :: Nil)
}
}
override def newInstance() = new SumFunction(child, this)
@ -351,7 +363,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
case class SumDistinct(child: Expression)
extends AggregateExpression with trees.UnaryNode[Expression] {
override def nullable = false
override def nullable = true
override def dataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
@ -401,16 +413,37 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
def this() = this(null, null) // Required for serialization.
private val zero = Cast(Literal(0), expr.dataType)
private val calcType =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case _ =>
expr.dataType
}
private val zero = Cast(Literal(0), calcType)
private var count: Long = _
private val sum = MutableLiteral(zero.eval(null), expr.dataType)
private val sumAsDouble = Cast(sum, DoubleType)
private val sum = MutableLiteral(zero.eval(null), calcType)
private def addFunction(value: Any) = Add(sum, Literal(value))
private def addFunction(value: Any) = Add(sum, Cast(Literal(value, expr.dataType), calcType))
override def eval(input: Row): Any =
sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
override def eval(input: Row): Any = {
if (count == 0L) {
null
} else {
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Cast(Divide(
Cast(sum, DecimalType.Unlimited),
Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null)
case _ =>
Divide(
Cast(sum, dataType),
Cast(Literal(count), dataType)).eval(null)
}
}
}
override def update(input: Row): Unit = {
val evaluatedExpr = expr.eval(input)
@ -475,17 +508,31 @@ case class ApproxCountDistinctMergeFunction(
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
private val zero = Cast(Literal(0), expr.dataType)
private val calcType =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case _ =>
expr.dataType
}
private val sum = MutableLiteral(zero.eval(null), expr.dataType)
private val zero = Cast(Literal(0), calcType)
private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
private val sum = MutableLiteral(null, calcType)
private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum))
override def update(input: Row): Unit = {
sum.update(addFunction, input)
}
override def eval(input: Row): Any = sum.eval(null)
override def eval(input: Row): Any = {
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Cast(sum, dataType).eval(null)
case _ => sum.eval(null)
}
}
}
case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
@ -502,8 +549,16 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
}
}
override def eval(input: Row): Any =
seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)
override def eval(input: Row): Any = {
if (seen.size == 0) {
null
} else {
Cast(Literal(
seen.reduceLeft(
dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
dataType).eval(null)
}
}
}
case class CountDistinctFunction(

View file

@ -83,29 +83,45 @@ case class GeneratedAggregate(
AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
case Sum(expr) =>
val resultType = expr.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 10, scale)
case _ =>
expr.dataType
}
case s @ Sum(expr) =>
val calcType =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case _ =>
expr.dataType
}
val currentSum = AttributeReference("currentSum", resultType, nullable = false)()
val initialValue = Cast(Literal(0L), resultType)
val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
val initialValue = Literal(null, calcType)
// Coalasce avoids double calculation...
// but really, common sub expression elimination would be better....
val updateFunction = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
val result = currentSum
val zero = Cast(Literal(0), calcType)
val updateFunction = Coalesce(
Add(Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: Nil)
val result =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Cast(currentSum, s.dataType)
case _ => currentSum
}
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
case a @ Average(expr) =>
val calcType =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case _ =>
expr.dataType
}
val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
val currentSum = AttributeReference("currentSum", calcType, nullable = false)()
val initialCount = Literal(0L)
val initialSum = Cast(Literal(0L), expr.dataType)
val initialSum = Cast(Literal(0L), calcType)
// If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
// UnscaledValue will be null if and only if x is null; helps with Average on decimals
@ -115,17 +131,21 @@ case class GeneratedAggregate(
}
val updateCount = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount)
val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
val updateSum = Coalesce(Add(Cast(expr, calcType), currentSum) :: currentSum :: Nil)
val resultType = expr.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 4, scale + 4)
case DecimalType.Unlimited =>
DecimalType.Unlimited
case _ =>
DoubleType
}
val result = Divide(Cast(currentSum, resultType), Cast(currentCount, resultType))
val result =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
If(EqualTo(currentCount, Literal(0L)),
Literal(null, a.dataType),
Cast(Divide(
Cast(currentSum, DecimalType.Unlimited),
Cast(currentCount, DecimalType.Unlimited)), a.dataType))
case _ =>
If(EqualTo(currentCount, Literal(0L)),
Literal(null, a.dataType),
Divide(Cast(currentSum, a.dataType), Cast(currentCount, a.dataType)))
}
AggregateEvaluation(
currentCount :: currentSum :: Nil,

View file

@ -156,22 +156,58 @@ class DslQuerySuite extends QueryTest {
test("average") {
checkAnswer(
testData2.groupBy()(avg('a)),
testData2.aggregate(avg('a)),
2.0)
checkAnswer(
testData2.aggregate(avg('a), sumDistinct('a)), // non-partial
(2.0, 6.0) :: Nil)
checkAnswer(
decimalData.aggregate(avg('a)),
BigDecimal(2.0))
checkAnswer(
decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial
(BigDecimal(2.0), BigDecimal(6)) :: Nil)
checkAnswer(
decimalData.aggregate(avg('a cast DecimalType(10, 2))),
BigDecimal(2.0))
checkAnswer(
decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial
(BigDecimal(2.0), BigDecimal(6)) :: Nil)
}
test("null average") {
checkAnswer(
testData3.groupBy()(avg('b)),
testData3.aggregate(avg('b)),
2.0)
checkAnswer(
testData3.groupBy()(avg('b), countDistinct('b)),
testData3.aggregate(avg('b), countDistinct('b)),
(2.0, 1) :: Nil)
checkAnswer(
testData3.aggregate(avg('b), sumDistinct('b)), // non-partial
(2.0, 2.0) :: Nil)
}
test("zero average") {
checkAnswer(
emptyTableData.aggregate(avg('a)),
null)
checkAnswer(
emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial
(null, null) :: Nil)
}
test("count") {
assert(testData2.count() === testData2.map(_ => 1).count())
checkAnswer(
testData2.aggregate(count('a), sumDistinct('a)), // non-partial
(6, 6.0) :: Nil)
}
test("null count") {
@ -186,13 +222,34 @@ class DslQuerySuite extends QueryTest {
)
checkAnswer(
testData3.groupBy()(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
(2, 1, 2, 2, 1) :: Nil
)
checkAnswer(
testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial
(1, 1, 2) :: Nil
)
}
test("zero count") {
assert(emptyTableData.count() === 0)
checkAnswer(
emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial
(0, null) :: Nil)
}
test("zero sum") {
checkAnswer(
emptyTableData.aggregate(sum('a)),
null)
}
test("zero sum distinct") {
checkAnswer(
emptyTableData.aggregate(sumDistinct('a)),
null)
}
test("except") {

View file

@ -54,6 +54,17 @@ object TestData {
TestData2(3, 2) :: Nil).toSchemaRDD
testData2.registerTempTable("testData2")
case class DecimalData(a: BigDecimal, b: BigDecimal)
val decimalData =
TestSQLContext.sparkContext.parallelize(
DecimalData(1, 1) ::
DecimalData(1, 2) ::
DecimalData(2, 1) ::
DecimalData(2, 2) ::
DecimalData(3, 1) ::
DecimalData(3, 2) :: Nil).toSchemaRDD
decimalData.registerTempTable("decimalData")
case class BinaryData(a: Array[Byte], b: Int)
val binaryData =
TestSQLContext.sparkContext.parallelize(