[SPARK-8117] [SQL] Push codegen implementation into each Expression

This PR move codegen implementation of expressions into Expression class itself, make it easy to manage.

It introduces two APIs in Expression:
```
def gen(ctx: CodeGenContext): GeneratedExpressionCode
def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code
```

gen(ctx) will call genSource(ctx, ev) to generate Java source code for the current expression. A expression needs to override genSource().

Here are the types:
```
type Term String
type Code String

/**
 * Java source for evaluating an [[Expression]] given a [[Row]] of input.
 */
case class GeneratedExpressionCode(var code: Code,
                               nullTerm: Term,
                               primitiveTerm: Term,
                               objectTerm: Term)
/**
 * A context for codegen, which is used to bookkeeping the expressions those are not supported
 * by codegen, then they are evaluated directly. The unsupported expression is appended at the
 * end of `references`, the position of it is kept in the code, used to access and evaluate it.
 */
class CodeGenContext {
  /**
   * Holding all the expressions those do not support codegen, will be evaluated directly.
   */
  val references: Seq[Expression] = new mutable.ArrayBuffer[Expression]()
}
```

This is basically #6660, but fixed style violation and compilation failure.

Author: Davies Liu <davies@databricks.com>
Author: Reynold Xin <rxin@databricks.com>

Closes #6690 from rxin/codegen and squashes the following commits:

e1368c2 [Reynold Xin] Fixed tests.
73db80e [Reynold Xin] Fixed compilation failure.
19d6435 [Reynold Xin] Fixed style violation.
9adaeaf [Davies Liu] address comments
f42c732 [Davies Liu] improve coverage and tests
bad6828 [Davies Liu] address comments
e03edaa [Davies Liu] consts fold
86fac2c [Davies Liu] fix style
02262c9 [Davies Liu] address comments
b5d3617 [Davies Liu] Merge pull request #5 from rxin/codegen
48c454f [Reynold Xin] Some code gen update.
2344bc0 [Davies Liu] fix test
12ff88a [Davies Liu] fix build
c5fb514 [Davies Liu] rename
8c6d82d [Davies Liu] update docs
b145047 [Davies Liu] fix style
e57959d [Davies Liu] add type alias
3ff25f8 [Davies Liu] refactor
593d617 [Davies Liu] pushing codegen into Expression
This commit is contained in:
Davies Liu 2015-06-07 14:11:20 -07:00 committed by Reynold Xin
parent b127ff8a0c
commit 5e7b6b67be
23 changed files with 1036 additions and 718 deletions

View file

@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.trees
@ -41,6 +42,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
override def qualifiers: Seq[String] = throw new UnsupportedOperationException
override def exprId: ExprId = throw new UnsupportedOperationException
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
s"""
boolean ${ev.isNull} = i.isNullAt($ordinal);
${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ?
${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)});
"""
}
}
object BindReferences extends Logging {

View file

@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
@ -433,6 +434,47 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
val evaluated = child.eval(input)
if (evaluated == null) null else cast(evaluated)
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
// TODO(cg): Add support for more data types.
(child.dataType, dataType) match {
case (BinaryType, StringType) =>
defineCodeGen (ctx, ev, c =>
s"new ${ctx.stringType}().set($c)")
case (DateType, StringType) =>
defineCodeGen(ctx, ev, c =>
s"""new ${ctx.stringType}().set(
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""")
// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case (TimestampType, StringType) =>
super.genCode(ctx, ev)
case (_, StringType) =>
defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))")
// fallback for DecimalType, this must be before other numeric types
case (_, dt: DecimalType) =>
super.genCode(ctx, ev)
case (BooleanType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
case (dt: DecimalType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c.isZero()")
case (dt: NumericType, BooleanType) =>
defineCodeGen(ctx, ev, c => s"$c != 0")
case (_: DecimalType, IntegerType) =>
defineCodeGen(ctx, ev, c => s"($c).toInt()")
case (_: DecimalType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
case (_: NumericType, dt: NumericType) =>
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)")
case other =>
super.genCode(ctx, ev)
}
}
}
object Cast {

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext, Term}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
@ -51,6 +52,44 @@ abstract class Expression extends TreeNode[Expression] {
/** Returns the result of evaluating this expression on a given input Row */
def eval(input: Row = null): Any
/**
* Returns an [[GeneratedExpressionCode]], which contains Java source code that
* can be used to generate the result of evaluating the expression on an input row.
*
* @param ctx a [[CodeGenContext]]
* @return [[GeneratedExpressionCode]]
*/
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
val isNull = ctx.freshName("isNull")
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
ve.code = genCode(ctx, ve)
ve
}
/**
* Returns Java source code that can be compiled to evaluate this expression.
* The default behavior is to call the eval method of the expression. Concrete expression
* implementations should override this to do actual code generation.
*
* @param ctx a [[CodeGenContext]]
* @param ev an [[GeneratedExpressionCode]] with unique terms.
* @return Java source code
*/
protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
ctx.references += this
val objectTerm = ctx.freshName("obj")
s"""
/* expression: ${this} */
Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
boolean ${ev.isNull} = ${objectTerm} == null;
${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)};
if (!${ev.isNull}) {
${ev.primitive} = (${ctx.boxedType(this.dataType)})${objectTerm};
}
"""
}
/**
* Returns `true` if this expression and all its children have been resolved to a specific schema
* and input data types checking passed, and `false` if it still contains any unresolved
@ -116,6 +155,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
override def nullable: Boolean = left.nullable || right.nullable
override def toString: String = s"($left $symbol $right)"
/**
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of
* the same type. If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f accepts two variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: (Term, Term) => Code): String = {
// TODO: Right now some timestamp tests fail if we enforce this...
if (left.dataType != right.dataType) {
// log.warn(s"${left.dataType} != ${right.dataType}")
}
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val resultCode = f(eval1.primitive, eval2.primitive)
s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${eval2.code}
if(!${eval2.isNull}) {
${ev.primitive} = $resultCode;
} else {
${ev.isNull} = true;
}
}
"""
}
}
private[sql] object BinaryExpression {
@ -128,6 +202,32 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>
/**
* Called by unary expressions to generate a code block that returns null if its parent returns
* null, and if not not null, use `f` to generate the expression.
*
* As an example, the following does a boolean inversion (i.e. NOT).
* {{{
* defineCodeGen(ctx, ev, c => s"!($c)")
* }}}
*
* @param f function that accepts a variable name and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodeGenContext,
ev: GeneratedExpressionCode,
f: Term => Code): Code = {
val eval = child.gen(ctx)
// reuse the previous isNull
ev.isNull = eval.isNull
eval.code + s"""
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.primitive} = ${f(eval.primitive)};
}
"""
}
}
// TODO Semantically we probably not need GroupExpression

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@ -49,6 +50,11 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {
private lazy val numeric = TypeUtils.getNumeric(dataType)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()")
case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)")
}
protected override def evalInternal(evalE: Any) = numeric.negate(evalE)
}
@ -67,6 +73,21 @@ case class Sqrt(child: Expression) extends UnaryArithmetic {
if (value < 0) null
else math.sqrt(value)
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
if (${eval.primitive} < 0.0) {
${ev.isNull} = true;
} else {
${ev.primitive} = java.lang.Math.sqrt(${eval.primitive});
}
}
"""
}
}
/**
@ -86,6 +107,9 @@ case class Abs(child: Expression) extends UnaryArithmetic {
abstract class BinaryArithmetic extends BinaryExpression {
self: Product =>
/** Name of the function for this expression on a [[Decimal]] type. */
def decimalMethod: String = ""
override def dataType: DataType = left.dataType
override def checkInputDataTypes(): TypeCheckResult = {
@ -114,6 +138,17 @@ abstract class BinaryArithmetic extends BinaryExpression {
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
case dt: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
defineCodeGen(ctx, ev, (eval1, eval2) =>
s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
protected def evalInternal(evalE1: Any, evalE2: Any): Any =
sys.error(s"BinaryArithmetics must override either eval or evalInternal")
}
@ -124,6 +159,7 @@ private[sql] object BinaryArithmetic {
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "+"
override def decimalMethod: String = "$plus"
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
@ -138,6 +174,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "-"
override def decimalMethod: String = "$minus"
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
@ -152,6 +189,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "*"
override def decimalMethod: String = "$times"
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
@ -166,6 +204,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "/"
override def decimalMethod: String = "$divide"
override def nullable: Boolean = true
override lazy val resolved =
@ -192,10 +232,40 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}
}
}
/**
* Special case handling due to division by 0 => null.
*/
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val test = if (left.dataType.isInstanceOf[DecimalType]) {
s"${eval2.primitive}.isZero()"
} else {
s"${eval2.primitive} == 0"
}
val method = if (left.dataType.isInstanceOf[DecimalType]) {
s".$decimalMethod"
} else {
s"$symbol"
}
eval1.code + eval2.code +
s"""
boolean ${ev.isNull} = false;
${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
if (${eval1.isNull} || ${eval2.isNull} || $test) {
${ev.isNull} = true;
} else {
${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
}
"""
}
}
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "%"
override def decimalMethod: String = "reminder"
override def nullable: Boolean = true
override lazy val resolved =
@ -222,6 +292,34 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
}
}
}
/**
* Special case handling for x % 0 ==> null.
*/
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val test = if (left.dataType.isInstanceOf[DecimalType]) {
s"${eval2.primitive}.isZero()"
} else {
s"${eval2.primitive} == 0"
}
val method = if (left.dataType.isInstanceOf[DecimalType]) {
s".$decimalMethod"
} else {
s"$symbol"
}
eval1.code + eval2.code +
s"""
boolean ${ev.isNull} = false;
${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
if (${eval1.isNull} || ${eval2.isNull} || $test) {
${ev.isNull} = true;
} else {
${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
}
"""
}
}
/**
@ -271,7 +369,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
}
/**
* A function that calculates bitwise xor(^) of two numbers.
* A function that calculates bitwise xor of two numbers.
*/
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "^"
@ -313,6 +411,10 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic {
((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any]
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)})~($c)")
}
protected override def evalInternal(evalE: Any) = not(evalE)
}
@ -340,6 +442,33 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
if (ctx.isNativeType(left.dataType)) {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
eval1.code + eval2.code + s"""
boolean ${ev.isNull} = false;
${ctx.javaType(left.dataType)} ${ev.primitive} =
${ctx.defaultValue(left.dataType)};
if (${eval1.isNull}) {
${ev.isNull} = ${eval2.isNull};
${ev.primitive} = ${eval2.primitive};
} else if (${eval2.isNull}) {
${ev.isNull} = ${eval1.isNull};
${ev.primitive} = ${eval1.primitive};
} else {
if (${eval1.primitive} > ${eval2.primitive}) {
${ev.primitive} = ${eval1.primitive};
} else {
${ev.primitive} = ${eval2.primitive};
}
}
"""
} else {
super.genCode(ctx, ev)
}
}
override def toString: String = s"MaxOf($left, $right)"
}
@ -367,5 +496,35 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
if (ctx.isNativeType(left.dataType)) {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
eval1.code + eval2.code + s"""
boolean ${ev.isNull} = false;
${ctx.javaType(left.dataType)} ${ev.primitive} =
${ctx.defaultValue(left.dataType)};
if (${eval1.isNull}) {
${ev.isNull} = ${eval2.isNull};
${ev.primitive} = ${eval2.primitive};
} else if (${eval2.isNull}) {
${ev.isNull} = ${eval1.isNull};
${ev.primitive} = ${eval1.primitive};
} else {
if (${eval1.primitive} < ${eval2.primitive}) {
${ev.primitive} = ${eval1.primitive};
} else {
${ev.primitive} = ${eval2.primitive};
}
}
"""
} else {
super.genCode(ctx, ev)
}
}
override def toString: String = s"MinOf($left, $right)"
}

View file

@ -24,7 +24,6 @@ import com.google.common.cache.{CacheBuilder, CacheLoader}
import org.codehaus.janino.ClassBodyEvaluator
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
@ -32,6 +31,157 @@ import org.apache.spark.sql.types._
class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int]
class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
/**
* Java source for evaluating an [[Expression]] given a [[Row]] of input.
*
* @param code The sequence of statements required to evaluate the expression.
* @param isNull A term that holds a boolean value representing whether the expression evaluated
* to null.
* @param primitive A term for a possible primitive value of the result of the evaluation. Not
* valid if `isNull` is set to `true`.
*/
case class GeneratedExpressionCode(var code: Code, var isNull: Term, var primitive: Term)
/**
* A context for codegen, which is used to bookkeeping the expressions those are not supported
* by codegen, then they are evaluated directly. The unsupported expression is appended at the
* end of `references`, the position of it is kept in the code, used to access and evaluate it.
*/
class CodeGenContext {
/**
* Holding all the expressions those do not support codegen, will be evaluated directly.
*/
val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]()
val stringType: String = classOf[UTF8String].getName
val decimalType: String = classOf[Decimal].getName
private val curId = new java.util.concurrent.atomic.AtomicInteger()
/**
* Returns a term name that is unique within this instance of a `CodeGenerator`.
*
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
* function.)
*/
def freshName(prefix: String): Term = {
s"$prefix${curId.getAndIncrement}"
}
/**
* Return the code to access a column for given DataType
*/
def getColumn(dataType: DataType, ordinal: Int): Code = {
if (isNativeType(dataType)) {
s"i.${accessorForType(dataType)}($ordinal)"
} else {
s"(${boxedType(dataType)})i.apply($ordinal)"
}
}
/**
* Return the code to update a column in Row for given DataType
*/
def setColumn(dataType: DataType, ordinal: Int, value: Term): Code = {
if (isNativeType(dataType)) {
s"${mutatorForType(dataType)}($ordinal, $value)"
} else {
s"update($ordinal, $value)"
}
}
/**
* Return the name of accessor in Row for a DataType
*/
def accessorForType(dt: DataType): Term = dt match {
case IntegerType => "getInt"
case other => s"get${boxedType(dt)}"
}
/**
* Return the name of mutator in Row for a DataType
*/
def mutatorForType(dt: DataType): Term = dt match {
case IntegerType => "setInt"
case other => s"set${boxedType(dt)}"
}
/**
* Return the Java type for a DataType
*/
def javaType(dt: DataType): Term = dt match {
case IntegerType => "int"
case LongType => "long"
case ShortType => "short"
case ByteType => "byte"
case DoubleType => "double"
case FloatType => "float"
case BooleanType => "boolean"
case dt: DecimalType => decimalType
case BinaryType => "byte[]"
case StringType => stringType
case DateType => "int"
case TimestampType => "java.sql.Timestamp"
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
case _ => "Object"
}
/**
* Return the boxed type in Java
*/
def boxedType(dt: DataType): Term = dt match {
case IntegerType => "Integer"
case LongType => "Long"
case ShortType => "Short"
case ByteType => "Byte"
case DoubleType => "Double"
case FloatType => "Float"
case BooleanType => "Boolean"
case DateType => "Integer"
case _ => javaType(dt)
}
/**
* Return the representation of default value for given DataType
*/
def defaultValue(dt: DataType): Term = dt match {
case BooleanType => "false"
case FloatType => "-1.0f"
case ShortType => "(short)-1"
case LongType => "-1L"
case ByteType => "(byte)-1"
case DoubleType => "-1.0"
case IntegerType => "-1"
case DateType => "-1"
case _ => "null"
}
/**
* Returns a function to generate equal expression in Java
*/
def equalFunc(dataType: DataType): ((Term, Term) => Code) = dataType match {
case BinaryType => { case (eval1, eval2) =>
s"java.util.Arrays.equals($eval1, $eval2)" }
case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType =>
{ case (eval1, eval2) => s"$eval1 == $eval2" }
case other =>
{ case (eval1, eval2) => s"$eval1.equals($eval2)" }
}
/**
* List of data types that have special accessors and setters in [[Row]].
*/
val nativeTypes =
Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType)
/**
* Returns true if the data type has a special accessor and setter in [[Row]].
*/
def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt)
}
/**
* A base class for generators of byte code to perform expression evaluation. Includes a set of
* helpers for referring to Catalyst types and building trees that perform evaluation of individual
@ -39,14 +189,9 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
*/
abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging {
protected val rowType = classOf[Row].getName
protected val stringType = classOf[UTF8String].getName
protected val decimalType = classOf[Decimal].getName
protected val exprType = classOf[Expression].getName
protected val mutableRowType = classOf[MutableRow].getName
protected val genericMutableRowType = classOf[GenericMutableRow].getName
private val curId = new java.util.concurrent.atomic.AtomicInteger()
protected val exprType: String = classOf[Expression].getName
protected val mutableRowType: String = classOf[MutableRow].getName
protected val genericMutableRowType: String = classOf[GenericMutableRow].getName
/**
* Can be flipped on manually in the console to add (expensive) expression evaluation trace code.
@ -75,10 +220,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
*/
protected def compile(code: String): Class[_] = {
val startTime = System.nanoTime()
val clazz = new ClassBodyEvaluator(code).getClazz()
val clazz = try {
new ClassBodyEvaluator(code).getClazz()
} catch {
case e: Exception =>
logError(s"failed to compile:\n $code", e)
throw e
}
val endTime = System.nanoTime()
def timeMs: Double = (endTime - startTime).toDouble / 1000000
logDebug(s"Compiled Java code (${code.size} bytes) in $timeMs ms")
logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms")
clazz
}
@ -112,586 +263,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
/** Generates the requested evaluator given already bound expression(s). */
def generate(expressions: InType): OutType = cache.get(canonicalize(expressions))
/**
* Returns a term name that is unique within this instance of a `CodeGenerator`.
*
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
* function.)
*/
protected def freshName(prefix: String): String = {
s"$prefix${curId.getAndIncrement}"
}
/**
* Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input.
*
* @param code The sequence of statements required to evaluate the expression.
* @param nullTerm A term that holds a boolean value representing whether the expression evaluated
* to null.
* @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not
* valid if `nullTerm` is set to `true`.
* @param objectTerm A possibly boxed version of the result of evaluating this expression.
*/
protected case class EvaluatedExpression(
code: String,
nullTerm: String,
primitiveTerm: String,
objectTerm: String)
/**
* A context for codegen, which is used to bookkeeping the expressions those are not supported
* by codegen, then they are evaluated directly. The unsupported expression is appended at the
* end of `references`, the position of it is kept in the code, used to access and evaluate it.
*/
protected class CodeGenContext {
/**
* Holding all the expressions those do not support codegen, will be evaluated directly.
*/
val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]()
}
/**
* Create a new codegen context for expression evaluator, used to store those
* expressions that don't support codegen
*/
def newCodeGenContext(): CodeGenContext = {
new CodeGenContext()
}
/**
* Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that
* can be used to determine the result of evaluating the expression on an input row.
*/
def expressionEvaluator(e: Expression, ctx: CodeGenContext): EvaluatedExpression = {
val primitiveTerm = freshName("primitiveTerm")
val nullTerm = freshName("nullTerm")
val objectTerm = freshName("objectTerm")
implicit class Evaluate1(e: Expression) {
def castOrNull(f: String => String, dataType: DataType): String = {
val eval = expressionEvaluator(e, ctx)
eval.code +
s"""
boolean $nullTerm = ${eval.nullTerm};
${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)};
if (!$nullTerm) {
$primitiveTerm = ${f(eval.primitiveTerm)};
}
"""
new CodeGenContext
}
}
implicit class Evaluate2(expressions: (Expression, Expression)) {
/**
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of
* the same type. If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f a function from two primitive term names to a tree that evaluates them.
*/
def evaluate(f: (String, String) => String): String =
evaluateAs(expressions._1.dataType)(f)
def evaluateAs(resultType: DataType)(f: (String, String) => String): String = {
// TODO: Right now some timestamp tests fail if we enforce this...
if (expressions._1.dataType != expressions._2.dataType) {
log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}")
}
val eval1 = expressionEvaluator(expressions._1, ctx)
val eval2 = expressionEvaluator(expressions._2, ctx)
val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)
eval1.code + eval2.code +
s"""
boolean $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm};
${primitiveForType(resultType)} $primitiveTerm = ${defaultPrimitive(resultType)};
if(!$nullTerm) {
$primitiveTerm = (${primitiveForType(resultType)})($resultCode);
}
"""
}
}
val inputTuple = "i"
// TODO: Skip generation of null handling code when expression are not nullable.
val primitiveEvaluation: PartialFunction[Expression, String] = {
case b @ BoundReference(ordinal, dataType, nullable) =>
s"""
final boolean $nullTerm = $inputTuple.isNullAt($ordinal);
final ${primitiveForType(dataType)} $primitiveTerm = $nullTerm ?
${defaultPrimitive(dataType)} : (${getColumn(inputTuple, dataType, ordinal)});
"""
case expressions.Literal(null, dataType) =>
s"""
final boolean $nullTerm = true;
${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)};
"""
case expressions.Literal(value: UTF8String, StringType) =>
val arr = s"new byte[]{${value.getBytes.map(_.toString).mkString(", ")}}"
s"""
final boolean $nullTerm = false;
${stringType} $primitiveTerm =
new ${stringType}().set(${arr});
"""
case expressions.Literal(value, FloatType) =>
s"""
final boolean $nullTerm = false;
float $primitiveTerm = ${value}f;
"""
case expressions.Literal(value, dt @ DecimalType()) =>
s"""
final boolean $nullTerm = false;
${primitiveForType(dt)} $primitiveTerm = new ${primitiveForType(dt)}().set($value);
"""
case expressions.Literal(value, dataType) =>
s"""
final boolean $nullTerm = false;
${primitiveForType(dataType)} $primitiveTerm = $value;
"""
case Cast(child @ BinaryType(), StringType) =>
child.castOrNull(c =>
s"new ${stringType}().set($c)",
StringType)
case Cast(child @ DateType(), StringType) =>
child.castOrNull(c =>
s"""new ${stringType}().set(
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""",
StringType)
case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
child.castOrNull(c => s"(${primitiveForType(dt)})($c?1:0)", dt)
case Cast(child @ DecimalType(), IntegerType) =>
child.castOrNull(c => s"($c).toInt()", IntegerType)
case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
child.castOrNull(c => s"($c).to${termForType(dt)}()", dt)
case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
child.castOrNull(c => s"(${primitiveForType(dt)})($c)", dt)
// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case Cast(e, StringType) if e.dataType != TimestampType =>
e.castOrNull(c =>
s"new ${stringType}().set(String.valueOf($c))",
StringType)
case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) =>
(e1, e2).evaluateAs (BooleanType) {
case (eval1, eval2) =>
s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)"
}
case EqualTo(e1, e2) =>
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 == $eval2" }
case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) =>
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 > $eval2" }
case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 >= $eval2" }
case LessThan(e1 @ NumericType(), e2 @ NumericType()) =>
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 < $eval2" }
case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 <= $eval2" }
case And(e1, e2) =>
val eval1 = expressionEvaluator(e1, ctx)
val eval2 = expressionEvaluator(e2, ctx)
s"""
${eval1.code}
boolean $nullTerm = false;
boolean $primitiveTerm = false;
if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) {
} else {
${eval2.code}
if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) {
} else if (!${eval1.nullTerm} && !${eval2.nullTerm}) {
$primitiveTerm = true;
} else {
$nullTerm = true;
}
}
"""
case Or(e1, e2) =>
val eval1 = expressionEvaluator(e1, ctx)
val eval2 = expressionEvaluator(e2, ctx)
s"""
${eval1.code}
boolean $nullTerm = false;
boolean $primitiveTerm = false;
if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) {
$primitiveTerm = true;
} else {
${eval2.code}
if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) {
$primitiveTerm = true;
} else if (!${eval1.nullTerm} && !${eval2.nullTerm}) {
$primitiveTerm = false;
} else {
$nullTerm = true;
}
}
"""
case Not(child) =>
// Uh, bad function name...
child.castOrNull(c => s"!$c", BooleanType)
case Add(e1 @ DecimalType(), e2 @ DecimalType()) =>
(e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$plus($eval2)" }
case Subtract(e1 @ DecimalType(), e2 @ DecimalType()) =>
(e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$minus($eval2)" }
case Multiply(e1 @ DecimalType(), e2 @ DecimalType()) =>
(e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$times($eval2)" }
case Divide(e1 @ DecimalType(), e2 @ DecimalType()) =>
val eval1 = expressionEvaluator(e1, ctx)
val eval2 = expressionEvaluator(e2, ctx)
eval1.code + eval2.code +
s"""
boolean $nullTerm = false;
${primitiveForType(e1.dataType)} $primitiveTerm = null;
if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) {
$nullTerm = true;
} else {
$primitiveTerm = ${eval1.primitiveTerm}.$$div${eval2.primitiveTerm});
}
"""
case Remainder(e1 @ DecimalType(), e2 @ DecimalType()) =>
val eval1 = expressionEvaluator(e1, ctx)
val eval2 = expressionEvaluator(e2, ctx)
eval1.code + eval2.code +
s"""
boolean $nullTerm = false;
${primitiveForType(e1.dataType)} $primitiveTerm = 0;
if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) {
$nullTerm = true;
} else {
$primitiveTerm = ${eval1.primitiveTerm}.remainder(${eval2.primitiveTerm});
}
"""
case Add(e1, e2) =>
(e1, e2) evaluate { case (eval1, eval2) => s"$eval1 + $eval2" }
case Subtract(e1, e2) =>
(e1, e2) evaluate { case (eval1, eval2) => s"$eval1 - $eval2" }
case Multiply(e1, e2) =>
(e1, e2) evaluate { case (eval1, eval2) => s"$eval1 * $eval2" }
case Divide(e1, e2) =>
val eval1 = expressionEvaluator(e1, ctx)
val eval2 = expressionEvaluator(e2, ctx)
eval1.code + eval2.code +
s"""
boolean $nullTerm = false;
${primitiveForType(e1.dataType)} $primitiveTerm = 0;
if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) {
$nullTerm = true;
} else {
$primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm};
}
"""
case Remainder(e1, e2) =>
val eval1 = expressionEvaluator(e1, ctx)
val eval2 = expressionEvaluator(e2, ctx)
eval1.code + eval2.code +
s"""
boolean $nullTerm = false;
${primitiveForType(e1.dataType)} $primitiveTerm = 0;
if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) {
$nullTerm = true;
} else {
$primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm};
}
"""
case IsNotNull(e) =>
val eval = expressionEvaluator(e, ctx)
s"""
${eval.code}
boolean $nullTerm = false;
boolean $primitiveTerm = !${eval.nullTerm};
"""
case IsNull(e) =>
val eval = expressionEvaluator(e, ctx)
s"""
${eval.code}
boolean $nullTerm = false;
boolean $primitiveTerm = ${eval.nullTerm};
"""
case e @ Coalesce(children) =>
s"""
boolean $nullTerm = true;
${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)};
""" +
children.map { c =>
val eval = expressionEvaluator(c, ctx)
s"""
if($nullTerm) {
${eval.code}
if(!${eval.nullTerm}) {
$nullTerm = false;
$primitiveTerm = ${eval.primitiveTerm};
}
}
"""
}.mkString("\n")
case e @ expressions.If(condition, trueValue, falseValue) =>
val condEval = expressionEvaluator(condition, ctx)
val trueEval = expressionEvaluator(trueValue, ctx)
val falseEval = expressionEvaluator(falseValue, ctx)
s"""
boolean $nullTerm = false;
${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)};
${condEval.code}
if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) {
${trueEval.code}
$nullTerm = ${trueEval.nullTerm};
$primitiveTerm = ${trueEval.primitiveTerm};
} else {
${falseEval.code}
$nullTerm = ${falseEval.nullTerm};
$primitiveTerm = ${falseEval.primitiveTerm};
}
"""
case NewSet(elementType) =>
s"""
boolean $nullTerm = false;
${hashSetForType(elementType)} $primitiveTerm = new ${hashSetForType(elementType)}();
"""
case AddItemToSet(item, set) =>
val itemEval = expressionEvaluator(item, ctx)
val setEval = expressionEvaluator(set, ctx)
val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType
val htype = hashSetForType(elementType)
itemEval.code + setEval.code +
s"""
if (!${itemEval.nullTerm} && !${setEval.nullTerm}) {
(($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm});
}
boolean $nullTerm = false;
${htype} $primitiveTerm = ($htype)${setEval.primitiveTerm};
"""
case CombineSets(left, right) =>
val leftEval = expressionEvaluator(left, ctx)
val rightEval = expressionEvaluator(right, ctx)
val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType
val htype = hashSetForType(elementType)
leftEval.code + rightEval.code +
s"""
boolean $nullTerm = false;
${htype} $primitiveTerm =
(${htype})${leftEval.primitiveTerm};
$primitiveTerm.union((${htype})${rightEval.primitiveTerm});
"""
case MaxOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] =>
val eval1 = expressionEvaluator(e1, ctx)
val eval2 = expressionEvaluator(e2, ctx)
eval1.code + eval2.code +
s"""
boolean $nullTerm = false;
${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)};
if (${eval1.nullTerm}) {
$nullTerm = ${eval2.nullTerm};
$primitiveTerm = ${eval2.primitiveTerm};
} else if (${eval2.nullTerm}) {
$nullTerm = ${eval1.nullTerm};
$primitiveTerm = ${eval1.primitiveTerm};
} else {
if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) {
$primitiveTerm = ${eval1.primitiveTerm};
} else {
$primitiveTerm = ${eval2.primitiveTerm};
}
}
"""
case MinOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] =>
val eval1 = expressionEvaluator(e1, ctx)
val eval2 = expressionEvaluator(e2, ctx)
eval1.code + eval2.code +
s"""
boolean $nullTerm = false;
${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)};
if (${eval1.nullTerm}) {
$nullTerm = ${eval2.nullTerm};
$primitiveTerm = ${eval2.primitiveTerm};
} else if (${eval2.nullTerm}) {
$nullTerm = ${eval1.nullTerm};
$primitiveTerm = ${eval1.primitiveTerm};
} else {
if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) {
$primitiveTerm = ${eval1.primitiveTerm};
} else {
$primitiveTerm = ${eval2.primitiveTerm};
}
}
"""
case UnscaledValue(child) =>
val childEval = expressionEvaluator(child, ctx)
childEval.code +
s"""
boolean $nullTerm = ${childEval.nullTerm};
long $primitiveTerm = $nullTerm ? -1 : ${childEval.primitiveTerm}.toUnscaledLong();
"""
case MakeDecimal(child, precision, scale) =>
val eval = expressionEvaluator(child, ctx)
eval.code +
s"""
boolean $nullTerm = ${eval.nullTerm};
org.apache.spark.sql.types.Decimal $primitiveTerm = ${defaultPrimitive(DecimalType())};
if (!$nullTerm) {
$primitiveTerm = new org.apache.spark.sql.types.Decimal();
$primitiveTerm = $primitiveTerm.setOrNull(${eval.primitiveTerm}, $precision, $scale);
$nullTerm = $primitiveTerm == null;
}
"""
}
// If there was no match in the partial function above, we fall back on calling the interpreted
// expression evaluator.
val code: String =
primitiveEvaluation.lift.apply(e).getOrElse {
logError(s"No rules to generate $e")
ctx.references += e
s"""
/* expression: ${e} */
Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i);
boolean $nullTerm = $objectTerm == null;
${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)};
if (!$nullTerm) $primitiveTerm = (${termForType(e.dataType)})$objectTerm;
"""
}
EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm)
}
protected def getColumn(inputRow: String, dataType: DataType, ordinal: Int) = {
dataType match {
case StringType => s"(${stringType})$inputRow.apply($ordinal)"
case dt: DataType if isNativeType(dt) => s"$inputRow.${accessorForType(dt)}($ordinal)"
case _ => s"(${termForType(dataType)})$inputRow.apply($ordinal)"
}
}
protected def setColumn(
destinationRow: String,
dataType: DataType,
ordinal: Int,
value: String): String = {
dataType match {
case StringType => s"$destinationRow.update($ordinal, $value)"
case dt: DataType if isNativeType(dt) =>
s"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
case _ => s"$destinationRow.update($ordinal, $value)"
}
}
protected def accessorForType(dt: DataType) = dt match {
case IntegerType => "getInt"
case other => s"get${termForType(dt)}"
}
protected def mutatorForType(dt: DataType) = dt match {
case IntegerType => "setInt"
case other => s"set${termForType(dt)}"
}
protected def hashSetForType(dt: DataType): String = dt match {
case IntegerType => classOf[IntegerHashSet].getName
case LongType => classOf[LongHashSet].getName
case unsupportedType =>
sys.error(s"Code generation not support for hashset of type $unsupportedType")
}
protected def primitiveForType(dt: DataType): String = dt match {
case IntegerType => "int"
case LongType => "long"
case ShortType => "short"
case ByteType => "byte"
case DoubleType => "double"
case FloatType => "float"
case BooleanType => "boolean"
case dt: DecimalType => decimalType
case BinaryType => "byte[]"
case StringType => stringType
case DateType => "int"
case TimestampType => "java.sql.Timestamp"
case _ => "Object"
}
protected def defaultPrimitive(dt: DataType): String = dt match {
case BooleanType => "false"
case FloatType => "-1.0f"
case ShortType => "-1"
case LongType => "-1"
case ByteType => "-1"
case DoubleType => "-1.0"
case IntegerType => "-1"
case DateType => "-1"
case dt: DecimalType => "null"
case StringType => "null"
case _ => "null"
}
protected def termForType(dt: DataType): String = dt match {
case IntegerType => "Integer"
case LongType => "Long"
case ShortType => "Short"
case ByteType => "Byte"
case DoubleType => "Double"
case FloatType => "Float"
case BooleanType => "Boolean"
case dt: DecimalType => decimalType
case BinaryType => "byte[]"
case StringType => stringType
case DateType => "Integer"
case TimestampType => "java.sql.Timestamp"
case _ => "Object"
}
/**
* List of data types that have special accessors and setters in [[Row]].
*/
protected val nativeTypes =
Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType)
/**
* Returns true if the data type has a special accessor and setter in [[Row]].
*/
protected def isNativeType(dt: DataType) = nativeTypes.contains(dt)
}

View file

@ -37,13 +37,13 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
val ctx = newCodeGenContext()
val projectionCode = expressions.zipWithIndex.map { case (e, i) =>
val evaluationCode = expressionEvaluator(e, ctx)
val evaluationCode = e.gen(ctx)
evaluationCode.code +
s"""
if(${evaluationCode.nullTerm})
if(${evaluationCode.isNull})
mutableRow.setNullAt($i);
else
${setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)};
mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)};
"""
}.mkString("\n")
val code = s"""

View file

@ -52,15 +52,15 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
val ctx = newCodeGenContext()
val comparisons = ordering.zipWithIndex.map { case (order, i) =>
val evalA = expressionEvaluator(order.child, ctx)
val evalB = expressionEvaluator(order.child, ctx)
val evalA = order.child.gen(ctx)
val evalB = order.child.gen(ctx)
val asc = order.direction == Ascending
val compare = order.child.dataType match {
case BinaryType =>
s"""
{
byte[] x = ${if (asc) evalA.primitiveTerm else evalB.primitiveTerm};
byte[] y = ${if (!asc) evalB.primitiveTerm else evalA.primitiveTerm};
byte[] x = ${if (asc) evalA.primitive else evalB.primitive};
byte[] y = ${if (!asc) evalB.primitive else evalA.primitive};
int j = 0;
while (j < x.length && j < y.length) {
if (x[j] != y[j]) return x[j] - y[j];
@ -73,8 +73,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
}"""
case _: NumericType =>
s"""
if (${evalA.primitiveTerm} != ${evalB.primitiveTerm}) {
if (${evalA.primitiveTerm} > ${evalB.primitiveTerm}) {
if (${evalA.primitive} != ${evalB.primitive}) {
if (${evalA.primitive} > ${evalB.primitive}) {
return ${if (asc) "1" else "-1"};
} else {
return ${if (asc) "-1" else "1"};
@ -82,7 +82,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
}"""
case _ =>
s"""
int comp = ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm});
int comp = ${evalA.primitive}.compare(${evalB.primitive});
if (comp != 0) {
return ${if (asc) "comp" else "-comp"};
}"""
@ -93,11 +93,11 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
${evalA.code}
i = $b;
${evalB.code}
if (${evalA.nullTerm} && ${evalB.nullTerm}) {
if (${evalA.isNull} && ${evalB.isNull}) {
// Nothing
} else if (${evalA.nullTerm}) {
} else if (${evalA.isNull}) {
return ${if (order.direction == Ascending) "-1" else "1"};
} else if (${evalB.nullTerm}) {
} else if (${evalB.isNull}) {
return ${if (order.direction == Ascending) "1" else "-1"};
} else {
$compare

View file

@ -38,7 +38,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
protected def create(predicate: Expression): ((Row) => Boolean) = {
val ctx = newCodeGenContext()
val eval = expressionEvaluator(predicate, ctx)
val eval = predicate.gen(ctx)
val code = s"""
import org.apache.spark.sql.Row;
@ -55,7 +55,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
@Override
public boolean eval(Row i) {
${eval.code}
return !${eval.nullTerm} && ${eval.primitiveTerm};
return !${eval.isNull} && ${eval.primitive};
}
}"""

View file

@ -45,19 +45,19 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
val ctx = newCodeGenContext()
val columns = expressions.zipWithIndex.map {
case (e, i) =>
s"private ${primitiveForType(e.dataType)} c$i = ${defaultPrimitive(e.dataType)};\n"
s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n"
}.mkString("\n ")
val initColumns = expressions.zipWithIndex.map {
case (e, i) =>
val eval = expressionEvaluator(e, ctx)
val eval = e.gen(ctx)
s"""
{
// column$i
${eval.code}
nullBits[$i] = ${eval.nullTerm};
if(!${eval.nullTerm}) {
c$i = ${eval.primitiveTerm};
nullBits[$i] = ${eval.isNull};
if (!${eval.isNull}) {
c$i = ${eval.primitive};
}
}
"""
@ -68,10 +68,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}.mkString("\n ")
val updateCases = expressions.zipWithIndex.map { case (e, i) =>
s"case $i: { c$i = (${termForType(e.dataType)})value; return;}"
s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}"
}.mkString("\n ")
val specificAccessorFunctions = nativeTypes.map { dataType =>
val specificAccessorFunctions = ctx.nativeTypes.map { dataType =>
val cases = expressions.zipWithIndex.map {
case (e, i) if e.dataType == dataType =>
s"case $i: return c$i;"
@ -80,21 +80,21 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
if (cases.count(_ != '\n') > 0) {
s"""
@Override
public ${primitiveForType(dataType)} ${accessorForType(dataType)}(int i) {
public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) {
if (isNullAt(i)) {
return ${defaultPrimitive(dataType)};
return ${ctx.defaultValue(dataType)};
}
switch (i) {
$cases
}
return ${defaultPrimitive(dataType)};
return ${ctx.defaultValue(dataType)};
}"""
} else {
""
}
}.mkString("\n")
val specificMutatorFunctions = nativeTypes.map { dataType =>
val specificMutatorFunctions = ctx.nativeTypes.map { dataType =>
val cases = expressions.zipWithIndex.map {
case (e, i) if e.dataType == dataType =>
s"case $i: { c$i = value; return; }"
@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
if (cases.count(_ != '\n') > 0) {
s"""
@Override
public void ${mutatorForType(dataType)}(int i, ${primitiveForType(dataType)} value) {
public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) {
nullBits[i] = false;
switch (i) {
$cases
@ -122,7 +122,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
case LongType => s"$col ^ ($col >>> 32)"
case FloatType => s"Float.floatToIntBits($col)"
case DoubleType =>
s"Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32)"
s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))"
case _ => s"$col.hashCode()"
}
s"isNullAt($i) ? 0 : ($nonNull)"

View file

@ -27,6 +27,9 @@ import org.apache.spark.util.Utils
*/
package object codegen {
type Term = String
type Code = String
/** Canonicalizes an expression so those that differ only by names can reuse the same code. */
object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] {
val batches =

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.types._
/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */
@ -35,6 +36,10 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
childResult.asInstanceOf[Decimal].toUnscaledLong
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()")
}
}
/** Create a Decimal from an unscaled Long value */
@ -53,4 +58,18 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale)
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
${ctx.decimalType} ${ev.primitive} = null;
if (!${ev.isNull}) {
${ev.primitive} = (new ${ctx.decimalType}()).setOrNull(
${eval.primitive}, $precision, $scale);
${ev.isNull} = ${ev.primitive} == null;
}
"""
}
}

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
@ -78,7 +79,60 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
override def toString: String = if (value != null) value.toString else "null"
override def equals(other: Any): Boolean = other match {
case o: Literal =>
dataType.equals(o.dataType) &&
(value == null && null == o.value || value != null && value.equals(o.value))
case _ => false
}
override def eval(input: Row): Any = value
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
// change the isNull and primitive to consts, to inline them
if (value == null) {
ev.isNull = "true"
ev.primitive = ctx.defaultValue(dataType)
""
} else {
dataType match {
case BooleanType =>
ev.isNull = "false"
ev.primitive = value.toString
""
case FloatType => // This must go before NumericType
val v = value.asInstanceOf[Float]
if (v.isNaN || v.isInfinite) {
super.genCode(ctx, ev)
} else {
ev.isNull = "false"
ev.primitive = s"${value}f"
""
}
case DoubleType => // This must go before NumericType
val v = value.asInstanceOf[Double]
if (v.isNaN || v.isInfinite) {
super.genCode(ctx, ev)
} else {
ev.isNull = "false"
ev.primitive = s"${value}"
""
}
case ByteType | ShortType => // This must go before NumericType
ev.isNull = "false"
ev.primitive = s"(${ctx.javaType(dataType)})$value"
""
case dt: NumericType if !dt.isInstanceOf[DecimalType] =>
ev.isNull = "false"
ev.primitive = value.toString
""
// eval() version may be faster for non-primitive types
case other =>
super.genCode(ctx, ev)
}
}
}
}
// TODO: Specialize

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions.mathfuncs
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row}
import org.apache.spark.sql.types._
@ -49,6 +50,10 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
}
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)")
}
}
case class Atan2(left: Expression, right: Expression)
@ -70,9 +75,26 @@ case class Atan2(left: Expression, right: Expression)
}
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s"""
if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
}
"""
}
}
case class Hypot(left: Expression, right: Expression)
extends BinaryMathExpression(math.hypot, "HYPOT")
case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER")
case class Pow(left: Expression, right: Expression)
extends BinaryMathExpression(math.pow, "POWER") {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s"""
if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
}
"""
}
}

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions.mathfuncs
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression}
import org.apache.spark.sql.types._
@ -44,6 +45,23 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
if (result.isNaN) null else result
}
}
// name of function in java.lang.Math
def funcName: String = name.toLowerCase
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval = child.gen(ctx)
eval.code + s"""
boolean ${ev.isNull} = ${eval.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive});
if (Double.valueOf(${ev.primitive}).isNaN()) {
${ev.isNull} = true;
}
}
"""
}
}
case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS")
@ -72,7 +90,9 @@ case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG
case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P")
case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND")
case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") {
override def funcName: String = "rint"
}
case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM")
@ -84,6 +104,10 @@ case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN")
case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH")
case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES")
case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") {
override def funcName: String = "toDegrees"
}
case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS")
case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") {
override def funcName: String = "toRadians"
}

View file

@ -17,10 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees.LeafNode
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types._
object NamedExpression {
@ -116,6 +116,8 @@ case class Alias(child: Expression, name: String)(
override def eval(input: Row): Any = child.eval(input)
override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
override def dataType: DataType = child.dataType
override def nullable: Boolean = child.nullable
override def metadata: Metadata = {

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.types.DataType
@ -51,6 +52,25 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
result
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
""" +
children.map { e =>
val eval = e.gen(ctx)
s"""
if (${ev.isNull}) {
${eval.code}
if (!${eval.isNull}) {
${ev.isNull} = false;
${ev.primitive} = ${eval.primitive};
}
}
"""
}.mkString("\n")
}
}
case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
@ -61,6 +81,13 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr
child.eval(input) == null
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval = child.gen(ctx)
ev.isNull = "false"
ev.primitive = eval.isNull
eval.code
}
override def toString: String = s"IS NULL $child"
}
@ -72,6 +99,13 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E
override def eval(input: Row): Any = {
child.eval(input) != null
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval = child.gen(ctx)
ev.isNull = "false"
ev.primitive = s"(!(${eval.isNull}))"
eval.code
}
}
/**
@ -95,4 +129,25 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
}
numNonNulls >= n
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val nonnull = ctx.freshName("nonnull")
val code = children.map { e =>
val eval = e.gen(ctx)
s"""
if ($nonnull < $n) {
${eval.code}
if (!${eval.isNull}) {
$nonnull += 1;
}
}
"""
}.mkString("\n")
s"""
int $nonnull = 0;
$code
boolean ${ev.isNull} = false;
boolean ${ev.primitive} = $nonnull >= $n;
"""
}
}

View file

@ -18,9 +18,10 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types._
object InterpretedPredicate {
def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) =
@ -82,6 +83,10 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex
case b: Boolean => !b
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, c => s"!($c)")
}
}
/**
@ -141,6 +146,29 @@ case class And(left: Expression, right: Expression)
}
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
// The result should be `false`, if any of them is `false` whenever the other is null or not.
s"""
${eval1.code}
boolean ${ev.isNull} = false;
boolean ${ev.primitive} = false;
if (!${eval1.isNull} && !${eval1.primitive}) {
} else {
${eval2.code}
if (!${eval2.isNull} && !${eval2.primitive}) {
} else if (!${eval1.isNull} && !${eval2.isNull}) {
${ev.primitive} = true;
} else {
${ev.isNull} = true;
}
}
"""
}
}
case class Or(left: Expression, right: Expression)
@ -167,6 +195,29 @@ case class Or(left: Expression, right: Expression)
}
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
// The result should be `true`, if any of them is `true` whenever the other is null or not.
s"""
${eval1.code}
boolean ${ev.isNull} = false;
boolean ${ev.primitive} = true;
if (!${eval1.isNull} && ${eval1.primitive}) {
} else {
${eval2.code}
if (!${eval2.isNull} && ${eval2.primitive}) {
} else if (!${eval1.isNull} && !${eval2.isNull}) {
${ev.primitive} = false;
} else {
${ev.isNull} = true;
}
}
"""
}
}
abstract class BinaryComparison extends BinaryExpression with Predicate {
@ -198,6 +249,20 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
left.dataType match {
case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, {
(c1, c3) => s"$c1 $symbol $c3"
})
case TimestampType =>
// java.sql.Timestamp does not have compare()
super.genCode(ctx, ev)
case other => defineCodeGen (ctx, ev, {
(c1, c2) => s"$c1.compare($c2) $symbol 0"
})
}
}
protected def evalInternal(evalE1: Any, evalE2: Any): Any =
sys.error(s"BinaryComparisons must override either eval or evalInternal")
}
@ -215,6 +280,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
if (left.dataType != BinaryType) l == r
else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType))
}
}
case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison {
@ -235,6 +303,17 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
l == r
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive)
ev.isNull = "false"
eval1.code + eval2.code + s"""
boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) ||
(!${eval1.isNull} && $equalCode);
"""
}
}
case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
@ -309,6 +388,27 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val condEval = predicate.gen(ctx)
val trueEval = trueValue.gen(ctx)
val falseEval = falseValue.gen(ctx)
s"""
${condEval.code}
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${condEval.isNull} && ${condEval.primitive}) {
${trueEval.code}
${ev.isNull} = ${trueEval.isNull};
${ev.primitive} = ${trueEval.primitive};
} else {
${falseEval.code}
${ev.isNull} = ${falseEval.isNull};
${ev.primitive} = ${falseEval.primitive};
}
"""
}
override def toString: String = s"if ($predicate) $trueValue else $falseValue"
}
@ -393,6 +493,48 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
return res
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val len = branchesArr.length
val got = ctx.freshName("got")
val cases = (0 until len/2).map { i =>
val cond = branchesArr(i * 2).gen(ctx)
val res = branchesArr(i * 2 + 1).gen(ctx)
s"""
if (!$got) {
${cond.code}
if (!${cond.isNull} && ${cond.primitive}) {
$got = true;
${res.code}
${ev.isNull} = ${res.isNull};
${ev.primitive} = ${res.primitive};
}
}
"""
}.mkString("\n")
val other = if (len % 2 == 1) {
val res = branchesArr(len - 1).gen(ctx)
s"""
if (!$got) {
${res.code}
${ev.isNull} = ${res.isNull};
${ev.primitive} = ${res.primitive};
}
"""
} else {
""
}
s"""
boolean $got = false;
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
$cases
$other
"""
}
override def toString: String = {
"CASE" + branches.sliding(2, 2).map {
case Seq(cond, value) => s" WHEN $cond THEN $value"
@ -444,6 +586,52 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
return res
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val keyEval = key.gen(ctx)
val len = branchesArr.length
val got = ctx.freshName("got")
val cases = (0 until len/2).map { i =>
val cond = branchesArr(i * 2).gen(ctx)
val res = branchesArr(i * 2 + 1).gen(ctx)
s"""
if (!$got) {
${cond.code}
if (${keyEval.isNull} && ${cond.isNull} ||
!${keyEval.isNull} && !${cond.isNull}
&& ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) {
$got = true;
${res.code}
${ev.isNull} = ${res.isNull};
${ev.primitive} = ${res.primitive};
}
}
"""
}.mkString("\n")
val other = if (len % 2 == 1) {
val res = branchesArr(len - 1).gen(ctx)
s"""
if (!$got) {
${res.code}
${ev.isNull} = ${res.isNull};
${ev.primitive} = ${res.primitive};
}
"""
} else {
""
}
s"""
boolean $got = false;
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
${keyEval.code}
$cases
$other
"""
}
private def equalNullSafe(l: Any, r: Any) = {
if (l == null && r == null) {
true

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
@ -60,6 +61,17 @@ case class NewSet(elementType: DataType) extends LeafExpression {
new OpenHashSet[Any]()
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
elementType match {
case IntegerType | LongType =>
ev.isNull = "false"
s"""
${ctx.javaType(dataType)} ${ev.primitive} = new ${ctx.javaType(dataType)}();
"""
case _ => super.genCode(ctx, ev)
}
}
override def toString: String = s"new Set($dataType)"
}
@ -91,6 +103,25 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType
elementType match {
case IntegerType | LongType =>
val itemEval = item.gen(ctx)
val setEval = set.gen(ctx)
val htype = ctx.javaType(dataType)
ev.isNull = "false"
ev.primitive = setEval.primitive
itemEval.code + setEval.code + s"""
if (!${itemEval.isNull} && !${setEval.isNull}) {
(($htype)${setEval.primitive}).add(${itemEval.primitive});
}
"""
case _ => super.genCode(ctx, ev)
}
}
override def toString: String = s"$set += $item"
}
@ -116,12 +147,29 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
val rightValue = iterator.next()
leftEval.add(rightValue)
}
}
leftEval
} else {
null
}
} else {
null
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType
elementType match {
case IntegerType | LongType =>
val leftEval = left.gen(ctx)
val rightEval = right.gen(ctx)
val htype = ctx.javaType(dataType)
ev.isNull = leftEval.isNull
ev.primitive = leftEval.primitive
leftEval.code + rightEval.code + s"""
if (!${leftEval.isNull} && !${rightEval.isNull}) {
${leftEval.primitive}.union((${htype})${rightEval.primitive});
}
"""
case _ => super.genCode(ctx, ev)
}
}
}

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.regex.Pattern
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
trait StringRegexExpression extends ExpectsInputTypes {
@ -137,6 +138,10 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
override def convert(v: UTF8String): UTF8String = v.toUpperCase()
override def toString: String = s"Upper($child)"
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, c => s"($c).toUpperCase()")
}
}
/**
@ -147,6 +152,10 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE
override def convert(v: UTF8String): UTF8String = v.toLowerCase()
override def toString: String = s"Lower($child)"
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, c => s"($c).toLowerCase()")
}
}
/** A base trait for functions that compare two strings, returning a boolean. */
@ -181,6 +190,9 @@ trait StringComparison extends ExpectsInputTypes {
case class Contains(left: Expression, right: Expression)
extends BinaryExpression with Predicate with StringComparison {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)")
}
}
/**
@ -189,6 +201,9 @@ case class Contains(left: Expression, right: Expression)
case class StartsWith(left: Expression, right: Expression)
extends BinaryExpression with Predicate with StringComparison {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)")
}
}
/**
@ -197,6 +212,9 @@ case class StartsWith(left: Expression, right: Expression)
case class EndsWith(left: Expression, right: Expression)
extends BinaryExpression with Predicate with StringComparison {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)")
}
}
/**

View file

@ -28,6 +28,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection}
import org.apache.spark.sql.catalyst.expressions.mathfuncs._
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
@ -35,11 +36,20 @@ import org.apache.spark.sql.types._
class ExpressionEvaluationBaseSuite extends SparkFunSuite {
def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = {
checkEvaluationWithoutCodegen(expression, expected, inputRow)
checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow)
checkEvaluationWithGeneratedProjection(expression, expected, inputRow)
}
def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = {
expression.eval(inputRow)
}
def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = {
def checkEvaluationWithoutCodegen(
expression: Expression,
expected: Any,
inputRow: Row = EmptyRow): Unit = {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
@ -49,6 +59,68 @@ class ExpressionEvaluationBaseSuite extends SparkFunSuite {
}
}
def checkEvaluationWithGeneratedMutableProjection(
expression: Expression,
expected: Any,
inputRow: Row = EmptyRow): Unit = {
val plan = try {
GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)()
} catch {
case e: Throwable =>
val ctx = GenerateProjection.newCodeGenContext()
val evaluated = expression.gen(ctx)
fail(
s"""
|Code generation of $expression failed:
|${evaluated.code}
|$e
""".stripMargin)
}
val actual = plan(inputRow).apply(0)
if (actual != expected) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
}
}
def checkEvaluationWithGeneratedProjection(
expression: Expression,
expected: Any,
inputRow: Row = EmptyRow): Unit = {
val ctx = GenerateProjection.newCodeGenContext()
lazy val evaluated = expression.gen(ctx)
val plan = try {
GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)
} catch {
case e: Throwable =>
fail(
s"""
|Code generation of $expression failed:
|${evaluated.code}
|$e
""".stripMargin)
}
val actual = plan(inputRow)
val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected)))
if (actual.hashCode() != expectedRow.hashCode()) {
fail(
s"""
|Mismatched hashCodes for values: $actual, $expectedRow
|Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()}
|Expressions: ${expression}
|Code: ${evaluated}
""".stripMargin)
}
if (actual != expectedRow) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
}
}
def checkDoubleEvaluation(
expression: Expression,
expected: Spread[Double],
@ -69,8 +141,16 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
test("literals") {
checkEvaluation(Literal(1), 1)
checkEvaluation(Literal(true), true)
checkEvaluation(Literal(false), false)
checkEvaluation(Literal(0L), 0L)
List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach {
d => {
checkEvaluation(Literal(d), d)
checkEvaluation(Literal(d.toFloat), d.toFloat)
}
}
checkEvaluation(Literal("test"), "test")
checkEvaluation(Literal.create(null, StringType), null)
checkEvaluation(Literal(1) + Literal(1), 2)
}
@ -1367,6 +1447,11 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
// TODO: Make the tests work with codegen.
class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite {
override def checkEvaluation(
expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = {
checkEvaluationWithoutCodegen(expression, expected, inputRow)
}
test("CreateStruct") {
val row = Row(1, 2, 3)
val c1 = 'a.int.at(0).as("a")

View file

@ -21,34 +21,9 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
/**
* Overrides our expression evaluation tests to use code generation for evaluation.
* Additional tests for code generation.
*/
class GeneratedEvaluationSuite extends ExpressionEvaluationSuite {
override def checkEvaluation(
expression: Expression,
expected: Any,
inputRow: Row = EmptyRow): Unit = {
val plan = try {
GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)()
} catch {
case e: Throwable =>
val ctx = GenerateProjection.newCodeGenContext()
val evaluated = GenerateProjection.expressionEvaluator(expression, ctx)
fail(
s"""
|Code generation of $expression failed:
|${evaluated.code}
|$e
""".stripMargin)
}
val actual = plan(inputRow).apply(0)
if (actual != expected) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
}
}
test("multithreaded eval") {
import scala.concurrent._

View file

@ -1,61 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen._
/**
* Overrides our expression evaluation tests to use generated code on mutable rows.
*/
class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
override def checkEvaluation(
expression: Expression,
expected: Any,
inputRow: Row = EmptyRow): Unit = {
val ctx = GenerateProjection.newCodeGenContext()
lazy val evaluated = GenerateProjection.expressionEvaluator(expression, ctx)
val plan = try {
GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)
} catch {
case e: Throwable =>
fail(
s"""
|Code generation of $expression failed:
|${evaluated.code}
|$e
""".stripMargin)
}
val actual = plan(inputRow)
val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected)))
if (actual.hashCode() != expectedRow.hashCode()) {
fail(
s"""
|Mismatched hashCodes for values: $actual, $expectedRow
|Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()}
|${evaluated.code}
""".stripMargin)
}
if (actual != expectedRow) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
}
}
}

View file

@ -53,7 +53,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
check("10", Literal.create(10, IntegerType))
check("1000000000000000", Literal.create(1000000000000000L, LongType))
check("1.5", Literal.create(1.5, FloatType))
check("1.5", Literal.create(1.5f, FloatType))
check("hello", Literal.create("hello", StringType))
check(defaultPartitionName, Literal.create(null, NullType))
}
@ -83,13 +83,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
ArrayBuffer(
Literal.create(10, IntegerType),
Literal.create("hello", StringType),
Literal.create(1.5, FloatType)))
Literal.create(1.5f, FloatType)))
})
check("file://path/a=10/b_hello/c=1.5", Some {
PartitionValues(
ArrayBuffer("c"),
ArrayBuffer(Literal.create(1.5, FloatType)))
ArrayBuffer(Literal.create(1.5f, FloatType)))
})
check("file:///", None)