diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 2ce9d072c7..76a881146a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ abstract sealed class SortDirection { @@ -148,7 +149,41 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { (!child.isAscending && child.nullOrdering == NullsLast) } - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + private lazy val calcPrefix: Any => Long = child.child.dataType match { + case BooleanType => (raw) => + if (raw.asInstanceOf[Boolean]) 1 else 0 + case DateType | TimestampType | _: IntegralType => (raw) => + raw.asInstanceOf[java.lang.Number].longValue() + case FloatType | DoubleType => (raw) => { + val dVal = raw.asInstanceOf[java.lang.Number].doubleValue() + DoublePrefixComparator.computePrefix(dVal) + } + case StringType => (raw) => + StringPrefixComparator.computePrefix(raw.asInstanceOf[UTF8String]) + case BinaryType => (raw) => + BinaryPrefixComparator.computePrefix(raw.asInstanceOf[Array[Byte]]) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + _.asInstanceOf[Decimal].toUnscaledLong + case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => + val p = Decimal.MAX_LONG_DIGITS + val s = p - (dt.precision - dt.scale) + (raw) => { + val value = raw.asInstanceOf[Decimal] + if (value.changePrecision(p, s)) value.toUnscaledLong else Long.MinValue + } + case dt: DecimalType => (raw) => + DoublePrefixComparator.computePrefix(raw.asInstanceOf[Decimal].toDouble) + case _ => (Any) => 0L + } + + override def eval(input: InternalRow): Any = { + val value = child.child.eval(input) + if (value == null) { + null + } else { + calcPrefix(value) + } + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childCode = child.child.genCode(ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala new file mode 100644 index 0000000000..cc2e2a993d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.sql.{Date, Timestamp} +import java.util.TimeZone + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ + +class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("SortPrefix") { + val b1 = Literal.create(false, BooleanType) + val b2 = Literal.create(true, BooleanType) + val i1 = Literal.create(20132983, IntegerType) + val i2 = Literal.create(-20132983, IntegerType) + val l1 = Literal.create(20132983, LongType) + val l2 = Literal.create(-20132983, LongType) + val millis = 1524954911000L; + // Explicitly choose a time zone, since Date objects can create different values depending on + // local time zone of the machine on which the test is running + val oldDefaultTZ = TimeZone.getDefault + val d1 = try { + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + Literal.create(new java.sql.Date(millis), DateType) + } finally { + TimeZone.setDefault(oldDefaultTZ) + } + val t1 = Literal.create(new Timestamp(millis), TimestampType) + val f1 = Literal.create(0.7788229f, FloatType) + val f2 = Literal.create(-0.7788229f, FloatType) + val db1 = Literal.create(0.7788229d, DoubleType) + val db2 = Literal.create(-0.7788229d, DoubleType) + val s1 = Literal.create("T", StringType) + val s2 = Literal.create("This is longer than 8 characters", StringType) + val bin1 = Literal.create(Array[Byte](12), BinaryType) + val bin2 = Literal.create(Array[Byte](12, 17, 99, 0, 0, 0, 2, 3, 0xf4.asInstanceOf[Byte]), + BinaryType) + val dec1 = Literal(Decimal(20132983L, 10, 2)) + val dec2 = Literal(Decimal(20132983L, 19, 2)) + val dec3 = Literal(Decimal(20132983L, 21, 2)) + val list1 = Literal(List(1, 2), ArrayType(IntegerType)) + val nullVal = Literal.create(null, IntegerType) + + checkEvaluation(SortPrefix(SortOrder(b1, Ascending)), 0L) + checkEvaluation(SortPrefix(SortOrder(b2, Ascending)), 1L) + checkEvaluation(SortPrefix(SortOrder(i1, Ascending)), 20132983L) + checkEvaluation(SortPrefix(SortOrder(i2, Ascending)), -20132983L) + checkEvaluation(SortPrefix(SortOrder(l1, Ascending)), 20132983L) + checkEvaluation(SortPrefix(SortOrder(l2, Ascending)), -20132983L) + // For some reason, the Literal.create code gives us the number of days since the epoch + checkEvaluation(SortPrefix(SortOrder(d1, Ascending)), 17649L) + checkEvaluation(SortPrefix(SortOrder(t1, Ascending)), millis * 1000) + checkEvaluation(SortPrefix(SortOrder(f1, Ascending)), + DoublePrefixComparator.computePrefix(f1.value.asInstanceOf[Float].toDouble)) + checkEvaluation(SortPrefix(SortOrder(f2, Ascending)), + DoublePrefixComparator.computePrefix(f2.value.asInstanceOf[Float].toDouble)) + checkEvaluation(SortPrefix(SortOrder(db1, Ascending)), + DoublePrefixComparator.computePrefix(db1.value.asInstanceOf[Double])) + checkEvaluation(SortPrefix(SortOrder(db2, Ascending)), + DoublePrefixComparator.computePrefix(db2.value.asInstanceOf[Double])) + checkEvaluation(SortPrefix(SortOrder(s1, Ascending)), + StringPrefixComparator.computePrefix(s1.value.asInstanceOf[UTF8String])) + checkEvaluation(SortPrefix(SortOrder(s2, Ascending)), + StringPrefixComparator.computePrefix(s2.value.asInstanceOf[UTF8String])) + checkEvaluation(SortPrefix(SortOrder(bin1, Ascending)), + BinaryPrefixComparator.computePrefix(bin1.value.asInstanceOf[Array[Byte]])) + checkEvaluation(SortPrefix(SortOrder(bin2, Ascending)), + BinaryPrefixComparator.computePrefix(bin2.value.asInstanceOf[Array[Byte]])) + checkEvaluation(SortPrefix(SortOrder(dec1, Ascending)), 20132983L) + checkEvaluation(SortPrefix(SortOrder(dec2, Ascending)), 2013298L) + checkEvaluation(SortPrefix(SortOrder(dec3, Ascending)), + DoublePrefixComparator.computePrefix(201329.83d)) + checkEvaluation(SortPrefix(SortOrder(list1, Ascending)), 0L) + checkEvaluation(SortPrefix(SortOrder(nullVal, Ascending)), null) + } +}