[SPARK-19691][SQL] Fix ClassCastException when calculating percentile of decimal column
## What changes were proposed in this pull request? This pr fixed a class-cast exception below; ``` scala> spark.range(10).selectExpr("cast (id as decimal) as x").selectExpr("percentile(x, 0.5)").collect() java.lang.ClassCastException: org.apache.spark.sql.types.Decimal cannot be cast to java.lang.Number at org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.update(Percentile.scala:141) at org.apache.spark.sql.catalyst.expressions.aggregate.Percentile.update(Percentile.scala:58) at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.update(interfaces.scala:514) at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171) at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$1$$anonfun$applyOrElse$1.apply(AggregationIterator.scala:171) at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:187) at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateProcessRow$1.apply(AggregationIterator.scala:181) at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:151) at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.<init>(ObjectAggregationIterator.scala:78) at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:109) at ``` This fix simply converts catalyst values (i.e., `Decimal`) into scala ones by using `CatalystTypeConverters`. ## How was this patch tested? Added a test in `DataFrameSuite`. Author: Takeshi Yamamuro <yamamuro@apache.org> Closes #17028 from maropu/SPARK-19691.
This commit is contained in:
parent
769aa0f1d2
commit
93aa427159
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
|
|||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
|
||||
import java.util
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
|
@ -61,7 +61,7 @@ case class Percentile(
|
|||
frequencyExpression : Expression,
|
||||
mutableAggBufferOffset: Int = 0,
|
||||
inputAggBufferOffset: Int = 0)
|
||||
extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with ImplicitCastInputTypes {
|
||||
extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] with ImplicitCastInputTypes {
|
||||
|
||||
def this(child: Expression, percentageExpression: Expression) = {
|
||||
this(child, percentageExpression, Literal(1L), 0, 0)
|
||||
|
@ -130,15 +130,20 @@ case class Percentile(
|
|||
}
|
||||
}
|
||||
|
||||
override def createAggregationBuffer(): OpenHashMap[Number, Long] = {
|
||||
private def toDoubleValue(d: Any): Double = d match {
|
||||
case d: Decimal => d.toDouble
|
||||
case n: Number => n.doubleValue
|
||||
}
|
||||
|
||||
override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = {
|
||||
// Initialize new counts map instance here.
|
||||
new OpenHashMap[Number, Long]()
|
||||
new OpenHashMap[AnyRef, Long]()
|
||||
}
|
||||
|
||||
override def update(
|
||||
buffer: OpenHashMap[Number, Long],
|
||||
input: InternalRow): OpenHashMap[Number, Long] = {
|
||||
val key = child.eval(input).asInstanceOf[Number]
|
||||
buffer: OpenHashMap[AnyRef, Long],
|
||||
input: InternalRow): OpenHashMap[AnyRef, Long] = {
|
||||
val key = child.eval(input).asInstanceOf[AnyRef]
|
||||
val frqValue = frequencyExpression.eval(input)
|
||||
|
||||
// Null values are ignored in counts map.
|
||||
|
@ -155,32 +160,32 @@ case class Percentile(
|
|||
}
|
||||
|
||||
override def merge(
|
||||
buffer: OpenHashMap[Number, Long],
|
||||
other: OpenHashMap[Number, Long]): OpenHashMap[Number, Long] = {
|
||||
buffer: OpenHashMap[AnyRef, Long],
|
||||
other: OpenHashMap[AnyRef, Long]): OpenHashMap[AnyRef, Long] = {
|
||||
other.foreach { case (key, count) =>
|
||||
buffer.changeValue(key, count, _ + count)
|
||||
}
|
||||
buffer
|
||||
}
|
||||
|
||||
override def eval(buffer: OpenHashMap[Number, Long]): Any = {
|
||||
override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
|
||||
generateOutput(getPercentiles(buffer))
|
||||
}
|
||||
|
||||
private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = {
|
||||
private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = {
|
||||
if (buffer.isEmpty) {
|
||||
return Seq.empty
|
||||
}
|
||||
|
||||
val sortedCounts = buffer.toSeq.sortBy(_._1)(
|
||||
child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]])
|
||||
child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]])
|
||||
val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) {
|
||||
case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
|
||||
}.tail
|
||||
val maxPosition = accumlatedCounts.last._2 - 1
|
||||
|
||||
percentages.map { percentile =>
|
||||
getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue()
|
||||
getPercentile(accumlatedCounts, maxPosition * percentile)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -200,7 +205,7 @@ case class Percentile(
|
|||
* This function has been based upon similar function from HIVE
|
||||
* `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`.
|
||||
*/
|
||||
private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = {
|
||||
private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: Double): Double = {
|
||||
// We may need to do linear interpolation to get the exact percentile
|
||||
val lower = position.floor.toLong
|
||||
val higher = position.ceil.toLong
|
||||
|
@ -213,18 +218,17 @@ case class Percentile(
|
|||
val lowerKey = aggreCounts(lowerIndex)._1
|
||||
if (higher == lower) {
|
||||
// no interpolation needed because position does not have a fraction
|
||||
return lowerKey
|
||||
return toDoubleValue(lowerKey)
|
||||
}
|
||||
|
||||
val higherKey = aggreCounts(higherIndex)._1
|
||||
if (higherKey == lowerKey) {
|
||||
// no interpolation needed because lower position and higher position has the same key
|
||||
return lowerKey
|
||||
return toDoubleValue(lowerKey)
|
||||
}
|
||||
|
||||
// Linear interpolation to get the exact percentile
|
||||
return (higher - position) * lowerKey.doubleValue() +
|
||||
(position - lower) * higherKey.doubleValue()
|
||||
(higher - position) * toDoubleValue(lowerKey) + (position - lower) * toDoubleValue(higherKey)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -238,7 +242,7 @@ case class Percentile(
|
|||
}
|
||||
}
|
||||
|
||||
override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = {
|
||||
override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = {
|
||||
val buffer = new Array[Byte](4 << 10) // 4K
|
||||
val bos = new ByteArrayOutputStream()
|
||||
val out = new DataOutputStream(bos)
|
||||
|
@ -261,11 +265,11 @@ case class Percentile(
|
|||
}
|
||||
}
|
||||
|
||||
override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = {
|
||||
override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = {
|
||||
val bis = new ByteArrayInputStream(bytes)
|
||||
val ins = new DataInputStream(bis)
|
||||
try {
|
||||
val counts = new OpenHashMap[Number, Long]
|
||||
val counts = new OpenHashMap[AnyRef, Long]
|
||||
// Read unsafeRow size and content in bytes.
|
||||
var sizeOfNextRow = ins.readInt()
|
||||
while (sizeOfNextRow >= 0) {
|
||||
|
@ -274,7 +278,7 @@ case class Percentile(
|
|||
val row = new UnsafeRow(2)
|
||||
row.pointTo(bs, sizeOfNextRow)
|
||||
// Insert the pairs into counts map.
|
||||
val key = row.get(0, child.dataType).asInstanceOf[Number]
|
||||
val key = row.get(0, child.dataType)
|
||||
val count = row.get(1, LongType).asInstanceOf[Long]
|
||||
counts.update(key, count)
|
||||
sizeOfNextRow = ins.readInt()
|
||||
|
|
|
@ -21,7 +21,6 @@ import org.apache.spark.SparkException
|
|||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.util.ArrayData
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -39,12 +38,12 @@ class PercentileSuite extends SparkFunSuite {
|
|||
val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5))
|
||||
|
||||
// Check empty serialize and deserialize
|
||||
val buffer = new OpenHashMap[Number, Long]()
|
||||
val buffer = new OpenHashMap[AnyRef, Long]()
|
||||
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
|
||||
|
||||
// Check non-empty buffer serializa and deserialize.
|
||||
data.foreach { key =>
|
||||
buffer.changeValue(key, 1L, _ + 1L)
|
||||
buffer.changeValue(new Integer(key), 1L, _ + 1L)
|
||||
}
|
||||
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
|
||||
}
|
||||
|
@ -58,25 +57,25 @@ class PercentileSuite extends SparkFunSuite {
|
|||
val agg = new Percentile(childExpression, percentageExpression)
|
||||
|
||||
// Test with rows without frequency
|
||||
val rows = (1 to count).map( x => Seq(x))
|
||||
runTest( agg, rows, expectedPercentiles)
|
||||
val rows = (1 to count).map(x => Seq(x))
|
||||
runTest(agg, rows, expectedPercentiles)
|
||||
|
||||
// Test with row with frequency. Second and third columns are frequency in Int and Long
|
||||
val countForFrequencyTest = 1000
|
||||
val rowsWithFrequency = (1 to countForFrequencyTest).map( x => Seq(x, x):+ x.toLong)
|
||||
val rowsWithFrequency = (1 to countForFrequencyTest).map(x => Seq(x, x):+ x.toLong)
|
||||
val expectedPercentilesWithFrquency = Seq(1.0, 500.0, 707.0, 866.0, 1000.0)
|
||||
|
||||
val frequencyExpressionInt = BoundReference(1, IntegerType, nullable = false)
|
||||
val aggInt = new Percentile(childExpression, percentageExpression, frequencyExpressionInt)
|
||||
runTest( aggInt, rowsWithFrequency, expectedPercentilesWithFrquency)
|
||||
runTest(aggInt, rowsWithFrequency, expectedPercentilesWithFrquency)
|
||||
|
||||
val frequencyExpressionLong = BoundReference(2, LongType, nullable = false)
|
||||
val aggLong = new Percentile(childExpression, percentageExpression, frequencyExpressionLong)
|
||||
runTest( aggLong, rowsWithFrequency, expectedPercentilesWithFrquency)
|
||||
runTest(aggLong, rowsWithFrequency, expectedPercentilesWithFrquency)
|
||||
|
||||
// Run test with Flatten data
|
||||
val flattenRows = (1 to countForFrequencyTest).flatMap( current =>
|
||||
(1 to current).map( y => current )).map( Seq(_))
|
||||
val flattenRows = (1 to countForFrequencyTest).flatMap(current =>
|
||||
(1 to current).map(y => current )).map(Seq(_))
|
||||
runTest(agg, flattenRows, expectedPercentilesWithFrquency)
|
||||
}
|
||||
|
||||
|
@ -153,7 +152,7 @@ class PercentileSuite extends SparkFunSuite {
|
|||
}
|
||||
|
||||
val validFrequencyTypes = Seq(ByteType, ShortType, IntegerType, LongType)
|
||||
for ( dataType <- validDataTypes;
|
||||
for (dataType <- validDataTypes;
|
||||
frequencyType <- validFrequencyTypes) {
|
||||
val child = AttributeReference("a", dataType)()
|
||||
val frq = AttributeReference("frq", frequencyType)()
|
||||
|
@ -176,7 +175,7 @@ class PercentileSuite extends SparkFunSuite {
|
|||
StringType, DateType, TimestampType,
|
||||
CalendarIntervalType, NullType)
|
||||
|
||||
for( dataType <- invalidDataTypes;
|
||||
for(dataType <- invalidDataTypes;
|
||||
frequencyType <- validFrequencyTypes) {
|
||||
val child = AttributeReference("a", dataType)()
|
||||
val frq = AttributeReference("frq", frequencyType)()
|
||||
|
@ -186,7 +185,7 @@ class PercentileSuite extends SparkFunSuite {
|
|||
s"'`a`' is of ${dataType.simpleString} type."))
|
||||
}
|
||||
|
||||
for( dataType <- validDataTypes;
|
||||
for(dataType <- validDataTypes;
|
||||
frequencyType <- invalidFrequencyDataTypes) {
|
||||
val child = AttributeReference("a", dataType)()
|
||||
val frq = AttributeReference("frq", frequencyType)()
|
||||
|
@ -294,11 +293,11 @@ class PercentileSuite extends SparkFunSuite {
|
|||
agg.update(buffer, InternalRow(1, -5))
|
||||
agg.eval(buffer)
|
||||
}
|
||||
assert( caught.getMessage.startsWith("Negative values found in "))
|
||||
assert(caught.getMessage.startsWith("Negative values found in "))
|
||||
}
|
||||
|
||||
private def compareEquals(
|
||||
left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = {
|
||||
left: OpenHashMap[AnyRef, Long], right: OpenHashMap[AnyRef, Long]): Boolean = {
|
||||
left.size == right.size && left.forall { case (key, count) =>
|
||||
right.apply(key) == count
|
||||
}
|
||||
|
|
|
@ -1702,4 +1702,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
|
|||
val df = Seq(123L -> "123", 19157170390056973L -> "19157170390056971").toDF("i", "j")
|
||||
checkAnswer(df.select($"i" === $"j"), Row(true) :: Row(false) :: Nil)
|
||||
}
|
||||
|
||||
test("SPARK-19691 Calculating percentile of decimal column fails with ClassCastException") {
|
||||
val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)")
|
||||
checkAnswer(df, Row(BigDecimal(0.0)) :: Nil)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue