[SPARK-24119][SQL] Add interpreted execution to SortPrefix expression
## What changes were proposed in this pull request? Implemented eval in SortPrefix expression. ## How was this patch tested? - ran existing sbt SQL tests - added unit test - ran existing Python SQL tests - manual tests: disabling codegen -- patching code to disable beyond what spark.sql.codegen.wholeStage=false can do -- and running sbt SQL tests Author: Bruce Robbins <bersprockets@gmail.com> Closes #21231 from bersprockets/sortprefixeval.
This commit is contained in:
parent
e76b0124fb
commit
1462bba4fd
|
@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
|||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._
|
||||
|
||||
abstract sealed class SortDirection {
|
||||
|
@ -148,7 +149,41 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
|
|||
(!child.isAscending && child.nullOrdering == NullsLast)
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
|
||||
private lazy val calcPrefix: Any => Long = child.child.dataType match {
|
||||
case BooleanType => (raw) =>
|
||||
if (raw.asInstanceOf[Boolean]) 1 else 0
|
||||
case DateType | TimestampType | _: IntegralType => (raw) =>
|
||||
raw.asInstanceOf[java.lang.Number].longValue()
|
||||
case FloatType | DoubleType => (raw) => {
|
||||
val dVal = raw.asInstanceOf[java.lang.Number].doubleValue()
|
||||
DoublePrefixComparator.computePrefix(dVal)
|
||||
}
|
||||
case StringType => (raw) =>
|
||||
StringPrefixComparator.computePrefix(raw.asInstanceOf[UTF8String])
|
||||
case BinaryType => (raw) =>
|
||||
BinaryPrefixComparator.computePrefix(raw.asInstanceOf[Array[Byte]])
|
||||
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
|
||||
_.asInstanceOf[Decimal].toUnscaledLong
|
||||
case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
|
||||
val p = Decimal.MAX_LONG_DIGITS
|
||||
val s = p - (dt.precision - dt.scale)
|
||||
(raw) => {
|
||||
val value = raw.asInstanceOf[Decimal]
|
||||
if (value.changePrecision(p, s)) value.toUnscaledLong else Long.MinValue
|
||||
}
|
||||
case dt: DecimalType => (raw) =>
|
||||
DoublePrefixComparator.computePrefix(raw.asInstanceOf[Decimal].toDouble)
|
||||
case _ => (Any) => 0L
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val value = child.child.eval(input)
|
||||
if (value == null) {
|
||||
null
|
||||
} else {
|
||||
calcPrefix(value)
|
||||
}
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val childCode = child.child.genCode(ctx)
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import java.sql.{Date, Timestamp}
|
||||
import java.util.TimeZone
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._
|
||||
|
||||
class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||
|
||||
test("SortPrefix") {
|
||||
val b1 = Literal.create(false, BooleanType)
|
||||
val b2 = Literal.create(true, BooleanType)
|
||||
val i1 = Literal.create(20132983, IntegerType)
|
||||
val i2 = Literal.create(-20132983, IntegerType)
|
||||
val l1 = Literal.create(20132983, LongType)
|
||||
val l2 = Literal.create(-20132983, LongType)
|
||||
val millis = 1524954911000L;
|
||||
// Explicitly choose a time zone, since Date objects can create different values depending on
|
||||
// local time zone of the machine on which the test is running
|
||||
val oldDefaultTZ = TimeZone.getDefault
|
||||
val d1 = try {
|
||||
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
|
||||
Literal.create(new java.sql.Date(millis), DateType)
|
||||
} finally {
|
||||
TimeZone.setDefault(oldDefaultTZ)
|
||||
}
|
||||
val t1 = Literal.create(new Timestamp(millis), TimestampType)
|
||||
val f1 = Literal.create(0.7788229f, FloatType)
|
||||
val f2 = Literal.create(-0.7788229f, FloatType)
|
||||
val db1 = Literal.create(0.7788229d, DoubleType)
|
||||
val db2 = Literal.create(-0.7788229d, DoubleType)
|
||||
val s1 = Literal.create("T", StringType)
|
||||
val s2 = Literal.create("This is longer than 8 characters", StringType)
|
||||
val bin1 = Literal.create(Array[Byte](12), BinaryType)
|
||||
val bin2 = Literal.create(Array[Byte](12, 17, 99, 0, 0, 0, 2, 3, 0xf4.asInstanceOf[Byte]),
|
||||
BinaryType)
|
||||
val dec1 = Literal(Decimal(20132983L, 10, 2))
|
||||
val dec2 = Literal(Decimal(20132983L, 19, 2))
|
||||
val dec3 = Literal(Decimal(20132983L, 21, 2))
|
||||
val list1 = Literal(List(1, 2), ArrayType(IntegerType))
|
||||
val nullVal = Literal.create(null, IntegerType)
|
||||
|
||||
checkEvaluation(SortPrefix(SortOrder(b1, Ascending)), 0L)
|
||||
checkEvaluation(SortPrefix(SortOrder(b2, Ascending)), 1L)
|
||||
checkEvaluation(SortPrefix(SortOrder(i1, Ascending)), 20132983L)
|
||||
checkEvaluation(SortPrefix(SortOrder(i2, Ascending)), -20132983L)
|
||||
checkEvaluation(SortPrefix(SortOrder(l1, Ascending)), 20132983L)
|
||||
checkEvaluation(SortPrefix(SortOrder(l2, Ascending)), -20132983L)
|
||||
// For some reason, the Literal.create code gives us the number of days since the epoch
|
||||
checkEvaluation(SortPrefix(SortOrder(d1, Ascending)), 17649L)
|
||||
checkEvaluation(SortPrefix(SortOrder(t1, Ascending)), millis * 1000)
|
||||
checkEvaluation(SortPrefix(SortOrder(f1, Ascending)),
|
||||
DoublePrefixComparator.computePrefix(f1.value.asInstanceOf[Float].toDouble))
|
||||
checkEvaluation(SortPrefix(SortOrder(f2, Ascending)),
|
||||
DoublePrefixComparator.computePrefix(f2.value.asInstanceOf[Float].toDouble))
|
||||
checkEvaluation(SortPrefix(SortOrder(db1, Ascending)),
|
||||
DoublePrefixComparator.computePrefix(db1.value.asInstanceOf[Double]))
|
||||
checkEvaluation(SortPrefix(SortOrder(db2, Ascending)),
|
||||
DoublePrefixComparator.computePrefix(db2.value.asInstanceOf[Double]))
|
||||
checkEvaluation(SortPrefix(SortOrder(s1, Ascending)),
|
||||
StringPrefixComparator.computePrefix(s1.value.asInstanceOf[UTF8String]))
|
||||
checkEvaluation(SortPrefix(SortOrder(s2, Ascending)),
|
||||
StringPrefixComparator.computePrefix(s2.value.asInstanceOf[UTF8String]))
|
||||
checkEvaluation(SortPrefix(SortOrder(bin1, Ascending)),
|
||||
BinaryPrefixComparator.computePrefix(bin1.value.asInstanceOf[Array[Byte]]))
|
||||
checkEvaluation(SortPrefix(SortOrder(bin2, Ascending)),
|
||||
BinaryPrefixComparator.computePrefix(bin2.value.asInstanceOf[Array[Byte]]))
|
||||
checkEvaluation(SortPrefix(SortOrder(dec1, Ascending)), 20132983L)
|
||||
checkEvaluation(SortPrefix(SortOrder(dec2, Ascending)), 2013298L)
|
||||
checkEvaluation(SortPrefix(SortOrder(dec3, Ascending)),
|
||||
DoublePrefixComparator.computePrefix(201329.83d))
|
||||
checkEvaluation(SortPrefix(SortOrder(list1, Ascending)), 0L)
|
||||
checkEvaluation(SortPrefix(SortOrder(nullVal, Ascending)), null)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue