[SPARK-19691][SQL] Fix ClassCastException when calculating percentile of decimal column

## What changes were proposed in this pull request?
This pr fixed a class-cast exception below;
```
scala> spark.range(10).selectExpr("cast (id as decimal) as x").selectExpr("percentile(x, 0.5)").collect()
 java.lang.ClassCastException: org.apache.spark.sql.types.Decimal cannot be cast to java.lang.Number
	at org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.update(Percentile.scala:141)
	at org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.update(Percentile.scala:58)
	at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.update(interfaces.scala:514)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:187)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:181)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:151)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.<init>(ObjectAggregationIterator.scala:78)
	at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:109)
	at
```
This fix simply converts catalyst values (i.e., `Decimal`) into scala ones by using `CatalystTypeConverters`.

## How was this patch tested?
Added a test in `DataFrameSuite`.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #17028 from maropu/SPARK-19691.
This commit is contained in:
Takeshi Yamamuro 2017-02-23 16:28:36 +01:00 committed by Herman van Hovell
parent 769aa0f1d2
commit 93aa427159
3 changed files with 45 additions and 37 deletions

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.util
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
@ -61,7 +61,7 @@ case class Percentile(
frequencyExpression : Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with ImplicitCastInputTypes {
extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] with ImplicitCastInputTypes {
def this(child: Expression, percentageExpression: Expression) = {
this(child, percentageExpression, Literal(1L), 0, 0)
@ -130,15 +130,20 @@ case class Percentile(
}
}
override def createAggregationBuffer(): OpenHashMap[Number, Long] = {
private def toDoubleValue(d: Any): Double = d match {
case d: Decimal => d.toDouble
case n: Number => n.doubleValue
}
override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = {
// Initialize new counts map instance here.
new OpenHashMap[Number, Long]()
new OpenHashMap[AnyRef, Long]()
}
override def update(
buffer: OpenHashMap[Number, Long],
input: InternalRow): OpenHashMap[Number, Long] = {
val key = child.eval(input).asInstanceOf[Number]
buffer: OpenHashMap[AnyRef, Long],
input: InternalRow): OpenHashMap[AnyRef, Long] = {
val key = child.eval(input).asInstanceOf[AnyRef]
val frqValue = frequencyExpression.eval(input)
// Null values are ignored in counts map.
@ -155,32 +160,32 @@ case class Percentile(
}
override def merge(
buffer: OpenHashMap[Number, Long],
other: OpenHashMap[Number, Long]): OpenHashMap[Number, Long] = {
buffer: OpenHashMap[AnyRef, Long],
other: OpenHashMap[AnyRef, Long]): OpenHashMap[AnyRef, Long] = {
other.foreach { case (key, count) =>
buffer.changeValue(key, count, _ + count)
}
buffer
}
override def eval(buffer: OpenHashMap[Number, Long]): Any = {
override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
generateOutput(getPercentiles(buffer))
}
private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = {
private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = {
if (buffer.isEmpty) {
return Seq.empty
}
val sortedCounts = buffer.toSeq.sortBy(_._1)(
child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]])
child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]])
val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) {
case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
}.tail
val maxPosition = accumlatedCounts.last._2 - 1
percentages.map { percentile =>
getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue()
getPercentile(accumlatedCounts, maxPosition * percentile)
}
}
@ -200,7 +205,7 @@ case class Percentile(
* This function has been based upon similar function from HIVE
* `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`.
*/
private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = {
private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: Double): Double = {
// We may need to do linear interpolation to get the exact percentile
val lower = position.floor.toLong
val higher = position.ceil.toLong
@ -213,18 +218,17 @@ case class Percentile(
val lowerKey = aggreCounts(lowerIndex)._1
if (higher == lower) {
// no interpolation needed because position does not have a fraction
return lowerKey
return toDoubleValue(lowerKey)
}
val higherKey = aggreCounts(higherIndex)._1
if (higherKey == lowerKey) {
// no interpolation needed because lower position and higher position has the same key
return lowerKey
return toDoubleValue(lowerKey)
}
// Linear interpolation to get the exact percentile
return (higher - position) * lowerKey.doubleValue() +
(position - lower) * higherKey.doubleValue()
(higher - position) * toDoubleValue(lowerKey) + (position - lower) * toDoubleValue(higherKey)
}
/**
@ -238,7 +242,7 @@ case class Percentile(
}
}
override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = {
override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = {
val buffer = new Array[Byte](4 << 10) // 4K
val bos = new ByteArrayOutputStream()
val out = new DataOutputStream(bos)
@ -261,11 +265,11 @@ case class Percentile(
}
}
override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = {
override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = {
val bis = new ByteArrayInputStream(bytes)
val ins = new DataInputStream(bis)
try {
val counts = new OpenHashMap[Number, Long]
val counts = new OpenHashMap[AnyRef, Long]
// Read unsafeRow size and content in bytes.
var sizeOfNextRow = ins.readInt()
while (sizeOfNextRow >= 0) {
@ -274,7 +278,7 @@ case class Percentile(
val row = new UnsafeRow(2)
row.pointTo(bs, sizeOfNextRow)
// Insert the pairs into counts map.
val key = row.get(0, child.dataType).asInstanceOf[Number]
val key = row.get(0, child.dataType)
val count = row.get(1, LongType).asInstanceOf[Long]
counts.update(key, count)
sizeOfNextRow = ins.readInt()

View file

@ -21,7 +21,6 @@ import org.apache.spark.SparkException
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._
@ -39,12 +38,12 @@ class PercentileSuite extends SparkFunSuite {
val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5))
// Check empty serialize and deserialize
val buffer = new OpenHashMap[Number, Long]()
val buffer = new OpenHashMap[AnyRef, Long]()
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
// Check non-empty buffer serializa and deserialize.
data.foreach { key =>
buffer.changeValue(key, 1L, _ + 1L)
buffer.changeValue(new Integer(key), 1L, _ + 1L)
}
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
}
@ -58,25 +57,25 @@ class PercentileSuite extends SparkFunSuite {
val agg = new Percentile(childExpression, percentageExpression)
// Test with rows without frequency
val rows = (1 to count).map( x => Seq(x))
runTest( agg, rows, expectedPercentiles)
val rows = (1 to count).map(x => Seq(x))
runTest(agg, rows, expectedPercentiles)
// Test with row with frequency. Second and third columns are frequency in Int and Long
val countForFrequencyTest = 1000
val rowsWithFrequency = (1 to countForFrequencyTest).map( x => Seq(x, x):+ x.toLong)
val rowsWithFrequency = (1 to countForFrequencyTest).map(x => Seq(x, x):+ x.toLong)
val expectedPercentilesWithFrquency = Seq(1.0, 500.0, 707.0, 866.0, 1000.0)
val frequencyExpressionInt = BoundReference(1, IntegerType, nullable = false)
val aggInt = new Percentile(childExpression, percentageExpression, frequencyExpressionInt)
runTest( aggInt, rowsWithFrequency, expectedPercentilesWithFrquency)
runTest(aggInt, rowsWithFrequency, expectedPercentilesWithFrquency)
val frequencyExpressionLong = BoundReference(2, LongType, nullable = false)
val aggLong = new Percentile(childExpression, percentageExpression, frequencyExpressionLong)
runTest( aggLong, rowsWithFrequency, expectedPercentilesWithFrquency)
runTest(aggLong, rowsWithFrequency, expectedPercentilesWithFrquency)
// Run test with Flatten data
val flattenRows = (1 to countForFrequencyTest).flatMap( current =>
(1 to current).map( y => current )).map( Seq(_))
val flattenRows = (1 to countForFrequencyTest).flatMap(current =>
(1 to current).map(y => current )).map(Seq(_))
runTest(agg, flattenRows, expectedPercentilesWithFrquency)
}
@ -153,7 +152,7 @@ class PercentileSuite extends SparkFunSuite {
}
val validFrequencyTypes = Seq(ByteType, ShortType, IntegerType, LongType)
for ( dataType <- validDataTypes;
for (dataType <- validDataTypes;
frequencyType <- validFrequencyTypes) {
val child = AttributeReference("a", dataType)()
val frq = AttributeReference("frq", frequencyType)()
@ -176,7 +175,7 @@ class PercentileSuite extends SparkFunSuite {
StringType, DateType, TimestampType,
CalendarIntervalType, NullType)
for( dataType <- invalidDataTypes;
for(dataType <- invalidDataTypes;
frequencyType <- validFrequencyTypes) {
val child = AttributeReference("a", dataType)()
val frq = AttributeReference("frq", frequencyType)()
@ -186,7 +185,7 @@ class PercentileSuite extends SparkFunSuite {
s"'`a`' is of ${dataType.simpleString} type."))
}
for( dataType <- validDataTypes;
for(dataType <- validDataTypes;
frequencyType <- invalidFrequencyDataTypes) {
val child = AttributeReference("a", dataType)()
val frq = AttributeReference("frq", frequencyType)()
@ -294,11 +293,11 @@ class PercentileSuite extends SparkFunSuite {
agg.update(buffer, InternalRow(1, -5))
agg.eval(buffer)
}
assert( caught.getMessage.startsWith("Negative values found in "))
assert(caught.getMessage.startsWith("Negative values found in "))
}
private def compareEquals(
left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = {
left: OpenHashMap[AnyRef, Long], right: OpenHashMap[AnyRef, Long]): Boolean = {
left.size == right.size && left.forall { case (key, count) =>
right.apply(key) == count
}

View file

@ -1702,4 +1702,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df = Seq(123L -> "123", 19157170390056973L -> "19157170390056971").toDF("i", "j")
checkAnswer(df.select($"i" === $"j"), Row(true) :: Row(false) :: Nil)
}
test("SPARK-19691 Calculating percentile of decimal column fails with ClassCastException") {
val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)")
checkAnswer(df, Row(BigDecimal(0.0)) :: Nil)
}
}