diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 1ffc95c676..005de31660 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.trees @@ -41,6 +42,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + s""" + boolean ${ev.isNull} = i.isNullAt($ordinal); + ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ? + ${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)}); + """ + } } object BindReferences extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 21adac1441..5f76a51267 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -433,6 +434,47 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val evaluated = child.eval(input) if (evaluated == null) null else cast(evaluated) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + // TODO(cg): Add support for more data types. + (child.dataType, dataType) match { + + case (BinaryType, StringType) => + defineCodeGen (ctx, ev, c => + s"new ${ctx.stringType}().set($c)") + case (DateType, StringType) => + defineCodeGen(ctx, ev, c => + s"""new ${ctx.stringType}().set( + org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""") + // Special handling required for timestamps in hive test cases since the toString function + // does not match the expected output. + case (TimestampType, StringType) => + super.genCode(ctx, ev) + case (_, StringType) => + defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))") + + // fallback for DecimalType, this must be before other numeric types + case (_, dt: DecimalType) => + super.genCode(ctx, ev) + + case (BooleanType, dt: NumericType) => + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)") + case (dt: DecimalType, BooleanType) => + defineCodeGen(ctx, ev, c => s"$c.isZero()") + case (dt: NumericType, BooleanType) => + defineCodeGen(ctx, ev, c => s"$c != 0") + + case (_: DecimalType, IntegerType) => + defineCodeGen(ctx, ev, c => s"($c).toInt()") + case (_: DecimalType, dt: NumericType) => + defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()") + case (_: NumericType, dt: NumericType) => + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)") + + case other => + super.genCode(ctx, ev) + } + } } object Cast { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b2b9d1a5e1..0ed576b3d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext, Term} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ @@ -51,6 +52,44 @@ abstract class Expression extends TreeNode[Expression] { /** Returns the result of evaluating this expression on a given input Row */ def eval(input: Row = null): Any + /** + * Returns an [[GeneratedExpressionCode]], which contains Java source code that + * can be used to generate the result of evaluating the expression on an input row. + * + * @param ctx a [[CodeGenContext]] + * @return [[GeneratedExpressionCode]] + */ + def gen(ctx: CodeGenContext): GeneratedExpressionCode = { + val isNull = ctx.freshName("isNull") + val primitive = ctx.freshName("primitive") + val ve = GeneratedExpressionCode("", isNull, primitive) + ve.code = genCode(ctx, ve) + ve + } + + /** + * Returns Java source code that can be compiled to evaluate this expression. + * The default behavior is to call the eval method of the expression. Concrete expression + * implementations should override this to do actual code generation. + * + * @param ctx a [[CodeGenContext]] + * @param ev an [[GeneratedExpressionCode]] with unique terms. + * @return Java source code + */ + protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + ctx.references += this + val objectTerm = ctx.freshName("obj") + s""" + /* expression: ${this} */ + Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i); + boolean ${ev.isNull} = ${objectTerm} == null; + ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = (${ctx.boxedType(this.dataType)})${objectTerm}; + } + """ + } + /** * Returns `true` if this expression and all its children have been resolved to a specific schema * and input data types checking passed, and `false` if it still contains any unresolved @@ -116,6 +155,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def nullable: Boolean = left.nullable || right.nullable override def toString: String = s"($left $symbol $right)" + + /** + * Short hand for generating binary evaluation code, which depends on two sub-evaluations of + * the same type. If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f accepts two variable names and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (Term, Term) => Code): String = { + // TODO: Right now some timestamp tests fail if we enforce this... + if (left.dataType != right.dataType) { + // log.warn(s"${left.dataType} != ${right.dataType}") + } + + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val resultCode = f(eval1.primitive, eval2.primitive) + + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if(!${eval2.isNull}) { + ${ev.primitive} = $resultCode; + } else { + ${ev.isNull} = true; + } + } + """ + } } private[sql] object BinaryExpression { @@ -128,6 +202,32 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => + + /** + * Called by unary expressions to generate a code block that returns null if its parent returns + * null, and if not not null, use `f` to generate the expression. + * + * As an example, the following does a boolean inversion (i.e. NOT). + * {{{ + * defineCodeGen(ctx, ev, c => s"!($c)") + * }}} + * + * @param f function that accepts a variable name and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: Term => Code): Code = { + val eval = child.gen(ctx) + // reuse the previous isNull + ev.isNull = eval.isNull + eval.code + s""" + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = ${f(eval.primitive)}; + } + """ + } } // TODO Semantically we probably not need GroupExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index a3770f998d..3ac7c92dcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{Code, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -49,6 +50,11 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { private lazy val numeric = TypeUtils.getNumeric(dataType) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match { + case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()") + case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)") + } + protected override def evalInternal(evalE: Any) = numeric.negate(evalE) } @@ -67,6 +73,21 @@ case class Sqrt(child: Expression) extends UnaryArithmetic { if (value < 0) null else math.sqrt(value) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + if (${eval.primitive} < 0.0) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = java.lang.Math.sqrt(${eval.primitive}); + } + } + """ + } } /** @@ -86,6 +107,9 @@ case class Abs(child: Expression) extends UnaryArithmetic { abstract class BinaryArithmetic extends BinaryExpression { self: Product => + /** Name of the function for this expression on a [[Decimal]] type. */ + def decimalMethod: String = "" + override def dataType: DataType = left.dataType override def checkInputDataTypes(): TypeCheckResult = { @@ -114,6 +138,17 @@ abstract class BinaryArithmetic extends BinaryExpression { } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match { + case dt: DecimalType => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") + // byte and short are casted into int when add, minus, times or divide + case ByteType | ShortType => + defineCodeGen(ctx, ev, (eval1, eval2) => + s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + case _ => + defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") + } + protected def evalInternal(evalE1: Any, evalE2: Any): Any = sys.error(s"BinaryArithmetics must override either eval or evalInternal") } @@ -124,6 +159,7 @@ private[sql] object BinaryArithmetic { case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" + override def decimalMethod: String = "$plus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) @@ -138,6 +174,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" + override def decimalMethod: String = "$minus" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) @@ -152,6 +189,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "*" + override def decimalMethod: String = "$times" override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) @@ -166,6 +204,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "/" + override def decimalMethod: String = "$divide" + override def nullable: Boolean = true override lazy val resolved = @@ -192,10 +232,40 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } } } + + /** + * Special case handling due to division by 0 => null. + */ + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val test = if (left.dataType.isInstanceOf[DecimalType]) { + s"${eval2.primitive}.isZero()" + } else { + s"${eval2.primitive} == 0" + } + val method = if (left.dataType.isInstanceOf[DecimalType]) { + s".$decimalMethod" + } else { + s"$symbol" + } + eval1.code + eval2.code + + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; + if (${eval1.isNull} || ${eval2.isNull} || $test) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); + } + """ + } } case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "%" + override def decimalMethod: String = "reminder" + override def nullable: Boolean = true override lazy val resolved = @@ -222,6 +292,34 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } } } + + /** + * Special case handling for x % 0 ==> null. + */ + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val test = if (left.dataType.isInstanceOf[DecimalType]) { + s"${eval2.primitive}.isZero()" + } else { + s"${eval2.primitive} == 0" + } + val method = if (left.dataType.isInstanceOf[DecimalType]) { + s".$decimalMethod" + } else { + s"$symbol" + } + eval1.code + eval2.code + + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; + if (${eval1.isNull} || ${eval2.isNull} || $test) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); + } + """ + } } /** @@ -271,7 +369,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet } /** - * A function that calculates bitwise xor(^) of two numbers. + * A function that calculates bitwise xor of two numbers. */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "^" @@ -313,6 +411,10 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic { ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)})~($c)") + } + protected override def evalInternal(evalE: Any) = not(evalE) } @@ -340,6 +442,33 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + if (ctx.isNativeType(left.dataType)) { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + eval1.code + eval2.code + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(left.dataType)} ${ev.primitive} = + ${ctx.defaultValue(left.dataType)}; + + if (${eval1.isNull}) { + ${ev.isNull} = ${eval2.isNull}; + ${ev.primitive} = ${eval2.primitive}; + } else if (${eval2.isNull}) { + ${ev.isNull} = ${eval1.isNull}; + ${ev.primitive} = ${eval1.primitive}; + } else { + if (${eval1.primitive} > ${eval2.primitive}) { + ${ev.primitive} = ${eval1.primitive}; + } else { + ${ev.primitive} = ${eval2.primitive}; + } + } + """ + } else { + super.genCode(ctx, ev) + } + } override def toString: String = s"MaxOf($left, $right)" } @@ -367,5 +496,35 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + if (ctx.isNativeType(left.dataType)) { + + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + + eval1.code + eval2.code + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(left.dataType)} ${ev.primitive} = + ${ctx.defaultValue(left.dataType)}; + + if (${eval1.isNull}) { + ${ev.isNull} = ${eval2.isNull}; + ${ev.primitive} = ${eval2.primitive}; + } else if (${eval2.isNull}) { + ${ev.isNull} = ${eval1.isNull}; + ${ev.primitive} = ${eval1.primitive}; + } else { + if (${eval1.primitive} < ${eval2.primitive}) { + ${ev.primitive} = ${eval1.primitive}; + } else { + ${ev.primitive} = ${eval2.primitive}; + } + } + """ + } else { + super.genCode(ctx, ev) + } + } + override def toString: String = s"MinOf($left, $right)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cd604121b7..c8d0aaf79f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -24,7 +24,6 @@ import com.google.common.cache.{CacheBuilder, CacheLoader} import org.codehaus.janino.ClassBodyEvaluator import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -32,6 +31,157 @@ import org.apache.spark.sql.types._ class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int] class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] +/** + * Java source for evaluating an [[Expression]] given a [[Row]] of input. + * + * @param code The sequence of statements required to evaluate the expression. + * @param isNull A term that holds a boolean value representing whether the expression evaluated + * to null. + * @param primitive A term for a possible primitive value of the result of the evaluation. Not + * valid if `isNull` is set to `true`. + */ +case class GeneratedExpressionCode(var code: Code, var isNull: Term, var primitive: Term) + +/** + * A context for codegen, which is used to bookkeeping the expressions those are not supported + * by codegen, then they are evaluated directly. The unsupported expression is appended at the + * end of `references`, the position of it is kept in the code, used to access and evaluate it. + */ +class CodeGenContext { + + /** + * Holding all the expressions those do not support codegen, will be evaluated directly. + */ + val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() + + val stringType: String = classOf[UTF8String].getName + val decimalType: String = classOf[Decimal].getName + + private val curId = new java.util.concurrent.atomic.AtomicInteger() + + /** + * Returns a term name that is unique within this instance of a `CodeGenerator`. + * + * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` + * function.) + */ + def freshName(prefix: String): Term = { + s"$prefix${curId.getAndIncrement}" + } + + /** + * Return the code to access a column for given DataType + */ + def getColumn(dataType: DataType, ordinal: Int): Code = { + if (isNativeType(dataType)) { + s"i.${accessorForType(dataType)}($ordinal)" + } else { + s"(${boxedType(dataType)})i.apply($ordinal)" + } + } + + /** + * Return the code to update a column in Row for given DataType + */ + def setColumn(dataType: DataType, ordinal: Int, value: Term): Code = { + if (isNativeType(dataType)) { + s"${mutatorForType(dataType)}($ordinal, $value)" + } else { + s"update($ordinal, $value)" + } + } + + /** + * Return the name of accessor in Row for a DataType + */ + def accessorForType(dt: DataType): Term = dt match { + case IntegerType => "getInt" + case other => s"get${boxedType(dt)}" + } + + /** + * Return the name of mutator in Row for a DataType + */ + def mutatorForType(dt: DataType): Term = dt match { + case IntegerType => "setInt" + case other => s"set${boxedType(dt)}" + } + + /** + * Return the Java type for a DataType + */ + def javaType(dt: DataType): Term = dt match { + case IntegerType => "int" + case LongType => "long" + case ShortType => "short" + case ByteType => "byte" + case DoubleType => "double" + case FloatType => "float" + case BooleanType => "boolean" + case dt: DecimalType => decimalType + case BinaryType => "byte[]" + case StringType => stringType + case DateType => "int" + case TimestampType => "java.sql.Timestamp" + case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName + case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName + case _ => "Object" + } + + /** + * Return the boxed type in Java + */ + def boxedType(dt: DataType): Term = dt match { + case IntegerType => "Integer" + case LongType => "Long" + case ShortType => "Short" + case ByteType => "Byte" + case DoubleType => "Double" + case FloatType => "Float" + case BooleanType => "Boolean" + case DateType => "Integer" + case _ => javaType(dt) + } + + /** + * Return the representation of default value for given DataType + */ + def defaultValue(dt: DataType): Term = dt match { + case BooleanType => "false" + case FloatType => "-1.0f" + case ShortType => "(short)-1" + case LongType => "-1L" + case ByteType => "(byte)-1" + case DoubleType => "-1.0" + case IntegerType => "-1" + case DateType => "-1" + case _ => "null" + } + + /** + * Returns a function to generate equal expression in Java + */ + def equalFunc(dataType: DataType): ((Term, Term) => Code) = dataType match { + case BinaryType => { case (eval1, eval2) => + s"java.util.Arrays.equals($eval1, $eval2)" } + case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType => + { case (eval1, eval2) => s"$eval1 == $eval2" } + case other => + { case (eval1, eval2) => s"$eval1.equals($eval2)" } + } + + /** + * List of data types that have special accessors and setters in [[Row]]. + */ + val nativeTypes = + Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType) + + /** + * Returns true if the data type has a special accessor and setter in [[Row]]. + */ + def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt) +} + /** * A base class for generators of byte code to perform expression evaluation. Includes a set of * helpers for referring to Catalyst types and building trees that perform evaluation of individual @@ -39,14 +189,9 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] */ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging { - protected val rowType = classOf[Row].getName - protected val stringType = classOf[UTF8String].getName - protected val decimalType = classOf[Decimal].getName - protected val exprType = classOf[Expression].getName - protected val mutableRowType = classOf[MutableRow].getName - protected val genericMutableRowType = classOf[GenericMutableRow].getName - - private val curId = new java.util.concurrent.atomic.AtomicInteger() + protected val exprType: String = classOf[Expression].getName + protected val mutableRowType: String = classOf[MutableRow].getName + protected val genericMutableRowType: String = classOf[GenericMutableRow].getName /** * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. @@ -75,10 +220,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin */ protected def compile(code: String): Class[_] = { val startTime = System.nanoTime() - val clazz = new ClassBodyEvaluator(code).getClazz() + val clazz = try { + new ClassBodyEvaluator(code).getClazz() + } catch { + case e: Exception => + logError(s"failed to compile:\n $code", e) + throw e + } val endTime = System.nanoTime() def timeMs: Double = (endTime - startTime).toDouble / 1000000 - logDebug(s"Compiled Java code (${code.size} bytes) in $timeMs ms") + logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms") clazz } @@ -112,586 +263,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** Generates the requested evaluator given already bound expression(s). */ def generate(expressions: InType): OutType = cache.get(canonicalize(expressions)) - /** - * Returns a term name that is unique within this instance of a `CodeGenerator`. - * - * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` - * function.) - */ - protected def freshName(prefix: String): String = { - s"$prefix${curId.getAndIncrement}" - } - - /** - * Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input. - * - * @param code The sequence of statements required to evaluate the expression. - * @param nullTerm A term that holds a boolean value representing whether the expression evaluated - * to null. - * @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not - * valid if `nullTerm` is set to `true`. - * @param objectTerm A possibly boxed version of the result of evaluating this expression. - */ - protected case class EvaluatedExpression( - code: String, - nullTerm: String, - primitiveTerm: String, - objectTerm: String) - - /** - * A context for codegen, which is used to bookkeeping the expressions those are not supported - * by codegen, then they are evaluated directly. The unsupported expression is appended at the - * end of `references`, the position of it is kept in the code, used to access and evaluate it. - */ - protected class CodeGenContext { - /** - * Holding all the expressions those do not support codegen, will be evaluated directly. - */ - val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]() - } - /** * Create a new codegen context for expression evaluator, used to store those * expressions that don't support codegen */ def newCodeGenContext(): CodeGenContext = { - new CodeGenContext() + new CodeGenContext } - - /** - * Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that - * can be used to determine the result of evaluating the expression on an input row. - */ - def expressionEvaluator(e: Expression, ctx: CodeGenContext): EvaluatedExpression = { - val primitiveTerm = freshName("primitiveTerm") - val nullTerm = freshName("nullTerm") - val objectTerm = freshName("objectTerm") - - implicit class Evaluate1(e: Expression) { - def castOrNull(f: String => String, dataType: DataType): String = { - val eval = expressionEvaluator(e, ctx) - eval.code + - s""" - boolean $nullTerm = ${eval.nullTerm}; - ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)}; - if (!$nullTerm) { - $primitiveTerm = ${f(eval.primitiveTerm)}; - } - """ - } - } - - implicit class Evaluate2(expressions: (Expression, Expression)) { - - /** - * Short hand for generating binary evaluation code, which depends on two sub-evaluations of - * the same type. If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - * - * @param f a function from two primitive term names to a tree that evaluates them. - */ - def evaluate(f: (String, String) => String): String = - evaluateAs(expressions._1.dataType)(f) - - def evaluateAs(resultType: DataType)(f: (String, String) => String): String = { - // TODO: Right now some timestamp tests fail if we enforce this... - if (expressions._1.dataType != expressions._2.dataType) { - log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}") - } - - val eval1 = expressionEvaluator(expressions._1, ctx) - val eval2 = expressionEvaluator(expressions._2, ctx) - val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm) - - eval1.code + eval2.code + - s""" - boolean $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm}; - ${primitiveForType(resultType)} $primitiveTerm = ${defaultPrimitive(resultType)}; - if(!$nullTerm) { - $primitiveTerm = (${primitiveForType(resultType)})($resultCode); - } - """ - } - } - - val inputTuple = "i" - - // TODO: Skip generation of null handling code when expression are not nullable. - val primitiveEvaluation: PartialFunction[Expression, String] = { - case b @ BoundReference(ordinal, dataType, nullable) => - s""" - final boolean $nullTerm = $inputTuple.isNullAt($ordinal); - final ${primitiveForType(dataType)} $primitiveTerm = $nullTerm ? - ${defaultPrimitive(dataType)} : (${getColumn(inputTuple, dataType, ordinal)}); - """ - - case expressions.Literal(null, dataType) => - s""" - final boolean $nullTerm = true; - ${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)}; - """ - - case expressions.Literal(value: UTF8String, StringType) => - val arr = s"new byte[]{${value.getBytes.map(_.toString).mkString(", ")}}" - s""" - final boolean $nullTerm = false; - ${stringType} $primitiveTerm = - new ${stringType}().set(${arr}); - """ - - case expressions.Literal(value, FloatType) => - s""" - final boolean $nullTerm = false; - float $primitiveTerm = ${value}f; - """ - - case expressions.Literal(value, dt @ DecimalType()) => - s""" - final boolean $nullTerm = false; - ${primitiveForType(dt)} $primitiveTerm = new ${primitiveForType(dt)}().set($value); - """ - - case expressions.Literal(value, dataType) => - s""" - final boolean $nullTerm = false; - ${primitiveForType(dataType)} $primitiveTerm = $value; - """ - - case Cast(child @ BinaryType(), StringType) => - child.castOrNull(c => - s"new ${stringType}().set($c)", - StringType) - - case Cast(child @ DateType(), StringType) => - child.castOrNull(c => - s"""new ${stringType}().set( - org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""", - StringType) - - case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - child.castOrNull(c => s"(${primitiveForType(dt)})($c?1:0)", dt) - - case Cast(child @ DecimalType(), IntegerType) => - child.castOrNull(c => s"($c).toInt()", IntegerType) - - case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - child.castOrNull(c => s"($c).to${termForType(dt)}()", dt) - - case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] => - child.castOrNull(c => s"(${primitiveForType(dt)})($c)", dt) - - // Special handling required for timestamps in hive test cases since the toString function - // does not match the expected output. - case Cast(e, StringType) if e.dataType != TimestampType => - e.castOrNull(c => - s"new ${stringType}().set(String.valueOf($c))", - StringType) - - case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) => - (e1, e2).evaluateAs (BooleanType) { - case (eval1, eval2) => - s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)" - } - - case EqualTo(e1, e2) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 == $eval2" } - - case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 > $eval2" } - case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 >= $eval2" } - case LessThan(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 < $eval2" } - case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) => - (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 <= $eval2" } - - case And(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - s""" - ${eval1.code} - boolean $nullTerm = false; - boolean $primitiveTerm = false; - - if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) { - } else { - ${eval2.code} - if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) { - } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = true; - } else { - $nullTerm = true; - } - } - """ - - case Or(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - - s""" - ${eval1.code} - boolean $nullTerm = false; - boolean $primitiveTerm = false; - - if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) { - $primitiveTerm = true; - } else { - ${eval2.code} - if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) { - $primitiveTerm = true; - } else if (!${eval1.nullTerm} && !${eval2.nullTerm}) { - $primitiveTerm = false; - } else { - $nullTerm = true; - } - } - """ - - case Not(child) => - // Uh, bad function name... - child.castOrNull(c => s"!$c", BooleanType) - - case Add(e1 @ DecimalType(), e2 @ DecimalType()) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$plus($eval2)" } - case Subtract(e1 @ DecimalType(), e2 @ DecimalType()) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$minus($eval2)" } - case Multiply(e1 @ DecimalType(), e2 @ DecimalType()) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$times($eval2)" } - case Divide(e1 @ DecimalType(), e2 @ DecimalType()) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = null; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm}.$$div${eval2.primitiveTerm}); - } - """ - case Remainder(e1 @ DecimalType(), e2 @ DecimalType()) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = 0; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm}.remainder(${eval2.primitiveTerm}); - } - """ - - case Add(e1, e2) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 + $eval2" } - case Subtract(e1, e2) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 - $eval2" } - case Multiply(e1, e2) => - (e1, e2) evaluate { case (eval1, eval2) => s"$eval1 * $eval2" } - case Divide(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = 0; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm}; - } - """ - case Remainder(e1, e2) => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = 0; - if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) { - $nullTerm = true; - } else { - $primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm}; - } - """ - - case IsNotNull(e) => - val eval = expressionEvaluator(e, ctx) - s""" - ${eval.code} - boolean $nullTerm = false; - boolean $primitiveTerm = !${eval.nullTerm}; - """ - - case IsNull(e) => - val eval = expressionEvaluator(e, ctx) - s""" - ${eval.code} - boolean $nullTerm = false; - boolean $primitiveTerm = ${eval.nullTerm}; - """ - - case e @ Coalesce(children) => - s""" - boolean $nullTerm = true; - ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; - """ + - children.map { c => - val eval = expressionEvaluator(c, ctx) - s""" - if($nullTerm) { - ${eval.code} - if(!${eval.nullTerm}) { - $nullTerm = false; - $primitiveTerm = ${eval.primitiveTerm}; - } - } - """ - }.mkString("\n") - - case e @ expressions.If(condition, trueValue, falseValue) => - val condEval = expressionEvaluator(condition, ctx) - val trueEval = expressionEvaluator(trueValue, ctx) - val falseEval = expressionEvaluator(falseValue, ctx) - - s""" - boolean $nullTerm = false; - ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; - ${condEval.code} - if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) { - ${trueEval.code} - $nullTerm = ${trueEval.nullTerm}; - $primitiveTerm = ${trueEval.primitiveTerm}; - } else { - ${falseEval.code} - $nullTerm = ${falseEval.nullTerm}; - $primitiveTerm = ${falseEval.primitiveTerm}; - } - """ - - case NewSet(elementType) => - s""" - boolean $nullTerm = false; - ${hashSetForType(elementType)} $primitiveTerm = new ${hashSetForType(elementType)}(); - """ - - case AddItemToSet(item, set) => - val itemEval = expressionEvaluator(item, ctx) - val setEval = expressionEvaluator(set, ctx) - - val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType - val htype = hashSetForType(elementType) - - itemEval.code + setEval.code + - s""" - if (!${itemEval.nullTerm} && !${setEval.nullTerm}) { - (($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm}); - } - boolean $nullTerm = false; - ${htype} $primitiveTerm = ($htype)${setEval.primitiveTerm}; - """ - - case CombineSets(left, right) => - val leftEval = expressionEvaluator(left, ctx) - val rightEval = expressionEvaluator(right, ctx) - - val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType - val htype = hashSetForType(elementType) - - leftEval.code + rightEval.code + - s""" - boolean $nullTerm = false; - ${htype} $primitiveTerm = - (${htype})${leftEval.primitiveTerm}; - $primitiveTerm.union((${htype})${rightEval.primitiveTerm}); - """ - - case MaxOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)}; - - if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm}; - $primitiveTerm = ${eval2.primitiveTerm}; - } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm}; - $primitiveTerm = ${eval1.primitiveTerm}; - } else { - if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm}; - } else { - $primitiveTerm = ${eval2.primitiveTerm}; - } - } - """ - - case MinOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] => - val eval1 = expressionEvaluator(e1, ctx) - val eval2 = expressionEvaluator(e2, ctx) - - eval1.code + eval2.code + - s""" - boolean $nullTerm = false; - ${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)}; - - if (${eval1.nullTerm}) { - $nullTerm = ${eval2.nullTerm}; - $primitiveTerm = ${eval2.primitiveTerm}; - } else if (${eval2.nullTerm}) { - $nullTerm = ${eval1.nullTerm}; - $primitiveTerm = ${eval1.primitiveTerm}; - } else { - if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) { - $primitiveTerm = ${eval1.primitiveTerm}; - } else { - $primitiveTerm = ${eval2.primitiveTerm}; - } - } - """ - - case UnscaledValue(child) => - val childEval = expressionEvaluator(child, ctx) - - childEval.code + - s""" - boolean $nullTerm = ${childEval.nullTerm}; - long $primitiveTerm = $nullTerm ? -1 : ${childEval.primitiveTerm}.toUnscaledLong(); - """ - - case MakeDecimal(child, precision, scale) => - val eval = expressionEvaluator(child, ctx) - - eval.code + - s""" - boolean $nullTerm = ${eval.nullTerm}; - org.apache.spark.sql.types.Decimal $primitiveTerm = ${defaultPrimitive(DecimalType())}; - - if (!$nullTerm) { - $primitiveTerm = new org.apache.spark.sql.types.Decimal(); - $primitiveTerm = $primitiveTerm.setOrNull(${eval.primitiveTerm}, $precision, $scale); - $nullTerm = $primitiveTerm == null; - } - """ - } - - // If there was no match in the partial function above, we fall back on calling the interpreted - // expression evaluator. - val code: String = - primitiveEvaluation.lift.apply(e).getOrElse { - logError(s"No rules to generate $e") - ctx.references += e - s""" - /* expression: ${e} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); - boolean $nullTerm = $objectTerm == null; - ${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)}; - if (!$nullTerm) $primitiveTerm = (${termForType(e.dataType)})$objectTerm; - """ - } - - EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) - } - - protected def getColumn(inputRow: String, dataType: DataType, ordinal: Int) = { - dataType match { - case StringType => s"(${stringType})$inputRow.apply($ordinal)" - case dt: DataType if isNativeType(dt) => s"$inputRow.${accessorForType(dt)}($ordinal)" - case _ => s"(${termForType(dataType)})$inputRow.apply($ordinal)" - } - } - - protected def setColumn( - destinationRow: String, - dataType: DataType, - ordinal: Int, - value: String): String = { - dataType match { - case StringType => s"$destinationRow.update($ordinal, $value)" - case dt: DataType if isNativeType(dt) => - s"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" - case _ => s"$destinationRow.update($ordinal, $value)" - } - } - - protected def accessorForType(dt: DataType) = dt match { - case IntegerType => "getInt" - case other => s"get${termForType(dt)}" - } - - protected def mutatorForType(dt: DataType) = dt match { - case IntegerType => "setInt" - case other => s"set${termForType(dt)}" - } - - protected def hashSetForType(dt: DataType): String = dt match { - case IntegerType => classOf[IntegerHashSet].getName - case LongType => classOf[LongHashSet].getName - case unsupportedType => - sys.error(s"Code generation not support for hashset of type $unsupportedType") - } - - protected def primitiveForType(dt: DataType): String = dt match { - case IntegerType => "int" - case LongType => "long" - case ShortType => "short" - case ByteType => "byte" - case DoubleType => "double" - case FloatType => "float" - case BooleanType => "boolean" - case dt: DecimalType => decimalType - case BinaryType => "byte[]" - case StringType => stringType - case DateType => "int" - case TimestampType => "java.sql.Timestamp" - case _ => "Object" - } - - protected def defaultPrimitive(dt: DataType): String = dt match { - case BooleanType => "false" - case FloatType => "-1.0f" - case ShortType => "-1" - case LongType => "-1" - case ByteType => "-1" - case DoubleType => "-1.0" - case IntegerType => "-1" - case DateType => "-1" - case dt: DecimalType => "null" - case StringType => "null" - case _ => "null" - } - - protected def termForType(dt: DataType): String = dt match { - case IntegerType => "Integer" - case LongType => "Long" - case ShortType => "Short" - case ByteType => "Byte" - case DoubleType => "Double" - case FloatType => "Float" - case BooleanType => "Boolean" - case dt: DecimalType => decimalType - case BinaryType => "byte[]" - case StringType => stringType - case DateType => "Integer" - case TimestampType => "java.sql.Timestamp" - case _ => "Object" - } - - /** - * List of data types that have special accessors and setters in [[Row]]. - */ - protected val nativeTypes = - Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType) - - /** - * Returns true if the data type has a special accessor and setter in [[Row]]. - */ - protected def isNativeType(dt: DataType) = nativeTypes.contains(dt) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 638b53fe0f..e5ee2accd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -37,13 +37,13 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { val ctx = newCodeGenContext() val projectionCode = expressions.zipWithIndex.map { case (e, i) => - val evaluationCode = expressionEvaluator(e, ctx) + val evaluationCode = e.gen(ctx) evaluationCode.code + s""" - if(${evaluationCode.nullTerm}) + if(${evaluationCode.isNull}) mutableRow.setNullAt($i); else - ${setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)}; + mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)}; """ }.mkString("\n") val code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 0ff840dab3..36e155d164 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -52,15 +52,15 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit val ctx = newCodeGenContext() val comparisons = ordering.zipWithIndex.map { case (order, i) => - val evalA = expressionEvaluator(order.child, ctx) - val evalB = expressionEvaluator(order.child, ctx) + val evalA = order.child.gen(ctx) + val evalB = order.child.gen(ctx) val asc = order.direction == Ascending val compare = order.child.dataType match { case BinaryType => s""" { - byte[] x = ${if (asc) evalA.primitiveTerm else evalB.primitiveTerm}; - byte[] y = ${if (!asc) evalB.primitiveTerm else evalA.primitiveTerm}; + byte[] x = ${if (asc) evalA.primitive else evalB.primitive}; + byte[] y = ${if (!asc) evalB.primitive else evalA.primitive}; int j = 0; while (j < x.length && j < y.length) { if (x[j] != y[j]) return x[j] - y[j]; @@ -73,8 +73,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit }""" case _: NumericType => s""" - if (${evalA.primitiveTerm} != ${evalB.primitiveTerm}) { - if (${evalA.primitiveTerm} > ${evalB.primitiveTerm}) { + if (${evalA.primitive} != ${evalB.primitive}) { + if (${evalA.primitive} > ${evalB.primitive}) { return ${if (asc) "1" else "-1"}; } else { return ${if (asc) "-1" else "1"}; @@ -82,7 +82,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit }""" case _ => s""" - int comp = ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm}); + int comp = ${evalA.primitive}.compare(${evalB.primitive}); if (comp != 0) { return ${if (asc) "comp" else "-comp"}; }""" @@ -93,11 +93,11 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit ${evalA.code} i = $b; ${evalB.code} - if (${evalA.nullTerm} && ${evalB.nullTerm}) { + if (${evalA.isNull} && ${evalB.isNull}) { // Nothing - } else if (${evalA.nullTerm}) { + } else if (${evalA.isNull}) { return ${if (order.direction == Ascending) "-1" else "1"}; - } else if (${evalB.nullTerm}) { + } else if (${evalB.isNull}) { return ${if (order.direction == Ascending) "1" else "-1"}; } else { $compare diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index fb18769f00..4a547b5ce9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -38,7 +38,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { protected def create(predicate: Expression): ((Row) => Boolean) = { val ctx = newCodeGenContext() - val eval = expressionEvaluator(predicate, ctx) + val eval = predicate.gen(ctx) val code = s""" import org.apache.spark.sql.Row; @@ -55,7 +55,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { @Override public boolean eval(Row i) { ${eval.code} - return !${eval.nullTerm} && ${eval.primitiveTerm}; + return !${eval.isNull} && ${eval.primitive}; } }""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index d5be1fc12e..7caf4aaab8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -45,19 +45,19 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val ctx = newCodeGenContext() val columns = expressions.zipWithIndex.map { case (e, i) => - s"private ${primitiveForType(e.dataType)} c$i = ${defaultPrimitive(e.dataType)};\n" + s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n" }.mkString("\n ") val initColumns = expressions.zipWithIndex.map { case (e, i) => - val eval = expressionEvaluator(e, ctx) + val eval = e.gen(ctx) s""" { // column$i ${eval.code} - nullBits[$i] = ${eval.nullTerm}; - if(!${eval.nullTerm}) { - c$i = ${eval.primitiveTerm}; + nullBits[$i] = ${eval.isNull}; + if (!${eval.isNull}) { + c$i = ${eval.primitive}; } } """ @@ -68,10 +68,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { }.mkString("\n ") val updateCases = expressions.zipWithIndex.map { case (e, i) => - s"case $i: { c$i = (${termForType(e.dataType)})value; return;}" + s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" }.mkString("\n ") - val specificAccessorFunctions = nativeTypes.map { dataType => + val specificAccessorFunctions = ctx.nativeTypes.map { dataType => val cases = expressions.zipWithIndex.map { case (e, i) if e.dataType == dataType => s"case $i: return c$i;" @@ -80,21 +80,21 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { if (cases.count(_ != '\n') > 0) { s""" @Override - public ${primitiveForType(dataType)} ${accessorForType(dataType)}(int i) { + public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) { if (isNullAt(i)) { - return ${defaultPrimitive(dataType)}; + return ${ctx.defaultValue(dataType)}; } switch (i) { $cases } - return ${defaultPrimitive(dataType)}; + return ${ctx.defaultValue(dataType)}; }""" } else { "" } }.mkString("\n") - val specificMutatorFunctions = nativeTypes.map { dataType => + val specificMutatorFunctions = ctx.nativeTypes.map { dataType => val cases = expressions.zipWithIndex.map { case (e, i) if e.dataType == dataType => s"case $i: { c$i = value; return; }" @@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { if (cases.count(_ != '\n') > 0) { s""" @Override - public void ${mutatorForType(dataType)}(int i, ${primitiveForType(dataType)} value) { + public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) { nullBits[i] = false; switch (i) { $cases @@ -122,7 +122,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { case LongType => s"$col ^ ($col >>> 32)" case FloatType => s"Float.floatToIntBits($col)" case DoubleType => - s"Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32)" + s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" case _ => s"$col.hashCode()" } s"isNullAt($i) ? 0 : ($nonNull)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 7f1b12cdd5..6f9589d204 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -27,6 +27,9 @@ import org.apache.spark.util.Utils */ package object codegen { + type Term = String + type Code = String + /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { val batches = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 65ba18924a..ddfadf314f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.types._ /** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ @@ -35,6 +36,10 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { childResult.asInstanceOf[Decimal].toUnscaledLong } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") + } } /** Create a Decimal from an unscaled Long value */ @@ -53,4 +58,18 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale) } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.decimalType} ${ev.primitive} = null; + + if (!${ev.isNull}) { + ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull( + ${eval.primitive}, $precision, $scale); + ${ev.isNull} = ${ev.primitive} == null; + } + """ + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index d3ca3d9a4b..3a9271678b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -78,7 +79,60 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def toString: String = if (value != null) value.toString else "null" + override def equals(other: Any): Boolean = other match { + case o: Literal => + dataType.equals(o.dataType) && + (value == null && null == o.value || value != null && value.equals(o.value)) + case _ => false + } + override def eval(input: Row): Any = value + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + // change the isNull and primitive to consts, to inline them + if (value == null) { + ev.isNull = "true" + ev.primitive = ctx.defaultValue(dataType) + "" + } else { + dataType match { + case BooleanType => + ev.isNull = "false" + ev.primitive = value.toString + "" + case FloatType => // This must go before NumericType + val v = value.asInstanceOf[Float] + if (v.isNaN || v.isInfinite) { + super.genCode(ctx, ev) + } else { + ev.isNull = "false" + ev.primitive = s"${value}f" + "" + } + case DoubleType => // This must go before NumericType + val v = value.asInstanceOf[Double] + if (v.isNaN || v.isInfinite) { + super.genCode(ctx, ev) + } else { + ev.isNull = "false" + ev.primitive = s"${value}" + "" + } + + case ByteType | ShortType => // This must go before NumericType + ev.isNull = "false" + ev.primitive = s"(${ctx.javaType(dataType)})$value" + "" + case dt: NumericType if !dt.isInstanceOf[DecimalType] => + ev.isNull = "false" + ev.primitive = value.toString + "" + // eval() version may be faster for non-primitive types + case other => + super.genCode(ctx, ev) + } + } + } } // TODO: Specialize diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index db853a2b97..88211acd77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.mathfuncs +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row} import org.apache.spark.sql.types._ @@ -49,6 +50,10 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") + } } case class Atan2(left: Expression, right: Expression) @@ -70,9 +75,26 @@ case class Atan2(left: Expression, right: Expression) } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } } case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") -case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") +case class Pow(left: Expression, right: Expression) + extends BinaryMathExpression(math.pow, "POWER") { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala index 41b422346a..5563cd94bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.mathfuncs +import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression} import org.apache.spark.sql.types._ @@ -44,6 +45,23 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) if (result.isNaN) null else result } } + + // name of function in java.lang.Math + def funcName: String = name.toLowerCase + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + } + """ + } } case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") @@ -72,7 +90,9 @@ case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") -case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") +case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { + override def funcName: String = "rint" +} case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") @@ -84,6 +104,10 @@ case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") -case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") +case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") { + override def funcName: String = "toDegrees" +} -case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") +case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") { + override def funcName: String = "toRadians" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 00565ec651..2e4b9ba678 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.trees.LeafNode +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.types._ object NamedExpression { @@ -116,6 +116,8 @@ case class Alias(child: Expression, name: String)( override def eval(input: Row): Any = child.eval(input) + override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) + override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable override def metadata: Metadata = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 5070570b47..9ecfb3ccc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.types.DataType @@ -51,6 +52,25 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } result } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + + children.map { e => + val eval = e.gen(ctx) + s""" + if (${ev.isNull}) { + ${eval.code} + if (!${eval.isNull}) { + ${ev.isNull} = false; + ${ev.primitive} = ${eval.primitive}; + } + } + """ + }.mkString("\n") + } } case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { @@ -61,6 +81,13 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr child.eval(input) == null } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval = child.gen(ctx) + ev.isNull = "false" + ev.primitive = eval.isNull + eval.code + } + override def toString: String = s"IS NULL $child" } @@ -72,6 +99,13 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E override def eval(input: Row): Any = { child.eval(input) != null } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval = child.gen(ctx) + ev.isNull = "false" + ev.primitive = s"(!(${eval.isNull}))" + eval.code + } } /** @@ -95,4 +129,25 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate } numNonNulls >= n } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val nonnull = ctx.freshName("nonnull") + val code = children.map { e => + val eval = e.gen(ctx) + s""" + if ($nonnull < $n) { + ${eval.code} + if (!${eval.isNull}) { + $nonnull += 1; + } + } + """ + }.mkString("\n") + s""" + int $nonnull = 0; + $code + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = $nonnull >= $n; + """ + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 58273b166f..1d0f19a400 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types._ object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -82,6 +83,10 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex case b: Boolean => !b } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"!($c)") + } } /** @@ -141,6 +146,29 @@ case class And(left: Expression, right: Expression) } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + + // The result should be `false`, if any of them is `false` whenever the other is null or not. + s""" + ${eval1.code} + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = false; + + if (!${eval1.isNull} && !${eval1.primitive}) { + } else { + ${eval2.code} + if (!${eval2.isNull} && !${eval2.primitive}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.primitive} = true; + } else { + ${ev.isNull} = true; + } + } + """ + } } case class Or(left: Expression, right: Expression) @@ -167,6 +195,29 @@ case class Or(left: Expression, right: Expression) } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + + // The result should be `true`, if any of them is `true` whenever the other is null or not. + s""" + ${eval1.code} + boolean ${ev.isNull} = false; + boolean ${ev.primitive} = true; + + if (!${eval1.isNull} && ${eval1.primitive}) { + } else { + ${eval2.code} + if (!${eval2.isNull} && ${eval2.primitive}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.primitive} = false; + } else { + ${ev.isNull} = true; + } + } + """ + } } abstract class BinaryComparison extends BinaryExpression with Predicate { @@ -198,6 +249,20 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + left.dataType match { + case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, { + (c1, c3) => s"$c1 $symbol $c3" + }) + case TimestampType => + // java.sql.Timestamp does not have compare() + super.genCode(ctx, ev) + case other => defineCodeGen (ctx, ev, { + (c1, c2) => s"$c1.compare($c2) $symbol 0" + }) + } + } + protected def evalInternal(evalE1: Any, evalE2: Any): Any = sys.error(s"BinaryComparisons must override either eval or evalInternal") } @@ -215,6 +280,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison if (left.dataType != BinaryType) l == r else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType)) + } } case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { @@ -235,6 +303,17 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp l == r } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive) + ev.isNull = "false" + eval1.code + eval2.code + s""" + boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) || + (!${eval1.isNull} && $equalCode); + """ + } } case class LessThan(left: Expression, right: Expression) extends BinaryComparison { @@ -309,6 +388,27 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val condEval = predicate.gen(ctx) + val trueEval = trueValue.gen(ctx) + val falseEval = falseValue.gen(ctx) + + s""" + ${condEval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${condEval.isNull} && ${condEval.primitive}) { + ${trueEval.code} + ${ev.isNull} = ${trueEval.isNull}; + ${ev.primitive} = ${trueEval.primitive}; + } else { + ${falseEval.code} + ${ev.isNull} = ${falseEval.isNull}; + ${ev.primitive} = ${falseEval.primitive}; + } + """ + } + override def toString: String = s"if ($predicate) $trueValue else $falseValue" } @@ -393,6 +493,48 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { return res } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val len = branchesArr.length + val got = ctx.freshName("got") + + val cases = (0 until len/2).map { i => + val cond = branchesArr(i * 2).gen(ctx) + val res = branchesArr(i * 2 + 1).gen(ctx) + s""" + if (!$got) { + ${cond.code} + if (!${cond.isNull} && ${cond.primitive}) { + $got = true; + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + } + """ + }.mkString("\n") + + val other = if (len % 2 == 1) { + val res = branchesArr(len - 1).gen(ctx) + s""" + if (!$got) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + """ + } else { + "" + } + + s""" + boolean $got = false; + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + $cases + $other + """ + } + override def toString: String = { "CASE" + branches.sliding(2, 2).map { case Seq(cond, value) => s" WHEN $cond THEN $value" @@ -444,6 +586,52 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW return res } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val keyEval = key.gen(ctx) + val len = branchesArr.length + val got = ctx.freshName("got") + + val cases = (0 until len/2).map { i => + val cond = branchesArr(i * 2).gen(ctx) + val res = branchesArr(i * 2 + 1).gen(ctx) + s""" + if (!$got) { + ${cond.code} + if (${keyEval.isNull} && ${cond.isNull} || + !${keyEval.isNull} && !${cond.isNull} + && ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) { + $got = true; + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + } + """ + }.mkString("\n") + + val other = if (len % 2 == 1) { + val res = branchesArr(len - 1).gen(ctx) + s""" + if (!$got) { + ${res.code} + ${ev.isNull} = ${res.isNull}; + ${ev.primitive} = ${res.primitive}; + } + """ + } else { + "" + } + + s""" + boolean $got = false; + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${keyEval.code} + $cases + $other + """ + } + private def equalNullSafe(l: Any, r: Any) = { if (l == null && r == null) { true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index b65bf165f2..b39349b988 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -60,6 +61,17 @@ case class NewSet(elementType: DataType) extends LeafExpression { new OpenHashSet[Any]() } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + elementType match { + case IntegerType | LongType => + ev.isNull = "false" + s""" + ${ctx.javaType(dataType)} ${ev.primitive} = new ${ctx.javaType(dataType)}(); + """ + case _ => super.genCode(ctx, ev) + } + } + override def toString: String = s"new Set($dataType)" } @@ -91,6 +103,25 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType + elementType match { + case IntegerType | LongType => + val itemEval = item.gen(ctx) + val setEval = set.gen(ctx) + val htype = ctx.javaType(dataType) + + ev.isNull = "false" + ev.primitive = setEval.primitive + itemEval.code + setEval.code + s""" + if (!${itemEval.isNull} && !${setEval.isNull}) { + (($htype)${setEval.primitive}).add(${itemEval.primitive}); + } + """ + case _ => super.genCode(ctx, ev) + } + } + override def toString: String = s"$set += $item" } @@ -116,14 +147,31 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres val rightValue = iterator.next() leftEval.add(rightValue) } - leftEval - } else { - null } + leftEval } else { null } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType + elementType match { + case IntegerType | LongType => + val leftEval = left.gen(ctx) + val rightEval = right.gen(ctx) + val htype = ctx.javaType(dataType) + + ev.isNull = leftEval.isNull + ev.primitive = leftEval.primitive + leftEval.code + rightEval.code + s""" + if (!${leftEval.isNull} && !${rightEval.isNull}) { + ${leftEval.primitive}.union((${htype})${rightEval.primitive}); + } + """ + case _ => super.genCode(ctx, ev) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index c4ef9c3090..78adb509b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ trait StringRegexExpression extends ExpectsInputTypes { @@ -137,6 +138,10 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE override def convert(v: UTF8String): UTF8String = v.toUpperCase() override def toString: String = s"Upper($child)" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") + } } /** @@ -147,6 +152,10 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE override def convert(v: UTF8String): UTF8String = v.toLowerCase() override def toString: String = s"Lower($child)" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") + } } /** A base trait for functions that compare two strings, returning a boolean. */ @@ -181,6 +190,9 @@ trait StringComparison extends ExpectsInputTypes { case class Contains(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") + } } /** @@ -189,6 +201,9 @@ case class Contains(left: Expression, right: Expression) case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") + } } /** @@ -197,6 +212,9 @@ case class StartsWith(left: Expression, right: Expression) case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 5df528770c..eea2edc323 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.expressions.mathfuncs._ import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -35,11 +36,20 @@ import org.apache.spark.sql.types._ class ExpressionEvaluationBaseSuite extends SparkFunSuite { + def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + checkEvaluationWithoutCodegen(expression, expected, inputRow) + checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow) + checkEvaluationWithGeneratedProjection(expression, expected, inputRow) + } + def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { expression.eval(inputRow) } - def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + def checkEvaluationWithoutCodegen( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } @@ -49,6 +59,68 @@ class ExpressionEvaluationBaseSuite extends SparkFunSuite { } } + def checkEvaluationWithGeneratedMutableProjection( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + + val plan = try { + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() + } catch { + case e: Throwable => + val ctx = GenerateProjection.newCodeGenContext() + val evaluated = expression.gen(ctx) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow).apply(0) + if (actual != expected) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + def checkEvaluationWithGeneratedProjection( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + val ctx = GenerateProjection.newCodeGenContext() + lazy val evaluated = expression.gen(ctx) + + val plan = try { + GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) + } catch { + case e: Throwable => + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val actual = plan(inputRow) + val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) + if (actual.hashCode() != expectedRow.hashCode()) { + fail( + s""" + |Mismatched hashCodes for values: $actual, $expectedRow + |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} + |Expressions: ${expression} + |Code: ${evaluated} + """.stripMargin) + } + if (actual != expectedRow) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + def checkDoubleEvaluation( expression: Expression, expected: Spread[Double], @@ -69,8 +141,16 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("literals") { checkEvaluation(Literal(1), 1) checkEvaluation(Literal(true), true) + checkEvaluation(Literal(false), false) checkEvaluation(Literal(0L), 0L) + List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { + d => { + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toFloat), d.toFloat) + } + } checkEvaluation(Literal("test"), "test") + checkEvaluation(Literal.create(null, StringType), null) checkEvaluation(Literal(1) + Literal(1), 2) } @@ -1367,6 +1447,11 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { // TODO: Make the tests work with codegen. class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite { + override def checkEvaluation( + expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + checkEvaluationWithoutCodegen(expression, expected, inputRow) + } + test("CreateStruct") { val row = Row(1, 2, 3) val c1 = 'a.int.at(0).as("a") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index 8cfd853afa..371a73181d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -21,34 +21,9 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ /** - * Overrides our expression evaluation tests to use code generation for evaluation. + * Additional tests for code generation. */ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - val plan = try { - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() - } catch { - case e: Throwable => - val ctx = GenerateProjection.newCodeGenContext() - val evaluated = GenerateProjection.expressionEvaluator(expression, ctx) - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code} - |$e - """.stripMargin) - } - - val actual = plan(inputRow).apply(0) - if (actual != expected) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } - test("multithreaded eval") { import scala.concurrent._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala deleted file mode 100644 index 9ab1f7d7ad..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen._ - -/** - * Overrides our expression evaluation tests to use generated code on mutable rows. - */ -class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { - override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { - val ctx = GenerateProjection.newCodeGenContext() - lazy val evaluated = GenerateProjection.expressionEvaluator(expression, ctx) - - val plan = try { - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) - } catch { - case e: Throwable => - fail( - s""" - |Code generation of $expression failed: - |${evaluated.code} - |$e - """.stripMargin) - } - - val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) - if (actual.hashCode() != expectedRow.hashCode()) { - fail( - s""" - |Mismatched hashCodes for values: $actual, $expectedRow - |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |${evaluated.code} - """.stripMargin) - } - if (actual != expectedRow) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 8979a0a210..d9a010a981 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -53,7 +53,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { check("10", Literal.create(10, IntegerType)) check("1000000000000000", Literal.create(1000000000000000L, LongType)) - check("1.5", Literal.create(1.5, FloatType)) + check("1.5", Literal.create(1.5f, FloatType)) check("hello", Literal.create("hello", StringType)) check(defaultPartitionName, Literal.create(null, NullType)) } @@ -83,13 +83,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { ArrayBuffer( Literal.create(10, IntegerType), Literal.create("hello", StringType), - Literal.create(1.5, FloatType))) + Literal.create(1.5f, FloatType))) }) check("file://path/a=10/b_hello/c=1.5", Some { PartitionValues( ArrayBuffer("c"), - ArrayBuffer(Literal.create(1.5, FloatType))) + ArrayBuffer(Literal.create(1.5f, FloatType))) }) check("file:///", None)