diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 6b7cf7991d..8433a93ea3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.util -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions._ @@ -61,7 +61,7 @@ case class Percentile( frequencyExpression : Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with ImplicitCastInputTypes { + extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] with ImplicitCastInputTypes { def this(child: Expression, percentageExpression: Expression) = { this(child, percentageExpression, Literal(1L), 0, 0) @@ -130,15 +130,20 @@ case class Percentile( } } - override def createAggregationBuffer(): OpenHashMap[Number, Long] = { + private def toDoubleValue(d: Any): Double = d match { + case d: Decimal => d.toDouble + case n: Number => n.doubleValue + } + + override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = { // Initialize new counts map instance here. - new OpenHashMap[Number, Long]() + new OpenHashMap[AnyRef, Long]() } override def update( - buffer: OpenHashMap[Number, Long], - input: InternalRow): OpenHashMap[Number, Long] = { - val key = child.eval(input).asInstanceOf[Number] + buffer: OpenHashMap[AnyRef, Long], + input: InternalRow): OpenHashMap[AnyRef, Long] = { + val key = child.eval(input).asInstanceOf[AnyRef] val frqValue = frequencyExpression.eval(input) // Null values are ignored in counts map. @@ -155,32 +160,32 @@ case class Percentile( } override def merge( - buffer: OpenHashMap[Number, Long], - other: OpenHashMap[Number, Long]): OpenHashMap[Number, Long] = { + buffer: OpenHashMap[AnyRef, Long], + other: OpenHashMap[AnyRef, Long]): OpenHashMap[AnyRef, Long] = { other.foreach { case (key, count) => buffer.changeValue(key, count, _ + count) } buffer } - override def eval(buffer: OpenHashMap[Number, Long]): Any = { + override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = { generateOutput(getPercentiles(buffer)) } - private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = { + private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = { if (buffer.isEmpty) { return Seq.empty } val sortedCounts = buffer.toSeq.sortBy(_._1)( - child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]]) + child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]]) val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { case ((key1, count1), (key2, count2)) => (key2, count1 + count2) }.tail val maxPosition = accumlatedCounts.last._2 - 1 percentages.map { percentile => - getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue() + getPercentile(accumlatedCounts, maxPosition * percentile) } } @@ -200,7 +205,7 @@ case class Percentile( * This function has been based upon similar function from HIVE * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`. */ - private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = { + private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: Double): Double = { // We may need to do linear interpolation to get the exact percentile val lower = position.floor.toLong val higher = position.ceil.toLong @@ -213,18 +218,17 @@ case class Percentile( val lowerKey = aggreCounts(lowerIndex)._1 if (higher == lower) { // no interpolation needed because position does not have a fraction - return lowerKey + return toDoubleValue(lowerKey) } val higherKey = aggreCounts(higherIndex)._1 if (higherKey == lowerKey) { // no interpolation needed because lower position and higher position has the same key - return lowerKey + return toDoubleValue(lowerKey) } // Linear interpolation to get the exact percentile - return (higher - position) * lowerKey.doubleValue() + - (position - lower) * higherKey.doubleValue() + (higher - position) * toDoubleValue(lowerKey) + (position - lower) * toDoubleValue(higherKey) } /** @@ -238,7 +242,7 @@ case class Percentile( } } - override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = { + override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = { val buffer = new Array[Byte](4 << 10) // 4K val bos = new ByteArrayOutputStream() val out = new DataOutputStream(bos) @@ -261,11 +265,11 @@ case class Percentile( } } - override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = { + override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = { val bis = new ByteArrayInputStream(bytes) val ins = new DataInputStream(bis) try { - val counts = new OpenHashMap[Number, Long] + val counts = new OpenHashMap[AnyRef, Long] // Read unsafeRow size and content in bytes. var sizeOfNextRow = ins.readInt() while (sizeOfNextRow >= 0) { @@ -274,7 +278,7 @@ case class Percentile( val row = new UnsafeRow(2) row.pointTo(bs, sizeOfNextRow) // Insert the pairs into counts map. - val key = row.get(0, child.dataType).asInstanceOf[Number] + val key = row.get(0, child.dataType) val count = row.get(1, LongType).asInstanceOf[Long] counts.update(key, count) sizeOfNextRow = ins.readInt() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index 1533fe5f90..2420ba513f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkException import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._ -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ @@ -39,12 +38,12 @@ class PercentileSuite extends SparkFunSuite { val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5)) // Check empty serialize and deserialize - val buffer = new OpenHashMap[Number, Long]() + val buffer = new OpenHashMap[AnyRef, Long]() assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) // Check non-empty buffer serializa and deserialize. data.foreach { key => - buffer.changeValue(key, 1L, _ + 1L) + buffer.changeValue(new Integer(key), 1L, _ + 1L) } assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) } @@ -58,25 +57,25 @@ class PercentileSuite extends SparkFunSuite { val agg = new Percentile(childExpression, percentageExpression) // Test with rows without frequency - val rows = (1 to count).map( x => Seq(x)) - runTest( agg, rows, expectedPercentiles) + val rows = (1 to count).map(x => Seq(x)) + runTest(agg, rows, expectedPercentiles) // Test with row with frequency. Second and third columns are frequency in Int and Long val countForFrequencyTest = 1000 - val rowsWithFrequency = (1 to countForFrequencyTest).map( x => Seq(x, x):+ x.toLong) + val rowsWithFrequency = (1 to countForFrequencyTest).map(x => Seq(x, x):+ x.toLong) val expectedPercentilesWithFrquency = Seq(1.0, 500.0, 707.0, 866.0, 1000.0) val frequencyExpressionInt = BoundReference(1, IntegerType, nullable = false) val aggInt = new Percentile(childExpression, percentageExpression, frequencyExpressionInt) - runTest( aggInt, rowsWithFrequency, expectedPercentilesWithFrquency) + runTest(aggInt, rowsWithFrequency, expectedPercentilesWithFrquency) val frequencyExpressionLong = BoundReference(2, LongType, nullable = false) val aggLong = new Percentile(childExpression, percentageExpression, frequencyExpressionLong) - runTest( aggLong, rowsWithFrequency, expectedPercentilesWithFrquency) + runTest(aggLong, rowsWithFrequency, expectedPercentilesWithFrquency) // Run test with Flatten data - val flattenRows = (1 to countForFrequencyTest).flatMap( current => - (1 to current).map( y => current )).map( Seq(_)) + val flattenRows = (1 to countForFrequencyTest).flatMap(current => + (1 to current).map(y => current )).map(Seq(_)) runTest(agg, flattenRows, expectedPercentilesWithFrquency) } @@ -153,7 +152,7 @@ class PercentileSuite extends SparkFunSuite { } val validFrequencyTypes = Seq(ByteType, ShortType, IntegerType, LongType) - for ( dataType <- validDataTypes; + for (dataType <- validDataTypes; frequencyType <- validFrequencyTypes) { val child = AttributeReference("a", dataType)() val frq = AttributeReference("frq", frequencyType)() @@ -176,7 +175,7 @@ class PercentileSuite extends SparkFunSuite { StringType, DateType, TimestampType, CalendarIntervalType, NullType) - for( dataType <- invalidDataTypes; + for(dataType <- invalidDataTypes; frequencyType <- validFrequencyTypes) { val child = AttributeReference("a", dataType)() val frq = AttributeReference("frq", frequencyType)() @@ -186,7 +185,7 @@ class PercentileSuite extends SparkFunSuite { s"'`a`' is of ${dataType.simpleString} type.")) } - for( dataType <- validDataTypes; + for(dataType <- validDataTypes; frequencyType <- invalidFrequencyDataTypes) { val child = AttributeReference("a", dataType)() val frq = AttributeReference("frq", frequencyType)() @@ -294,11 +293,11 @@ class PercentileSuite extends SparkFunSuite { agg.update(buffer, InternalRow(1, -5)) agg.eval(buffer) } - assert( caught.getMessage.startsWith("Negative values found in ")) + assert(caught.getMessage.startsWith("Negative values found in ")) } private def compareEquals( - left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = { + left: OpenHashMap[AnyRef, Long], right: OpenHashMap[AnyRef, Long]): Boolean = { left.size == right.size && left.forall { case (key, count) => right.apply(key) == count } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e6338ab7cd..5e65436079 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1702,4 +1702,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = Seq(123L -> "123", 19157170390056973L -> "19157170390056971").toDF("i", "j") checkAnswer(df.select($"i" === $"j"), Row(true) :: Row(false) :: Nil) } + + test("SPARK-19691 Calculating percentile of decimal column fails with ClassCastException") { + val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)") + checkAnswer(df, Row(BigDecimal(0.0)) :: Nil) + } }