[SPARK-4318][SQL] Fix empty sum distinct.
Executing sum distinct for empty table throws `java.lang.UnsupportedOperationException: empty.reduceLeft`. Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #3184 from ueshin/issues/SPARK-4318 and squashes the following commits: 8168c42 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4318 66fdb0a [Takuya UESHIN] Re-refine aggregate functions. 6186eb4 [Takuya UESHIN] Fix Sum of GeneratedAggregate. d2975f6 [Takuya UESHIN] Refine Sum and Average of GeneratedAggregate. 1bba675 [Takuya UESHIN] Refine Sum, SumDistinct and Average functions. 917e533 [Takuya UESHIN] Use aggregate instead of groupBy(). 1a5f874 [Takuya UESHIN] Add tests to be executed as non-partial aggregation. a5a57d2 [Takuya UESHIN] Fix empty Average. 22799dc [Takuya UESHIN] Fix empty Sum and SumDistinct. 65b7dd2 [Takuya UESHIN] Fix empty sum distinct.
This commit is contained in:
parent
98e9419784
commit
2c2e7a44db
|
@ -158,7 +158,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
|
|||
|
||||
override def asPartial: SplitEvaluation = {
|
||||
val partialCount = Alias(Count(child), "PartialCount")()
|
||||
SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
|
||||
SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil)
|
||||
}
|
||||
|
||||
override def newInstance() = new CountFunction(child, this)
|
||||
|
@ -285,7 +285,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
|
|||
|
||||
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
|
||||
|
||||
override def nullable = false
|
||||
override def nullable = true
|
||||
|
||||
override def dataType = child.dataType match {
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
|
@ -299,12 +299,12 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
|
|||
override def toString = s"AVG($child)"
|
||||
|
||||
override def asPartial: SplitEvaluation = {
|
||||
val partialSum = Alias(Sum(child), "PartialSum")()
|
||||
val partialCount = Alias(Count(child), "PartialCount")()
|
||||
|
||||
child.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
// Turn the results to unlimited decimals for the division, before going back to fixed
|
||||
// Turn the child to unlimited decimals for calculation, before going back to fixed
|
||||
val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
|
||||
val partialCount = Alias(Count(child), "PartialCount")()
|
||||
|
||||
val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited)
|
||||
val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited)
|
||||
SplitEvaluation(
|
||||
|
@ -312,6 +312,9 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
|
|||
partialCount :: partialSum :: Nil)
|
||||
|
||||
case _ =>
|
||||
val partialSum = Alias(Sum(child), "PartialSum")()
|
||||
val partialCount = Alias(Count(child), "PartialCount")()
|
||||
|
||||
val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
|
||||
val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
|
||||
SplitEvaluation(
|
||||
|
@ -325,7 +328,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
|
|||
|
||||
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
|
||||
|
||||
override def nullable = false
|
||||
override def nullable = true
|
||||
|
||||
override def dataType = child.dataType match {
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
|
@ -339,10 +342,19 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
|
|||
override def toString = s"SUM($child)"
|
||||
|
||||
override def asPartial: SplitEvaluation = {
|
||||
val partialSum = Alias(Sum(child), "PartialSum")()
|
||||
SplitEvaluation(
|
||||
Sum(partialSum.toAttribute),
|
||||
partialSum :: Nil)
|
||||
child.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
|
||||
SplitEvaluation(
|
||||
Cast(Sum(partialSum.toAttribute), dataType),
|
||||
partialSum :: Nil)
|
||||
|
||||
case _ =>
|
||||
val partialSum = Alias(Sum(child), "PartialSum")()
|
||||
SplitEvaluation(
|
||||
Sum(partialSum.toAttribute),
|
||||
partialSum :: Nil)
|
||||
}
|
||||
}
|
||||
|
||||
override def newInstance() = new SumFunction(child, this)
|
||||
|
@ -351,7 +363,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
|
|||
case class SumDistinct(child: Expression)
|
||||
extends AggregateExpression with trees.UnaryNode[Expression] {
|
||||
|
||||
override def nullable = false
|
||||
override def nullable = true
|
||||
|
||||
override def dataType = child.dataType match {
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
|
@ -401,16 +413,37 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
|
|||
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
private val zero = Cast(Literal(0), expr.dataType)
|
||||
private val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
DecimalType.Unlimited
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
||||
private val zero = Cast(Literal(0), calcType)
|
||||
|
||||
private var count: Long = _
|
||||
private val sum = MutableLiteral(zero.eval(null), expr.dataType)
|
||||
private val sumAsDouble = Cast(sum, DoubleType)
|
||||
private val sum = MutableLiteral(zero.eval(null), calcType)
|
||||
|
||||
private def addFunction(value: Any) = Add(sum, Literal(value))
|
||||
private def addFunction(value: Any) = Add(sum, Cast(Literal(value, expr.dataType), calcType))
|
||||
|
||||
override def eval(input: Row): Any =
|
||||
sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
|
||||
override def eval(input: Row): Any = {
|
||||
if (count == 0L) {
|
||||
null
|
||||
} else {
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
Cast(Divide(
|
||||
Cast(sum, DecimalType.Unlimited),
|
||||
Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null)
|
||||
case _ =>
|
||||
Divide(
|
||||
Cast(sum, dataType),
|
||||
Cast(Literal(count), dataType)).eval(null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def update(input: Row): Unit = {
|
||||
val evaluatedExpr = expr.eval(input)
|
||||
|
@ -475,17 +508,31 @@ case class ApproxCountDistinctMergeFunction(
|
|||
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
private val zero = Cast(Literal(0), expr.dataType)
|
||||
private val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
DecimalType.Unlimited
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
||||
private val sum = MutableLiteral(zero.eval(null), expr.dataType)
|
||||
private val zero = Cast(Literal(0), calcType)
|
||||
|
||||
private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
|
||||
private val sum = MutableLiteral(null, calcType)
|
||||
|
||||
private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum))
|
||||
|
||||
override def update(input: Row): Unit = {
|
||||
sum.update(addFunction, input)
|
||||
}
|
||||
|
||||
override def eval(input: Row): Any = sum.eval(null)
|
||||
override def eval(input: Row): Any = {
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
Cast(sum, dataType).eval(null)
|
||||
case _ => sum.eval(null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
|
||||
|
@ -502,8 +549,16 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
|
|||
}
|
||||
}
|
||||
|
||||
override def eval(input: Row): Any =
|
||||
seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)
|
||||
override def eval(input: Row): Any = {
|
||||
if (seen.size == 0) {
|
||||
null
|
||||
} else {
|
||||
Cast(Literal(
|
||||
seen.reduceLeft(
|
||||
dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
|
||||
dataType).eval(null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class CountDistinctFunction(
|
||||
|
|
|
@ -83,29 +83,45 @@ case class GeneratedAggregate(
|
|||
|
||||
AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
|
||||
|
||||
case Sum(expr) =>
|
||||
val resultType = expr.dataType match {
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
DecimalType(precision + 10, scale)
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
case s @ Sum(expr) =>
|
||||
val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
DecimalType.Unlimited
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
||||
val currentSum = AttributeReference("currentSum", resultType, nullable = false)()
|
||||
val initialValue = Cast(Literal(0L), resultType)
|
||||
val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
|
||||
val initialValue = Literal(null, calcType)
|
||||
|
||||
// Coalasce avoids double calculation...
|
||||
// but really, common sub expression elimination would be better....
|
||||
val updateFunction = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
|
||||
val result = currentSum
|
||||
val zero = Cast(Literal(0), calcType)
|
||||
val updateFunction = Coalesce(
|
||||
Add(Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: Nil)
|
||||
val result =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
Cast(currentSum, s.dataType)
|
||||
case _ => currentSum
|
||||
}
|
||||
|
||||
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
|
||||
|
||||
case a @ Average(expr) =>
|
||||
val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
DecimalType.Unlimited
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
||||
val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
|
||||
val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
|
||||
val currentSum = AttributeReference("currentSum", calcType, nullable = false)()
|
||||
val initialCount = Literal(0L)
|
||||
val initialSum = Cast(Literal(0L), expr.dataType)
|
||||
val initialSum = Cast(Literal(0L), calcType)
|
||||
|
||||
// If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
|
||||
// UnscaledValue will be null if and only if x is null; helps with Average on decimals
|
||||
|
@ -115,17 +131,21 @@ case class GeneratedAggregate(
|
|||
}
|
||||
|
||||
val updateCount = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount)
|
||||
val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
|
||||
val updateSum = Coalesce(Add(Cast(expr, calcType), currentSum) :: currentSum :: Nil)
|
||||
|
||||
val resultType = expr.dataType match {
|
||||
case DecimalType.Fixed(precision, scale) =>
|
||||
DecimalType(precision + 4, scale + 4)
|
||||
case DecimalType.Unlimited =>
|
||||
DecimalType.Unlimited
|
||||
case _ =>
|
||||
DoubleType
|
||||
}
|
||||
val result = Divide(Cast(currentSum, resultType), Cast(currentCount, resultType))
|
||||
val result =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
If(EqualTo(currentCount, Literal(0L)),
|
||||
Literal(null, a.dataType),
|
||||
Cast(Divide(
|
||||
Cast(currentSum, DecimalType.Unlimited),
|
||||
Cast(currentCount, DecimalType.Unlimited)), a.dataType))
|
||||
case _ =>
|
||||
If(EqualTo(currentCount, Literal(0L)),
|
||||
Literal(null, a.dataType),
|
||||
Divide(Cast(currentSum, a.dataType), Cast(currentCount, a.dataType)))
|
||||
}
|
||||
|
||||
AggregateEvaluation(
|
||||
currentCount :: currentSum :: Nil,
|
||||
|
|
|
@ -156,22 +156,58 @@ class DslQuerySuite extends QueryTest {
|
|||
|
||||
test("average") {
|
||||
checkAnswer(
|
||||
testData2.groupBy()(avg('a)),
|
||||
testData2.aggregate(avg('a)),
|
||||
2.0)
|
||||
|
||||
checkAnswer(
|
||||
testData2.aggregate(avg('a), sumDistinct('a)), // non-partial
|
||||
(2.0, 6.0) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
decimalData.aggregate(avg('a)),
|
||||
BigDecimal(2.0))
|
||||
checkAnswer(
|
||||
decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial
|
||||
(BigDecimal(2.0), BigDecimal(6)) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
decimalData.aggregate(avg('a cast DecimalType(10, 2))),
|
||||
BigDecimal(2.0))
|
||||
checkAnswer(
|
||||
decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial
|
||||
(BigDecimal(2.0), BigDecimal(6)) :: Nil)
|
||||
}
|
||||
|
||||
test("null average") {
|
||||
checkAnswer(
|
||||
testData3.groupBy()(avg('b)),
|
||||
testData3.aggregate(avg('b)),
|
||||
2.0)
|
||||
|
||||
checkAnswer(
|
||||
testData3.groupBy()(avg('b), countDistinct('b)),
|
||||
testData3.aggregate(avg('b), countDistinct('b)),
|
||||
(2.0, 1) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
testData3.aggregate(avg('b), sumDistinct('b)), // non-partial
|
||||
(2.0, 2.0) :: Nil)
|
||||
}
|
||||
|
||||
test("zero average") {
|
||||
checkAnswer(
|
||||
emptyTableData.aggregate(avg('a)),
|
||||
null)
|
||||
|
||||
checkAnswer(
|
||||
emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial
|
||||
(null, null) :: Nil)
|
||||
}
|
||||
|
||||
test("count") {
|
||||
assert(testData2.count() === testData2.map(_ => 1).count())
|
||||
|
||||
checkAnswer(
|
||||
testData2.aggregate(count('a), sumDistinct('a)), // non-partial
|
||||
(6, 6.0) :: Nil)
|
||||
}
|
||||
|
||||
test("null count") {
|
||||
|
@ -186,13 +222,34 @@ class DslQuerySuite extends QueryTest {
|
|||
)
|
||||
|
||||
checkAnswer(
|
||||
testData3.groupBy()(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
|
||||
testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
|
||||
(2, 1, 2, 2, 1) :: Nil
|
||||
)
|
||||
|
||||
checkAnswer(
|
||||
testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial
|
||||
(1, 1, 2) :: Nil
|
||||
)
|
||||
}
|
||||
|
||||
test("zero count") {
|
||||
assert(emptyTableData.count() === 0)
|
||||
|
||||
checkAnswer(
|
||||
emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial
|
||||
(0, null) :: Nil)
|
||||
}
|
||||
|
||||
test("zero sum") {
|
||||
checkAnswer(
|
||||
emptyTableData.aggregate(sum('a)),
|
||||
null)
|
||||
}
|
||||
|
||||
test("zero sum distinct") {
|
||||
checkAnswer(
|
||||
emptyTableData.aggregate(sumDistinct('a)),
|
||||
null)
|
||||
}
|
||||
|
||||
test("except") {
|
||||
|
|
|
@ -54,6 +54,17 @@ object TestData {
|
|||
TestData2(3, 2) :: Nil).toSchemaRDD
|
||||
testData2.registerTempTable("testData2")
|
||||
|
||||
case class DecimalData(a: BigDecimal, b: BigDecimal)
|
||||
val decimalData =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
DecimalData(1, 1) ::
|
||||
DecimalData(1, 2) ::
|
||||
DecimalData(2, 1) ::
|
||||
DecimalData(2, 2) ::
|
||||
DecimalData(3, 1) ::
|
||||
DecimalData(3, 2) :: Nil).toSchemaRDD
|
||||
decimalData.registerTempTable("decimalData")
|
||||
|
||||
case class BinaryData(a: Array[Byte], b: Int)
|
||||
val binaryData =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
|
|
Loading…
Reference in a new issue