[SPARK-8117] [SQL] Push codegen implementation into each Expression
This PR move codegen implementation of expressions into Expression class itself, make it easy to manage. It introduces two APIs in Expression: ``` def gen(ctx: CodeGenContext): GeneratedExpressionCode def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code ``` gen(ctx) will call genSource(ctx, ev) to generate Java source code for the current expression. A expression needs to override genSource(). Here are the types: ``` type Term String type Code String /** * Java source for evaluating an [[Expression]] given a [[Row]] of input. */ case class GeneratedExpressionCode(var code: Code, nullTerm: Term, primitiveTerm: Term, objectTerm: Term) /** * A context for codegen, which is used to bookkeeping the expressions those are not supported * by codegen, then they are evaluated directly. The unsupported expression is appended at the * end of `references`, the position of it is kept in the code, used to access and evaluate it. */ class CodeGenContext { /** * Holding all the expressions those do not support codegen, will be evaluated directly. */ val references: Seq[Expression] = new mutable.ArrayBuffer[Expression]() } ``` This is basically #6660, but fixed style violation and compilation failure. Author: Davies Liu <davies@databricks.com> Author: Reynold Xin <rxin@databricks.com> Closes #6690 from rxin/codegen and squashes the following commits: e1368c2 [Reynold Xin] Fixed tests. 73db80e [Reynold Xin] Fixed compilation failure. 19d6435 [Reynold Xin] Fixed style violation. 9adaeaf [Davies Liu] address comments f42c732 [Davies Liu] improve coverage and tests bad6828 [Davies Liu] address comments e03edaa [Davies Liu] consts fold 86fac2c [Davies Liu] fix style 02262c9 [Davies Liu] address comments b5d3617 [Davies Liu] Merge pull request #5 from rxin/codegen 48c454f [Reynold Xin] Some code gen update. 2344bc0 [Davies Liu] fix test 12ff88a [Davies Liu] fix build c5fb514 [Davies Liu] rename 8c6d82d [Davies Liu] update docs b145047 [Davies Liu] fix style e57959d [Davies Liu] add type alias 3ff25f8 [Davies Liu] refactor 593d617 [Davies Liu] pushing codegen into Expression
This commit is contained in:
parent
b127ff8a0c
commit
5e7b6b67be
|
@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.sql.catalyst.errors.attachTree
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.catalyst.trees
|
||||
|
||||
|
@ -41,6 +42,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
|
|||
override def qualifiers: Seq[String] = throw new UnsupportedOperationException
|
||||
|
||||
override def exprId: ExprId = throw new UnsupportedOperationException
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
s"""
|
||||
boolean ${ev.isNull} = i.isNullAt($ordinal);
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ?
|
||||
${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)});
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
object BindReferences extends Logging {
|
||||
|
|
|
@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
|
|||
import java.text.{DateFormat, SimpleDateFormat}
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
|
||||
import org.apache.spark.sql.catalyst.util.DateUtils
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -433,6 +434,47 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
|
|||
val evaluated = child.eval(input)
|
||||
if (evaluated == null) null else cast(evaluated)
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
// TODO(cg): Add support for more data types.
|
||||
(child.dataType, dataType) match {
|
||||
|
||||
case (BinaryType, StringType) =>
|
||||
defineCodeGen (ctx, ev, c =>
|
||||
s"new ${ctx.stringType}().set($c)")
|
||||
case (DateType, StringType) =>
|
||||
defineCodeGen(ctx, ev, c =>
|
||||
s"""new ${ctx.stringType}().set(
|
||||
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""")
|
||||
// Special handling required for timestamps in hive test cases since the toString function
|
||||
// does not match the expected output.
|
||||
case (TimestampType, StringType) =>
|
||||
super.genCode(ctx, ev)
|
||||
case (_, StringType) =>
|
||||
defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))")
|
||||
|
||||
// fallback for DecimalType, this must be before other numeric types
|
||||
case (_, dt: DecimalType) =>
|
||||
super.genCode(ctx, ev)
|
||||
|
||||
case (BooleanType, dt: NumericType) =>
|
||||
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
|
||||
case (dt: DecimalType, BooleanType) =>
|
||||
defineCodeGen(ctx, ev, c => s"$c.isZero()")
|
||||
case (dt: NumericType, BooleanType) =>
|
||||
defineCodeGen(ctx, ev, c => s"$c != 0")
|
||||
|
||||
case (_: DecimalType, IntegerType) =>
|
||||
defineCodeGen(ctx, ev, c => s"($c).toInt()")
|
||||
case (_: DecimalType, dt: NumericType) =>
|
||||
defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
|
||||
case (_: NumericType, dt: NumericType) =>
|
||||
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)")
|
||||
|
||||
case other =>
|
||||
super.genCode(ctx, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object Cast {
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext, Term}
|
||||
import org.apache.spark.sql.catalyst.trees
|
||||
import org.apache.spark.sql.catalyst.trees.TreeNode
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -51,6 +52,44 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
/** Returns the result of evaluating this expression on a given input Row */
|
||||
def eval(input: Row = null): Any
|
||||
|
||||
/**
|
||||
* Returns an [[GeneratedExpressionCode]], which contains Java source code that
|
||||
* can be used to generate the result of evaluating the expression on an input row.
|
||||
*
|
||||
* @param ctx a [[CodeGenContext]]
|
||||
* @return [[GeneratedExpressionCode]]
|
||||
*/
|
||||
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
|
||||
val isNull = ctx.freshName("isNull")
|
||||
val primitive = ctx.freshName("primitive")
|
||||
val ve = GeneratedExpressionCode("", isNull, primitive)
|
||||
ve.code = genCode(ctx, ve)
|
||||
ve
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns Java source code that can be compiled to evaluate this expression.
|
||||
* The default behavior is to call the eval method of the expression. Concrete expression
|
||||
* implementations should override this to do actual code generation.
|
||||
*
|
||||
* @param ctx a [[CodeGenContext]]
|
||||
* @param ev an [[GeneratedExpressionCode]] with unique terms.
|
||||
* @return Java source code
|
||||
*/
|
||||
protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
ctx.references += this
|
||||
val objectTerm = ctx.freshName("obj")
|
||||
s"""
|
||||
/* expression: ${this} */
|
||||
Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
|
||||
boolean ${ev.isNull} = ${objectTerm} == null;
|
||||
${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
${ev.primitive} = (${ctx.boxedType(this.dataType)})${objectTerm};
|
||||
}
|
||||
"""
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns `true` if this expression and all its children have been resolved to a specific schema
|
||||
* and input data types checking passed, and `false` if it still contains any unresolved
|
||||
|
@ -116,6 +155,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
|
|||
override def nullable: Boolean = left.nullable || right.nullable
|
||||
|
||||
override def toString: String = s"($left $symbol $right)"
|
||||
|
||||
/**
|
||||
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of
|
||||
* the same type. If either of the sub-expressions is null, the result of this computation
|
||||
* is assumed to be null.
|
||||
*
|
||||
* @param f accepts two variable names and returns Java code to compute the output.
|
||||
*/
|
||||
protected def defineCodeGen(
|
||||
ctx: CodeGenContext,
|
||||
ev: GeneratedExpressionCode,
|
||||
f: (Term, Term) => Code): String = {
|
||||
// TODO: Right now some timestamp tests fail if we enforce this...
|
||||
if (left.dataType != right.dataType) {
|
||||
// log.warn(s"${left.dataType} != ${right.dataType}")
|
||||
}
|
||||
|
||||
val eval1 = left.gen(ctx)
|
||||
val eval2 = right.gen(ctx)
|
||||
val resultCode = f(eval1.primitive, eval2.primitive)
|
||||
|
||||
s"""
|
||||
${eval1.code}
|
||||
boolean ${ev.isNull} = ${eval1.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
${eval2.code}
|
||||
if(!${eval2.isNull}) {
|
||||
${ev.primitive} = $resultCode;
|
||||
} else {
|
||||
${ev.isNull} = true;
|
||||
}
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] object BinaryExpression {
|
||||
|
@ -128,6 +202,32 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
|
|||
|
||||
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
|
||||
self: Product =>
|
||||
|
||||
/**
|
||||
* Called by unary expressions to generate a code block that returns null if its parent returns
|
||||
* null, and if not not null, use `f` to generate the expression.
|
||||
*
|
||||
* As an example, the following does a boolean inversion (i.e. NOT).
|
||||
* {{{
|
||||
* defineCodeGen(ctx, ev, c => s"!($c)")
|
||||
* }}}
|
||||
*
|
||||
* @param f function that accepts a variable name and returns Java code to compute the output.
|
||||
*/
|
||||
protected def defineCodeGen(
|
||||
ctx: CodeGenContext,
|
||||
ev: GeneratedExpressionCode,
|
||||
f: Term => Code): Code = {
|
||||
val eval = child.gen(ctx)
|
||||
// reuse the previous isNull
|
||||
ev.isNull = eval.isNull
|
||||
eval.code + s"""
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
${ev.primitive} = ${f(eval.primitive)};
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
// TODO Semantically we probably not need GroupExpression
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, GeneratedExpressionCode, CodeGenContext}
|
||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -49,6 +50,11 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {
|
|||
|
||||
private lazy val numeric = TypeUtils.getNumeric(dataType)
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
|
||||
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()")
|
||||
case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)")
|
||||
}
|
||||
|
||||
protected override def evalInternal(evalE: Any) = numeric.negate(evalE)
|
||||
}
|
||||
|
||||
|
@ -67,6 +73,21 @@ case class Sqrt(child: Expression) extends UnaryArithmetic {
|
|||
if (value < 0) null
|
||||
else math.sqrt(value)
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
val eval = child.gen(ctx)
|
||||
eval.code + s"""
|
||||
boolean ${ev.isNull} = ${eval.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
if (${eval.primitive} < 0.0) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
${ev.primitive} = java.lang.Math.sqrt(${eval.primitive});
|
||||
}
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -86,6 +107,9 @@ case class Abs(child: Expression) extends UnaryArithmetic {
|
|||
abstract class BinaryArithmetic extends BinaryExpression {
|
||||
self: Product =>
|
||||
|
||||
/** Name of the function for this expression on a [[Decimal]] type. */
|
||||
def decimalMethod: String = ""
|
||||
|
||||
override def dataType: DataType = left.dataType
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
|
@ -114,6 +138,17 @@ abstract class BinaryArithmetic extends BinaryExpression {
|
|||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
|
||||
case dt: DecimalType =>
|
||||
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
|
||||
// byte and short are casted into int when add, minus, times or divide
|
||||
case ByteType | ShortType =>
|
||||
defineCodeGen(ctx, ev, (eval1, eval2) =>
|
||||
s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
|
||||
case _ =>
|
||||
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
|
||||
}
|
||||
|
||||
protected def evalInternal(evalE1: Any, evalE2: Any): Any =
|
||||
sys.error(s"BinaryArithmetics must override either eval or evalInternal")
|
||||
}
|
||||
|
@ -124,6 +159,7 @@ private[sql] object BinaryArithmetic {
|
|||
|
||||
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
|
||||
override def symbol: String = "+"
|
||||
override def decimalMethod: String = "$plus"
|
||||
|
||||
override lazy val resolved =
|
||||
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
|
||||
|
@ -138,6 +174,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
|
|||
|
||||
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
|
||||
override def symbol: String = "-"
|
||||
override def decimalMethod: String = "$minus"
|
||||
|
||||
override lazy val resolved =
|
||||
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
|
||||
|
@ -152,6 +189,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
|
|||
|
||||
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
|
||||
override def symbol: String = "*"
|
||||
override def decimalMethod: String = "$times"
|
||||
|
||||
override lazy val resolved =
|
||||
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
|
||||
|
@ -166,6 +204,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
|
|||
|
||||
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
|
||||
override def symbol: String = "/"
|
||||
override def decimalMethod: String = "$divide"
|
||||
|
||||
override def nullable: Boolean = true
|
||||
|
||||
override lazy val resolved =
|
||||
|
@ -192,10 +232,40 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Special case handling due to division by 0 => null.
|
||||
*/
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
val eval1 = left.gen(ctx)
|
||||
val eval2 = right.gen(ctx)
|
||||
val test = if (left.dataType.isInstanceOf[DecimalType]) {
|
||||
s"${eval2.primitive}.isZero()"
|
||||
} else {
|
||||
s"${eval2.primitive} == 0"
|
||||
}
|
||||
val method = if (left.dataType.isInstanceOf[DecimalType]) {
|
||||
s".$decimalMethod"
|
||||
} else {
|
||||
s"$symbol"
|
||||
}
|
||||
eval1.code + eval2.code +
|
||||
s"""
|
||||
boolean ${ev.isNull} = false;
|
||||
${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
|
||||
if (${eval1.isNull} || ${eval2.isNull} || $test) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
|
||||
override def symbol: String = "%"
|
||||
override def decimalMethod: String = "reminder"
|
||||
|
||||
override def nullable: Boolean = true
|
||||
|
||||
override lazy val resolved =
|
||||
|
@ -222,6 +292,34 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Special case handling for x % 0 ==> null.
|
||||
*/
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
val eval1 = left.gen(ctx)
|
||||
val eval2 = right.gen(ctx)
|
||||
val test = if (left.dataType.isInstanceOf[DecimalType]) {
|
||||
s"${eval2.primitive}.isZero()"
|
||||
} else {
|
||||
s"${eval2.primitive} == 0"
|
||||
}
|
||||
val method = if (left.dataType.isInstanceOf[DecimalType]) {
|
||||
s".$decimalMethod"
|
||||
} else {
|
||||
s"$symbol"
|
||||
}
|
||||
eval1.code + eval2.code +
|
||||
s"""
|
||||
boolean ${ev.isNull} = false;
|
||||
${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
|
||||
if (${eval1.isNull} || ${eval2.isNull} || $test) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -271,7 +369,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
|
|||
}
|
||||
|
||||
/**
|
||||
* A function that calculates bitwise xor(^) of two numbers.
|
||||
* A function that calculates bitwise xor of two numbers.
|
||||
*/
|
||||
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
|
||||
override def symbol: String = "^"
|
||||
|
@ -313,6 +411,10 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic {
|
|||
((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any]
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)})~($c)")
|
||||
}
|
||||
|
||||
protected override def evalInternal(evalE: Any) = not(evalE)
|
||||
}
|
||||
|
||||
|
@ -340,6 +442,33 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
|
|||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
if (ctx.isNativeType(left.dataType)) {
|
||||
val eval1 = left.gen(ctx)
|
||||
val eval2 = right.gen(ctx)
|
||||
eval1.code + eval2.code + s"""
|
||||
boolean ${ev.isNull} = false;
|
||||
${ctx.javaType(left.dataType)} ${ev.primitive} =
|
||||
${ctx.defaultValue(left.dataType)};
|
||||
|
||||
if (${eval1.isNull}) {
|
||||
${ev.isNull} = ${eval2.isNull};
|
||||
${ev.primitive} = ${eval2.primitive};
|
||||
} else if (${eval2.isNull}) {
|
||||
${ev.isNull} = ${eval1.isNull};
|
||||
${ev.primitive} = ${eval1.primitive};
|
||||
} else {
|
||||
if (${eval1.primitive} > ${eval2.primitive}) {
|
||||
${ev.primitive} = ${eval1.primitive};
|
||||
} else {
|
||||
${ev.primitive} = ${eval2.primitive};
|
||||
}
|
||||
}
|
||||
"""
|
||||
} else {
|
||||
super.genCode(ctx, ev)
|
||||
}
|
||||
}
|
||||
override def toString: String = s"MaxOf($left, $right)"
|
||||
}
|
||||
|
||||
|
@ -367,5 +496,35 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
|
|||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
if (ctx.isNativeType(left.dataType)) {
|
||||
|
||||
val eval1 = left.gen(ctx)
|
||||
val eval2 = right.gen(ctx)
|
||||
|
||||
eval1.code + eval2.code + s"""
|
||||
boolean ${ev.isNull} = false;
|
||||
${ctx.javaType(left.dataType)} ${ev.primitive} =
|
||||
${ctx.defaultValue(left.dataType)};
|
||||
|
||||
if (${eval1.isNull}) {
|
||||
${ev.isNull} = ${eval2.isNull};
|
||||
${ev.primitive} = ${eval2.primitive};
|
||||
} else if (${eval2.isNull}) {
|
||||
${ev.isNull} = ${eval1.isNull};
|
||||
${ev.primitive} = ${eval1.primitive};
|
||||
} else {
|
||||
if (${eval1.primitive} < ${eval2.primitive}) {
|
||||
${ev.primitive} = ${eval1.primitive};
|
||||
} else {
|
||||
${ev.primitive} = ${eval2.primitive};
|
||||
}
|
||||
}
|
||||
"""
|
||||
} else {
|
||||
super.genCode(ctx, ev)
|
||||
}
|
||||
}
|
||||
|
||||
override def toString: String = s"MinOf($left, $right)"
|
||||
}
|
||||
|
|
|
@ -24,7 +24,6 @@ import com.google.common.cache.{CacheBuilder, CacheLoader}
|
|||
import org.codehaus.janino.ClassBodyEvaluator
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.sql.catalyst.expressions
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -32,6 +31,157 @@ import org.apache.spark.sql.types._
|
|||
class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int]
|
||||
class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
|
||||
|
||||
/**
|
||||
* Java source for evaluating an [[Expression]] given a [[Row]] of input.
|
||||
*
|
||||
* @param code The sequence of statements required to evaluate the expression.
|
||||
* @param isNull A term that holds a boolean value representing whether the expression evaluated
|
||||
* to null.
|
||||
* @param primitive A term for a possible primitive value of the result of the evaluation. Not
|
||||
* valid if `isNull` is set to `true`.
|
||||
*/
|
||||
case class GeneratedExpressionCode(var code: Code, var isNull: Term, var primitive: Term)
|
||||
|
||||
/**
|
||||
* A context for codegen, which is used to bookkeeping the expressions those are not supported
|
||||
* by codegen, then they are evaluated directly. The unsupported expression is appended at the
|
||||
* end of `references`, the position of it is kept in the code, used to access and evaluate it.
|
||||
*/
|
||||
class CodeGenContext {
|
||||
|
||||
/**
|
||||
* Holding all the expressions those do not support codegen, will be evaluated directly.
|
||||
*/
|
||||
val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]()
|
||||
|
||||
val stringType: String = classOf[UTF8String].getName
|
||||
val decimalType: String = classOf[Decimal].getName
|
||||
|
||||
private val curId = new java.util.concurrent.atomic.AtomicInteger()
|
||||
|
||||
/**
|
||||
* Returns a term name that is unique within this instance of a `CodeGenerator`.
|
||||
*
|
||||
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
|
||||
* function.)
|
||||
*/
|
||||
def freshName(prefix: String): Term = {
|
||||
s"$prefix${curId.getAndIncrement}"
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the code to access a column for given DataType
|
||||
*/
|
||||
def getColumn(dataType: DataType, ordinal: Int): Code = {
|
||||
if (isNativeType(dataType)) {
|
||||
s"i.${accessorForType(dataType)}($ordinal)"
|
||||
} else {
|
||||
s"(${boxedType(dataType)})i.apply($ordinal)"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the code to update a column in Row for given DataType
|
||||
*/
|
||||
def setColumn(dataType: DataType, ordinal: Int, value: Term): Code = {
|
||||
if (isNativeType(dataType)) {
|
||||
s"${mutatorForType(dataType)}($ordinal, $value)"
|
||||
} else {
|
||||
s"update($ordinal, $value)"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the name of accessor in Row for a DataType
|
||||
*/
|
||||
def accessorForType(dt: DataType): Term = dt match {
|
||||
case IntegerType => "getInt"
|
||||
case other => s"get${boxedType(dt)}"
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the name of mutator in Row for a DataType
|
||||
*/
|
||||
def mutatorForType(dt: DataType): Term = dt match {
|
||||
case IntegerType => "setInt"
|
||||
case other => s"set${boxedType(dt)}"
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the Java type for a DataType
|
||||
*/
|
||||
def javaType(dt: DataType): Term = dt match {
|
||||
case IntegerType => "int"
|
||||
case LongType => "long"
|
||||
case ShortType => "short"
|
||||
case ByteType => "byte"
|
||||
case DoubleType => "double"
|
||||
case FloatType => "float"
|
||||
case BooleanType => "boolean"
|
||||
case dt: DecimalType => decimalType
|
||||
case BinaryType => "byte[]"
|
||||
case StringType => stringType
|
||||
case DateType => "int"
|
||||
case TimestampType => "java.sql.Timestamp"
|
||||
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
|
||||
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
|
||||
case _ => "Object"
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the boxed type in Java
|
||||
*/
|
||||
def boxedType(dt: DataType): Term = dt match {
|
||||
case IntegerType => "Integer"
|
||||
case LongType => "Long"
|
||||
case ShortType => "Short"
|
||||
case ByteType => "Byte"
|
||||
case DoubleType => "Double"
|
||||
case FloatType => "Float"
|
||||
case BooleanType => "Boolean"
|
||||
case DateType => "Integer"
|
||||
case _ => javaType(dt)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the representation of default value for given DataType
|
||||
*/
|
||||
def defaultValue(dt: DataType): Term = dt match {
|
||||
case BooleanType => "false"
|
||||
case FloatType => "-1.0f"
|
||||
case ShortType => "(short)-1"
|
||||
case LongType => "-1L"
|
||||
case ByteType => "(byte)-1"
|
||||
case DoubleType => "-1.0"
|
||||
case IntegerType => "-1"
|
||||
case DateType => "-1"
|
||||
case _ => "null"
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a function to generate equal expression in Java
|
||||
*/
|
||||
def equalFunc(dataType: DataType): ((Term, Term) => Code) = dataType match {
|
||||
case BinaryType => { case (eval1, eval2) =>
|
||||
s"java.util.Arrays.equals($eval1, $eval2)" }
|
||||
case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType =>
|
||||
{ case (eval1, eval2) => s"$eval1 == $eval2" }
|
||||
case other =>
|
||||
{ case (eval1, eval2) => s"$eval1.equals($eval2)" }
|
||||
}
|
||||
|
||||
/**
|
||||
* List of data types that have special accessors and setters in [[Row]].
|
||||
*/
|
||||
val nativeTypes =
|
||||
Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType)
|
||||
|
||||
/**
|
||||
* Returns true if the data type has a special accessor and setter in [[Row]].
|
||||
*/
|
||||
def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt)
|
||||
}
|
||||
|
||||
/**
|
||||
* A base class for generators of byte code to perform expression evaluation. Includes a set of
|
||||
* helpers for referring to Catalyst types and building trees that perform evaluation of individual
|
||||
|
@ -39,14 +189,9 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
|
|||
*/
|
||||
abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Logging {
|
||||
|
||||
protected val rowType = classOf[Row].getName
|
||||
protected val stringType = classOf[UTF8String].getName
|
||||
protected val decimalType = classOf[Decimal].getName
|
||||
protected val exprType = classOf[Expression].getName
|
||||
protected val mutableRowType = classOf[MutableRow].getName
|
||||
protected val genericMutableRowType = classOf[GenericMutableRow].getName
|
||||
|
||||
private val curId = new java.util.concurrent.atomic.AtomicInteger()
|
||||
protected val exprType: String = classOf[Expression].getName
|
||||
protected val mutableRowType: String = classOf[MutableRow].getName
|
||||
protected val genericMutableRowType: String = classOf[GenericMutableRow].getName
|
||||
|
||||
/**
|
||||
* Can be flipped on manually in the console to add (expensive) expression evaluation trace code.
|
||||
|
@ -75,10 +220,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
|
|||
*/
|
||||
protected def compile(code: String): Class[_] = {
|
||||
val startTime = System.nanoTime()
|
||||
val clazz = new ClassBodyEvaluator(code).getClazz()
|
||||
val clazz = try {
|
||||
new ClassBodyEvaluator(code).getClazz()
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError(s"failed to compile:\n $code", e)
|
||||
throw e
|
||||
}
|
||||
val endTime = System.nanoTime()
|
||||
def timeMs: Double = (endTime - startTime).toDouble / 1000000
|
||||
logDebug(s"Compiled Java code (${code.size} bytes) in $timeMs ms")
|
||||
logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms")
|
||||
clazz
|
||||
}
|
||||
|
||||
|
@ -112,586 +263,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
|
|||
/** Generates the requested evaluator given already bound expression(s). */
|
||||
def generate(expressions: InType): OutType = cache.get(canonicalize(expressions))
|
||||
|
||||
/**
|
||||
* Returns a term name that is unique within this instance of a `CodeGenerator`.
|
||||
*
|
||||
* (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
|
||||
* function.)
|
||||
*/
|
||||
protected def freshName(prefix: String): String = {
|
||||
s"$prefix${curId.getAndIncrement}"
|
||||
}
|
||||
|
||||
/**
|
||||
* Scala ASTs for evaluating an [[Expression]] given a [[Row]] of input.
|
||||
*
|
||||
* @param code The sequence of statements required to evaluate the expression.
|
||||
* @param nullTerm A term that holds a boolean value representing whether the expression evaluated
|
||||
* to null.
|
||||
* @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not
|
||||
* valid if `nullTerm` is set to `true`.
|
||||
* @param objectTerm A possibly boxed version of the result of evaluating this expression.
|
||||
*/
|
||||
protected case class EvaluatedExpression(
|
||||
code: String,
|
||||
nullTerm: String,
|
||||
primitiveTerm: String,
|
||||
objectTerm: String)
|
||||
|
||||
/**
|
||||
* A context for codegen, which is used to bookkeeping the expressions those are not supported
|
||||
* by codegen, then they are evaluated directly. The unsupported expression is appended at the
|
||||
* end of `references`, the position of it is kept in the code, used to access and evaluate it.
|
||||
*/
|
||||
protected class CodeGenContext {
|
||||
/**
|
||||
* Holding all the expressions those do not support codegen, will be evaluated directly.
|
||||
*/
|
||||
val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]()
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new codegen context for expression evaluator, used to store those
|
||||
* expressions that don't support codegen
|
||||
*/
|
||||
def newCodeGenContext(): CodeGenContext = {
|
||||
new CodeGenContext()
|
||||
new CodeGenContext
|
||||
}
|
||||
|
||||
/**
|
||||
* Given an expression tree returns an [[EvaluatedExpression]], which contains Scala trees that
|
||||
* can be used to determine the result of evaluating the expression on an input row.
|
||||
*/
|
||||
def expressionEvaluator(e: Expression, ctx: CodeGenContext): EvaluatedExpression = {
|
||||
val primitiveTerm = freshName("primitiveTerm")
|
||||
val nullTerm = freshName("nullTerm")
|
||||
val objectTerm = freshName("objectTerm")
|
||||
|
||||
implicit class Evaluate1(e: Expression) {
|
||||
def castOrNull(f: String => String, dataType: DataType): String = {
|
||||
val eval = expressionEvaluator(e, ctx)
|
||||
eval.code +
|
||||
s"""
|
||||
boolean $nullTerm = ${eval.nullTerm};
|
||||
${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)};
|
||||
if (!$nullTerm) {
|
||||
$primitiveTerm = ${f(eval.primitiveTerm)};
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
implicit class Evaluate2(expressions: (Expression, Expression)) {
|
||||
|
||||
/**
|
||||
* Short hand for generating binary evaluation code, which depends on two sub-evaluations of
|
||||
* the same type. If either of the sub-expressions is null, the result of this computation
|
||||
* is assumed to be null.
|
||||
*
|
||||
* @param f a function from two primitive term names to a tree that evaluates them.
|
||||
*/
|
||||
def evaluate(f: (String, String) => String): String =
|
||||
evaluateAs(expressions._1.dataType)(f)
|
||||
|
||||
def evaluateAs(resultType: DataType)(f: (String, String) => String): String = {
|
||||
// TODO: Right now some timestamp tests fail if we enforce this...
|
||||
if (expressions._1.dataType != expressions._2.dataType) {
|
||||
log.warn(s"${expressions._1.dataType} != ${expressions._2.dataType}")
|
||||
}
|
||||
|
||||
val eval1 = expressionEvaluator(expressions._1, ctx)
|
||||
val eval2 = expressionEvaluator(expressions._2, ctx)
|
||||
val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)
|
||||
|
||||
eval1.code + eval2.code +
|
||||
s"""
|
||||
boolean $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm};
|
||||
${primitiveForType(resultType)} $primitiveTerm = ${defaultPrimitive(resultType)};
|
||||
if(!$nullTerm) {
|
||||
$primitiveTerm = (${primitiveForType(resultType)})($resultCode);
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
val inputTuple = "i"
|
||||
|
||||
// TODO: Skip generation of null handling code when expression are not nullable.
|
||||
val primitiveEvaluation: PartialFunction[Expression, String] = {
|
||||
case b @ BoundReference(ordinal, dataType, nullable) =>
|
||||
s"""
|
||||
final boolean $nullTerm = $inputTuple.isNullAt($ordinal);
|
||||
final ${primitiveForType(dataType)} $primitiveTerm = $nullTerm ?
|
||||
${defaultPrimitive(dataType)} : (${getColumn(inputTuple, dataType, ordinal)});
|
||||
"""
|
||||
|
||||
case expressions.Literal(null, dataType) =>
|
||||
s"""
|
||||
final boolean $nullTerm = true;
|
||||
${primitiveForType(dataType)} $primitiveTerm = ${defaultPrimitive(dataType)};
|
||||
"""
|
||||
|
||||
case expressions.Literal(value: UTF8String, StringType) =>
|
||||
val arr = s"new byte[]{${value.getBytes.map(_.toString).mkString(", ")}}"
|
||||
s"""
|
||||
final boolean $nullTerm = false;
|
||||
${stringType} $primitiveTerm =
|
||||
new ${stringType}().set(${arr});
|
||||
"""
|
||||
|
||||
case expressions.Literal(value, FloatType) =>
|
||||
s"""
|
||||
final boolean $nullTerm = false;
|
||||
float $primitiveTerm = ${value}f;
|
||||
"""
|
||||
|
||||
case expressions.Literal(value, dt @ DecimalType()) =>
|
||||
s"""
|
||||
final boolean $nullTerm = false;
|
||||
${primitiveForType(dt)} $primitiveTerm = new ${primitiveForType(dt)}().set($value);
|
||||
"""
|
||||
|
||||
case expressions.Literal(value, dataType) =>
|
||||
s"""
|
||||
final boolean $nullTerm = false;
|
||||
${primitiveForType(dataType)} $primitiveTerm = $value;
|
||||
"""
|
||||
|
||||
case Cast(child @ BinaryType(), StringType) =>
|
||||
child.castOrNull(c =>
|
||||
s"new ${stringType}().set($c)",
|
||||
StringType)
|
||||
|
||||
case Cast(child @ DateType(), StringType) =>
|
||||
child.castOrNull(c =>
|
||||
s"""new ${stringType}().set(
|
||||
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""",
|
||||
StringType)
|
||||
|
||||
case Cast(child @ BooleanType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
|
||||
child.castOrNull(c => s"(${primitiveForType(dt)})($c?1:0)", dt)
|
||||
|
||||
case Cast(child @ DecimalType(), IntegerType) =>
|
||||
child.castOrNull(c => s"($c).toInt()", IntegerType)
|
||||
|
||||
case Cast(child @ DecimalType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
|
||||
child.castOrNull(c => s"($c).to${termForType(dt)}()", dt)
|
||||
|
||||
case Cast(child @ NumericType(), dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
|
||||
child.castOrNull(c => s"(${primitiveForType(dt)})($c)", dt)
|
||||
|
||||
// Special handling required for timestamps in hive test cases since the toString function
|
||||
// does not match the expected output.
|
||||
case Cast(e, StringType) if e.dataType != TimestampType =>
|
||||
e.castOrNull(c =>
|
||||
s"new ${stringType}().set(String.valueOf($c))",
|
||||
StringType)
|
||||
|
||||
case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) =>
|
||||
(e1, e2).evaluateAs (BooleanType) {
|
||||
case (eval1, eval2) =>
|
||||
s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)"
|
||||
}
|
||||
|
||||
case EqualTo(e1, e2) =>
|
||||
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 == $eval2" }
|
||||
|
||||
case GreaterThan(e1 @ NumericType(), e2 @ NumericType()) =>
|
||||
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 > $eval2" }
|
||||
case GreaterThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
|
||||
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 >= $eval2" }
|
||||
case LessThan(e1 @ NumericType(), e2 @ NumericType()) =>
|
||||
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 < $eval2" }
|
||||
case LessThanOrEqual(e1 @ NumericType(), e2 @ NumericType()) =>
|
||||
(e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => s"$eval1 <= $eval2" }
|
||||
|
||||
case And(e1, e2) =>
|
||||
val eval1 = expressionEvaluator(e1, ctx)
|
||||
val eval2 = expressionEvaluator(e2, ctx)
|
||||
s"""
|
||||
${eval1.code}
|
||||
boolean $nullTerm = false;
|
||||
boolean $primitiveTerm = false;
|
||||
|
||||
if (!${eval1.nullTerm} && !${eval1.primitiveTerm}) {
|
||||
} else {
|
||||
${eval2.code}
|
||||
if (!${eval2.nullTerm} && !${eval2.primitiveTerm}) {
|
||||
} else if (!${eval1.nullTerm} && !${eval2.nullTerm}) {
|
||||
$primitiveTerm = true;
|
||||
} else {
|
||||
$nullTerm = true;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
case Or(e1, e2) =>
|
||||
val eval1 = expressionEvaluator(e1, ctx)
|
||||
val eval2 = expressionEvaluator(e2, ctx)
|
||||
|
||||
s"""
|
||||
${eval1.code}
|
||||
boolean $nullTerm = false;
|
||||
boolean $primitiveTerm = false;
|
||||
|
||||
if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) {
|
||||
$primitiveTerm = true;
|
||||
} else {
|
||||
${eval2.code}
|
||||
if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) {
|
||||
$primitiveTerm = true;
|
||||
} else if (!${eval1.nullTerm} && !${eval2.nullTerm}) {
|
||||
$primitiveTerm = false;
|
||||
} else {
|
||||
$nullTerm = true;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
case Not(child) =>
|
||||
// Uh, bad function name...
|
||||
child.castOrNull(c => s"!$c", BooleanType)
|
||||
|
||||
case Add(e1 @ DecimalType(), e2 @ DecimalType()) =>
|
||||
(e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$plus($eval2)" }
|
||||
case Subtract(e1 @ DecimalType(), e2 @ DecimalType()) =>
|
||||
(e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$minus($eval2)" }
|
||||
case Multiply(e1 @ DecimalType(), e2 @ DecimalType()) =>
|
||||
(e1, e2) evaluate { case (eval1, eval2) => s"$eval1.$$times($eval2)" }
|
||||
case Divide(e1 @ DecimalType(), e2 @ DecimalType()) =>
|
||||
val eval1 = expressionEvaluator(e1, ctx)
|
||||
val eval2 = expressionEvaluator(e2, ctx)
|
||||
eval1.code + eval2.code +
|
||||
s"""
|
||||
boolean $nullTerm = false;
|
||||
${primitiveForType(e1.dataType)} $primitiveTerm = null;
|
||||
if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) {
|
||||
$nullTerm = true;
|
||||
} else {
|
||||
$primitiveTerm = ${eval1.primitiveTerm}.$$div${eval2.primitiveTerm});
|
||||
}
|
||||
"""
|
||||
case Remainder(e1 @ DecimalType(), e2 @ DecimalType()) =>
|
||||
val eval1 = expressionEvaluator(e1, ctx)
|
||||
val eval2 = expressionEvaluator(e2, ctx)
|
||||
eval1.code + eval2.code +
|
||||
s"""
|
||||
boolean $nullTerm = false;
|
||||
${primitiveForType(e1.dataType)} $primitiveTerm = 0;
|
||||
if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm}.isZero()) {
|
||||
$nullTerm = true;
|
||||
} else {
|
||||
$primitiveTerm = ${eval1.primitiveTerm}.remainder(${eval2.primitiveTerm});
|
||||
}
|
||||
"""
|
||||
|
||||
case Add(e1, e2) =>
|
||||
(e1, e2) evaluate { case (eval1, eval2) => s"$eval1 + $eval2" }
|
||||
case Subtract(e1, e2) =>
|
||||
(e1, e2) evaluate { case (eval1, eval2) => s"$eval1 - $eval2" }
|
||||
case Multiply(e1, e2) =>
|
||||
(e1, e2) evaluate { case (eval1, eval2) => s"$eval1 * $eval2" }
|
||||
case Divide(e1, e2) =>
|
||||
val eval1 = expressionEvaluator(e1, ctx)
|
||||
val eval2 = expressionEvaluator(e2, ctx)
|
||||
eval1.code + eval2.code +
|
||||
s"""
|
||||
boolean $nullTerm = false;
|
||||
${primitiveForType(e1.dataType)} $primitiveTerm = 0;
|
||||
if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) {
|
||||
$nullTerm = true;
|
||||
} else {
|
||||
$primitiveTerm = ${eval1.primitiveTerm} / ${eval2.primitiveTerm};
|
||||
}
|
||||
"""
|
||||
case Remainder(e1, e2) =>
|
||||
val eval1 = expressionEvaluator(e1, ctx)
|
||||
val eval2 = expressionEvaluator(e2, ctx)
|
||||
eval1.code + eval2.code +
|
||||
s"""
|
||||
boolean $nullTerm = false;
|
||||
${primitiveForType(e1.dataType)} $primitiveTerm = 0;
|
||||
if (${eval1.nullTerm} || ${eval2.nullTerm} || ${eval2.primitiveTerm} == 0) {
|
||||
$nullTerm = true;
|
||||
} else {
|
||||
$primitiveTerm = ${eval1.primitiveTerm} % ${eval2.primitiveTerm};
|
||||
}
|
||||
"""
|
||||
|
||||
case IsNotNull(e) =>
|
||||
val eval = expressionEvaluator(e, ctx)
|
||||
s"""
|
||||
${eval.code}
|
||||
boolean $nullTerm = false;
|
||||
boolean $primitiveTerm = !${eval.nullTerm};
|
||||
"""
|
||||
|
||||
case IsNull(e) =>
|
||||
val eval = expressionEvaluator(e, ctx)
|
||||
s"""
|
||||
${eval.code}
|
||||
boolean $nullTerm = false;
|
||||
boolean $primitiveTerm = ${eval.nullTerm};
|
||||
"""
|
||||
|
||||
case e @ Coalesce(children) =>
|
||||
s"""
|
||||
boolean $nullTerm = true;
|
||||
${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)};
|
||||
""" +
|
||||
children.map { c =>
|
||||
val eval = expressionEvaluator(c, ctx)
|
||||
s"""
|
||||
if($nullTerm) {
|
||||
${eval.code}
|
||||
if(!${eval.nullTerm}) {
|
||||
$nullTerm = false;
|
||||
$primitiveTerm = ${eval.primitiveTerm};
|
||||
}
|
||||
}
|
||||
"""
|
||||
}.mkString("\n")
|
||||
|
||||
case e @ expressions.If(condition, trueValue, falseValue) =>
|
||||
val condEval = expressionEvaluator(condition, ctx)
|
||||
val trueEval = expressionEvaluator(trueValue, ctx)
|
||||
val falseEval = expressionEvaluator(falseValue, ctx)
|
||||
|
||||
s"""
|
||||
boolean $nullTerm = false;
|
||||
${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)};
|
||||
${condEval.code}
|
||||
if(!${condEval.nullTerm} && ${condEval.primitiveTerm}) {
|
||||
${trueEval.code}
|
||||
$nullTerm = ${trueEval.nullTerm};
|
||||
$primitiveTerm = ${trueEval.primitiveTerm};
|
||||
} else {
|
||||
${falseEval.code}
|
||||
$nullTerm = ${falseEval.nullTerm};
|
||||
$primitiveTerm = ${falseEval.primitiveTerm};
|
||||
}
|
||||
"""
|
||||
|
||||
case NewSet(elementType) =>
|
||||
s"""
|
||||
boolean $nullTerm = false;
|
||||
${hashSetForType(elementType)} $primitiveTerm = new ${hashSetForType(elementType)}();
|
||||
"""
|
||||
|
||||
case AddItemToSet(item, set) =>
|
||||
val itemEval = expressionEvaluator(item, ctx)
|
||||
val setEval = expressionEvaluator(set, ctx)
|
||||
|
||||
val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType
|
||||
val htype = hashSetForType(elementType)
|
||||
|
||||
itemEval.code + setEval.code +
|
||||
s"""
|
||||
if (!${itemEval.nullTerm} && !${setEval.nullTerm}) {
|
||||
(($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm});
|
||||
}
|
||||
boolean $nullTerm = false;
|
||||
${htype} $primitiveTerm = ($htype)${setEval.primitiveTerm};
|
||||
"""
|
||||
|
||||
case CombineSets(left, right) =>
|
||||
val leftEval = expressionEvaluator(left, ctx)
|
||||
val rightEval = expressionEvaluator(right, ctx)
|
||||
|
||||
val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType
|
||||
val htype = hashSetForType(elementType)
|
||||
|
||||
leftEval.code + rightEval.code +
|
||||
s"""
|
||||
boolean $nullTerm = false;
|
||||
${htype} $primitiveTerm =
|
||||
(${htype})${leftEval.primitiveTerm};
|
||||
$primitiveTerm.union((${htype})${rightEval.primitiveTerm});
|
||||
"""
|
||||
|
||||
case MaxOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] =>
|
||||
val eval1 = expressionEvaluator(e1, ctx)
|
||||
val eval2 = expressionEvaluator(e2, ctx)
|
||||
|
||||
eval1.code + eval2.code +
|
||||
s"""
|
||||
boolean $nullTerm = false;
|
||||
${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)};
|
||||
|
||||
if (${eval1.nullTerm}) {
|
||||
$nullTerm = ${eval2.nullTerm};
|
||||
$primitiveTerm = ${eval2.primitiveTerm};
|
||||
} else if (${eval2.nullTerm}) {
|
||||
$nullTerm = ${eval1.nullTerm};
|
||||
$primitiveTerm = ${eval1.primitiveTerm};
|
||||
} else {
|
||||
if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) {
|
||||
$primitiveTerm = ${eval1.primitiveTerm};
|
||||
} else {
|
||||
$primitiveTerm = ${eval2.primitiveTerm};
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
case MinOf(e1, e2) if !e1.dataType.isInstanceOf[DecimalType] =>
|
||||
val eval1 = expressionEvaluator(e1, ctx)
|
||||
val eval2 = expressionEvaluator(e2, ctx)
|
||||
|
||||
eval1.code + eval2.code +
|
||||
s"""
|
||||
boolean $nullTerm = false;
|
||||
${primitiveForType(e1.dataType)} $primitiveTerm = ${defaultPrimitive(e1.dataType)};
|
||||
|
||||
if (${eval1.nullTerm}) {
|
||||
$nullTerm = ${eval2.nullTerm};
|
||||
$primitiveTerm = ${eval2.primitiveTerm};
|
||||
} else if (${eval2.nullTerm}) {
|
||||
$nullTerm = ${eval1.nullTerm};
|
||||
$primitiveTerm = ${eval1.primitiveTerm};
|
||||
} else {
|
||||
if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) {
|
||||
$primitiveTerm = ${eval1.primitiveTerm};
|
||||
} else {
|
||||
$primitiveTerm = ${eval2.primitiveTerm};
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
case UnscaledValue(child) =>
|
||||
val childEval = expressionEvaluator(child, ctx)
|
||||
|
||||
childEval.code +
|
||||
s"""
|
||||
boolean $nullTerm = ${childEval.nullTerm};
|
||||
long $primitiveTerm = $nullTerm ? -1 : ${childEval.primitiveTerm}.toUnscaledLong();
|
||||
"""
|
||||
|
||||
case MakeDecimal(child, precision, scale) =>
|
||||
val eval = expressionEvaluator(child, ctx)
|
||||
|
||||
eval.code +
|
||||
s"""
|
||||
boolean $nullTerm = ${eval.nullTerm};
|
||||
org.apache.spark.sql.types.Decimal $primitiveTerm = ${defaultPrimitive(DecimalType())};
|
||||
|
||||
if (!$nullTerm) {
|
||||
$primitiveTerm = new org.apache.spark.sql.types.Decimal();
|
||||
$primitiveTerm = $primitiveTerm.setOrNull(${eval.primitiveTerm}, $precision, $scale);
|
||||
$nullTerm = $primitiveTerm == null;
|
||||
}
|
||||
"""
|
||||
}
|
||||
|
||||
// If there was no match in the partial function above, we fall back on calling the interpreted
|
||||
// expression evaluator.
|
||||
val code: String =
|
||||
primitiveEvaluation.lift.apply(e).getOrElse {
|
||||
logError(s"No rules to generate $e")
|
||||
ctx.references += e
|
||||
s"""
|
||||
/* expression: ${e} */
|
||||
Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i);
|
||||
boolean $nullTerm = $objectTerm == null;
|
||||
${primitiveForType(e.dataType)} $primitiveTerm = ${defaultPrimitive(e.dataType)};
|
||||
if (!$nullTerm) $primitiveTerm = (${termForType(e.dataType)})$objectTerm;
|
||||
"""
|
||||
}
|
||||
|
||||
EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm)
|
||||
}
|
||||
|
||||
protected def getColumn(inputRow: String, dataType: DataType, ordinal: Int) = {
|
||||
dataType match {
|
||||
case StringType => s"(${stringType})$inputRow.apply($ordinal)"
|
||||
case dt: DataType if isNativeType(dt) => s"$inputRow.${accessorForType(dt)}($ordinal)"
|
||||
case _ => s"(${termForType(dataType)})$inputRow.apply($ordinal)"
|
||||
}
|
||||
}
|
||||
|
||||
protected def setColumn(
|
||||
destinationRow: String,
|
||||
dataType: DataType,
|
||||
ordinal: Int,
|
||||
value: String): String = {
|
||||
dataType match {
|
||||
case StringType => s"$destinationRow.update($ordinal, $value)"
|
||||
case dt: DataType if isNativeType(dt) =>
|
||||
s"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
|
||||
case _ => s"$destinationRow.update($ordinal, $value)"
|
||||
}
|
||||
}
|
||||
|
||||
protected def accessorForType(dt: DataType) = dt match {
|
||||
case IntegerType => "getInt"
|
||||
case other => s"get${termForType(dt)}"
|
||||
}
|
||||
|
||||
protected def mutatorForType(dt: DataType) = dt match {
|
||||
case IntegerType => "setInt"
|
||||
case other => s"set${termForType(dt)}"
|
||||
}
|
||||
|
||||
protected def hashSetForType(dt: DataType): String = dt match {
|
||||
case IntegerType => classOf[IntegerHashSet].getName
|
||||
case LongType => classOf[LongHashSet].getName
|
||||
case unsupportedType =>
|
||||
sys.error(s"Code generation not support for hashset of type $unsupportedType")
|
||||
}
|
||||
|
||||
protected def primitiveForType(dt: DataType): String = dt match {
|
||||
case IntegerType => "int"
|
||||
case LongType => "long"
|
||||
case ShortType => "short"
|
||||
case ByteType => "byte"
|
||||
case DoubleType => "double"
|
||||
case FloatType => "float"
|
||||
case BooleanType => "boolean"
|
||||
case dt: DecimalType => decimalType
|
||||
case BinaryType => "byte[]"
|
||||
case StringType => stringType
|
||||
case DateType => "int"
|
||||
case TimestampType => "java.sql.Timestamp"
|
||||
case _ => "Object"
|
||||
}
|
||||
|
||||
protected def defaultPrimitive(dt: DataType): String = dt match {
|
||||
case BooleanType => "false"
|
||||
case FloatType => "-1.0f"
|
||||
case ShortType => "-1"
|
||||
case LongType => "-1"
|
||||
case ByteType => "-1"
|
||||
case DoubleType => "-1.0"
|
||||
case IntegerType => "-1"
|
||||
case DateType => "-1"
|
||||
case dt: DecimalType => "null"
|
||||
case StringType => "null"
|
||||
case _ => "null"
|
||||
}
|
||||
|
||||
protected def termForType(dt: DataType): String = dt match {
|
||||
case IntegerType => "Integer"
|
||||
case LongType => "Long"
|
||||
case ShortType => "Short"
|
||||
case ByteType => "Byte"
|
||||
case DoubleType => "Double"
|
||||
case FloatType => "Float"
|
||||
case BooleanType => "Boolean"
|
||||
case dt: DecimalType => decimalType
|
||||
case BinaryType => "byte[]"
|
||||
case StringType => stringType
|
||||
case DateType => "Integer"
|
||||
case TimestampType => "java.sql.Timestamp"
|
||||
case _ => "Object"
|
||||
}
|
||||
|
||||
/**
|
||||
* List of data types that have special accessors and setters in [[Row]].
|
||||
*/
|
||||
protected val nativeTypes =
|
||||
Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType)
|
||||
|
||||
/**
|
||||
* Returns true if the data type has a special accessor and setter in [[Row]].
|
||||
*/
|
||||
protected def isNativeType(dt: DataType) = nativeTypes.contains(dt)
|
||||
}
|
||||
|
|
|
@ -37,13 +37,13 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
|
|||
protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
|
||||
val ctx = newCodeGenContext()
|
||||
val projectionCode = expressions.zipWithIndex.map { case (e, i) =>
|
||||
val evaluationCode = expressionEvaluator(e, ctx)
|
||||
val evaluationCode = e.gen(ctx)
|
||||
evaluationCode.code +
|
||||
s"""
|
||||
if(${evaluationCode.nullTerm})
|
||||
if(${evaluationCode.isNull})
|
||||
mutableRow.setNullAt($i);
|
||||
else
|
||||
${setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)};
|
||||
mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)};
|
||||
"""
|
||||
}.mkString("\n")
|
||||
val code = s"""
|
||||
|
|
|
@ -52,15 +52,15 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
|
|||
val ctx = newCodeGenContext()
|
||||
|
||||
val comparisons = ordering.zipWithIndex.map { case (order, i) =>
|
||||
val evalA = expressionEvaluator(order.child, ctx)
|
||||
val evalB = expressionEvaluator(order.child, ctx)
|
||||
val evalA = order.child.gen(ctx)
|
||||
val evalB = order.child.gen(ctx)
|
||||
val asc = order.direction == Ascending
|
||||
val compare = order.child.dataType match {
|
||||
case BinaryType =>
|
||||
s"""
|
||||
{
|
||||
byte[] x = ${if (asc) evalA.primitiveTerm else evalB.primitiveTerm};
|
||||
byte[] y = ${if (!asc) evalB.primitiveTerm else evalA.primitiveTerm};
|
||||
byte[] x = ${if (asc) evalA.primitive else evalB.primitive};
|
||||
byte[] y = ${if (!asc) evalB.primitive else evalA.primitive};
|
||||
int j = 0;
|
||||
while (j < x.length && j < y.length) {
|
||||
if (x[j] != y[j]) return x[j] - y[j];
|
||||
|
@ -73,8 +73,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
|
|||
}"""
|
||||
case _: NumericType =>
|
||||
s"""
|
||||
if (${evalA.primitiveTerm} != ${evalB.primitiveTerm}) {
|
||||
if (${evalA.primitiveTerm} > ${evalB.primitiveTerm}) {
|
||||
if (${evalA.primitive} != ${evalB.primitive}) {
|
||||
if (${evalA.primitive} > ${evalB.primitive}) {
|
||||
return ${if (asc) "1" else "-1"};
|
||||
} else {
|
||||
return ${if (asc) "-1" else "1"};
|
||||
|
@ -82,7 +82,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
|
|||
}"""
|
||||
case _ =>
|
||||
s"""
|
||||
int comp = ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm});
|
||||
int comp = ${evalA.primitive}.compare(${evalB.primitive});
|
||||
if (comp != 0) {
|
||||
return ${if (asc) "comp" else "-comp"};
|
||||
}"""
|
||||
|
@ -93,11 +93,11 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
|
|||
${evalA.code}
|
||||
i = $b;
|
||||
${evalB.code}
|
||||
if (${evalA.nullTerm} && ${evalB.nullTerm}) {
|
||||
if (${evalA.isNull} && ${evalB.isNull}) {
|
||||
// Nothing
|
||||
} else if (${evalA.nullTerm}) {
|
||||
} else if (${evalA.isNull}) {
|
||||
return ${if (order.direction == Ascending) "-1" else "1"};
|
||||
} else if (${evalB.nullTerm}) {
|
||||
} else if (${evalB.isNull}) {
|
||||
return ${if (order.direction == Ascending) "1" else "-1"};
|
||||
} else {
|
||||
$compare
|
||||
|
|
|
@ -38,7 +38,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
|
|||
|
||||
protected def create(predicate: Expression): ((Row) => Boolean) = {
|
||||
val ctx = newCodeGenContext()
|
||||
val eval = expressionEvaluator(predicate, ctx)
|
||||
val eval = predicate.gen(ctx)
|
||||
val code = s"""
|
||||
import org.apache.spark.sql.Row;
|
||||
|
||||
|
@ -55,7 +55,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
|
|||
@Override
|
||||
public boolean eval(Row i) {
|
||||
${eval.code}
|
||||
return !${eval.nullTerm} && ${eval.primitiveTerm};
|
||||
return !${eval.isNull} && ${eval.primitive};
|
||||
}
|
||||
}"""
|
||||
|
||||
|
|
|
@ -45,19 +45,19 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
|
|||
val ctx = newCodeGenContext()
|
||||
val columns = expressions.zipWithIndex.map {
|
||||
case (e, i) =>
|
||||
s"private ${primitiveForType(e.dataType)} c$i = ${defaultPrimitive(e.dataType)};\n"
|
||||
s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n"
|
||||
}.mkString("\n ")
|
||||
|
||||
val initColumns = expressions.zipWithIndex.map {
|
||||
case (e, i) =>
|
||||
val eval = expressionEvaluator(e, ctx)
|
||||
val eval = e.gen(ctx)
|
||||
s"""
|
||||
{
|
||||
// column$i
|
||||
${eval.code}
|
||||
nullBits[$i] = ${eval.nullTerm};
|
||||
if(!${eval.nullTerm}) {
|
||||
c$i = ${eval.primitiveTerm};
|
||||
nullBits[$i] = ${eval.isNull};
|
||||
if (!${eval.isNull}) {
|
||||
c$i = ${eval.primitive};
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
@ -68,10 +68,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
|
|||
}.mkString("\n ")
|
||||
|
||||
val updateCases = expressions.zipWithIndex.map { case (e, i) =>
|
||||
s"case $i: { c$i = (${termForType(e.dataType)})value; return;}"
|
||||
s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}"
|
||||
}.mkString("\n ")
|
||||
|
||||
val specificAccessorFunctions = nativeTypes.map { dataType =>
|
||||
val specificAccessorFunctions = ctx.nativeTypes.map { dataType =>
|
||||
val cases = expressions.zipWithIndex.map {
|
||||
case (e, i) if e.dataType == dataType =>
|
||||
s"case $i: return c$i;"
|
||||
|
@ -80,21 +80,21 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
|
|||
if (cases.count(_ != '\n') > 0) {
|
||||
s"""
|
||||
@Override
|
||||
public ${primitiveForType(dataType)} ${accessorForType(dataType)}(int i) {
|
||||
public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) {
|
||||
if (isNullAt(i)) {
|
||||
return ${defaultPrimitive(dataType)};
|
||||
return ${ctx.defaultValue(dataType)};
|
||||
}
|
||||
switch (i) {
|
||||
$cases
|
||||
}
|
||||
return ${defaultPrimitive(dataType)};
|
||||
return ${ctx.defaultValue(dataType)};
|
||||
}"""
|
||||
} else {
|
||||
""
|
||||
}
|
||||
}.mkString("\n")
|
||||
|
||||
val specificMutatorFunctions = nativeTypes.map { dataType =>
|
||||
val specificMutatorFunctions = ctx.nativeTypes.map { dataType =>
|
||||
val cases = expressions.zipWithIndex.map {
|
||||
case (e, i) if e.dataType == dataType =>
|
||||
s"case $i: { c$i = value; return; }"
|
||||
|
@ -103,7 +103,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
|
|||
if (cases.count(_ != '\n') > 0) {
|
||||
s"""
|
||||
@Override
|
||||
public void ${mutatorForType(dataType)}(int i, ${primitiveForType(dataType)} value) {
|
||||
public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) {
|
||||
nullBits[i] = false;
|
||||
switch (i) {
|
||||
$cases
|
||||
|
@ -122,7 +122,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
|
|||
case LongType => s"$col ^ ($col >>> 32)"
|
||||
case FloatType => s"Float.floatToIntBits($col)"
|
||||
case DoubleType =>
|
||||
s"Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32)"
|
||||
s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))"
|
||||
case _ => s"$col.hashCode()"
|
||||
}
|
||||
s"isNullAt($i) ? 0 : ($nonNull)"
|
||||
|
|
|
@ -27,6 +27,9 @@ import org.apache.spark.util.Utils
|
|||
*/
|
||||
package object codegen {
|
||||
|
||||
type Term = String
|
||||
type Code = String
|
||||
|
||||
/** Canonicalizes an expression so those that differ only by names can reuse the same code. */
|
||||
object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] {
|
||||
val batches =
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */
|
||||
|
@ -35,6 +36,10 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
|
|||
childResult.asInstanceOf[Decimal].toUnscaledLong
|
||||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()")
|
||||
}
|
||||
}
|
||||
|
||||
/** Create a Decimal from an unscaled Long value */
|
||||
|
@ -53,4 +58,18 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
|
|||
new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale)
|
||||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
val eval = child.gen(ctx)
|
||||
eval.code + s"""
|
||||
boolean ${ev.isNull} = ${eval.isNull};
|
||||
${ctx.decimalType} ${ev.primitive} = null;
|
||||
|
||||
if (!${ev.isNull}) {
|
||||
${ev.primitive} = (new ${ctx.decimalType}()).setOrNull(
|
||||
${eval.primitive}, $precision, $scale);
|
||||
${ev.isNull} = ${ev.primitive} == null;
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
import java.sql.{Date, Timestamp}
|
||||
|
||||
import org.apache.spark.sql.catalyst.CatalystTypeConverters
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode}
|
||||
import org.apache.spark.sql.catalyst.util.DateUtils
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -78,7 +79,60 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
|
|||
|
||||
override def toString: String = if (value != null) value.toString else "null"
|
||||
|
||||
override def equals(other: Any): Boolean = other match {
|
||||
case o: Literal =>
|
||||
dataType.equals(o.dataType) &&
|
||||
(value == null && null == o.value || value != null && value.equals(o.value))
|
||||
case _ => false
|
||||
}
|
||||
|
||||
override def eval(input: Row): Any = value
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
// change the isNull and primitive to consts, to inline them
|
||||
if (value == null) {
|
||||
ev.isNull = "true"
|
||||
ev.primitive = ctx.defaultValue(dataType)
|
||||
""
|
||||
} else {
|
||||
dataType match {
|
||||
case BooleanType =>
|
||||
ev.isNull = "false"
|
||||
ev.primitive = value.toString
|
||||
""
|
||||
case FloatType => // This must go before NumericType
|
||||
val v = value.asInstanceOf[Float]
|
||||
if (v.isNaN || v.isInfinite) {
|
||||
super.genCode(ctx, ev)
|
||||
} else {
|
||||
ev.isNull = "false"
|
||||
ev.primitive = s"${value}f"
|
||||
""
|
||||
}
|
||||
case DoubleType => // This must go before NumericType
|
||||
val v = value.asInstanceOf[Double]
|
||||
if (v.isNaN || v.isInfinite) {
|
||||
super.genCode(ctx, ev)
|
||||
} else {
|
||||
ev.isNull = "false"
|
||||
ev.primitive = s"${value}"
|
||||
""
|
||||
}
|
||||
|
||||
case ByteType | ShortType => // This must go before NumericType
|
||||
ev.isNull = "false"
|
||||
ev.primitive = s"(${ctx.javaType(dataType)})$value"
|
||||
""
|
||||
case dt: NumericType if !dt.isInstanceOf[DecimalType] =>
|
||||
ev.isNull = "false"
|
||||
ev.primitive = value.toString
|
||||
""
|
||||
// eval() version may be faster for non-primitive types
|
||||
case other =>
|
||||
super.genCode(ctx, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Specialize
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions.mathfuncs
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -49,6 +50,10 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)")
|
||||
}
|
||||
}
|
||||
|
||||
case class Atan2(left: Expression, right: Expression)
|
||||
|
@ -70,9 +75,26 @@ case class Atan2(left: Expression, right: Expression)
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s"""
|
||||
if (Double.valueOf(${ev.primitive}).isNaN()) {
|
||||
${ev.isNull} = true;
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
case class Hypot(left: Expression, right: Expression)
|
||||
extends BinaryMathExpression(math.hypot, "HYPOT")
|
||||
|
||||
case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER")
|
||||
case class Pow(left: Expression, right: Expression)
|
||||
extends BinaryMathExpression(math.pow, "POWER") {
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s"""
|
||||
if (Double.valueOf(${ev.primitive}).isNaN()) {
|
||||
${ev.isNull} = true;
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions.mathfuncs
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -44,6 +45,23 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
|
|||
if (result.isNaN) null else result
|
||||
}
|
||||
}
|
||||
|
||||
// name of function in java.lang.Math
|
||||
def funcName: String = name.toLowerCase
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
val eval = child.gen(ctx)
|
||||
eval.code + s"""
|
||||
boolean ${ev.isNull} = ${eval.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive});
|
||||
if (Double.valueOf(${ev.primitive}).isNaN()) {
|
||||
${ev.isNull} = true;
|
||||
}
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS")
|
||||
|
@ -72,7 +90,9 @@ case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG
|
|||
|
||||
case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P")
|
||||
|
||||
case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND")
|
||||
case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") {
|
||||
override def funcName: String = "rint"
|
||||
}
|
||||
|
||||
case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM")
|
||||
|
||||
|
@ -84,6 +104,10 @@ case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN")
|
|||
|
||||
case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH")
|
||||
|
||||
case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES")
|
||||
case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") {
|
||||
override def funcName: String = "toDegrees"
|
||||
}
|
||||
|
||||
case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS")
|
||||
case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") {
|
||||
override def funcName: String = "toRadians"
|
||||
}
|
||||
|
|
|
@ -17,10 +17,10 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.trees
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
|
||||
import org.apache.spark.sql.catalyst.errors.TreeNodeException
|
||||
import org.apache.spark.sql.catalyst.trees.LeafNode
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
|
||||
import org.apache.spark.sql.catalyst.trees
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
object NamedExpression {
|
||||
|
@ -116,6 +116,8 @@ case class Alias(child: Expression, name: String)(
|
|||
|
||||
override def eval(input: Row): Any = child.eval(input)
|
||||
|
||||
override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
|
||||
|
||||
override def dataType: DataType = child.dataType
|
||||
override def nullable: Boolean = child.nullable
|
||||
override def metadata: Metadata = {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
|
||||
import org.apache.spark.sql.catalyst.trees
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
|
||||
import org.apache.spark.sql.types.DataType
|
||||
|
@ -51,6 +52,25 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
|
|||
}
|
||||
result
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
s"""
|
||||
boolean ${ev.isNull} = true;
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
|
||||
""" +
|
||||
children.map { e =>
|
||||
val eval = e.gen(ctx)
|
||||
s"""
|
||||
if (${ev.isNull}) {
|
||||
${eval.code}
|
||||
if (!${eval.isNull}) {
|
||||
${ev.isNull} = false;
|
||||
${ev.primitive} = ${eval.primitive};
|
||||
}
|
||||
}
|
||||
"""
|
||||
}.mkString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] {
|
||||
|
@ -61,6 +81,13 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr
|
|||
child.eval(input) == null
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
val eval = child.gen(ctx)
|
||||
ev.isNull = "false"
|
||||
ev.primitive = eval.isNull
|
||||
eval.code
|
||||
}
|
||||
|
||||
override def toString: String = s"IS NULL $child"
|
||||
}
|
||||
|
||||
|
@ -72,6 +99,13 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E
|
|||
override def eval(input: Row): Any = {
|
||||
child.eval(input) != null
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
val eval = child.gen(ctx)
|
||||
ev.isNull = "false"
|
||||
ev.primitive = s"(!(${eval.isNull}))"
|
||||
eval.code
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -95,4 +129,25 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
|
|||
}
|
||||
numNonNulls >= n
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
val nonnull = ctx.freshName("nonnull")
|
||||
val code = children.map { e =>
|
||||
val eval = e.gen(ctx)
|
||||
s"""
|
||||
if ($nonnull < $n) {
|
||||
${eval.code}
|
||||
if (!${eval.isNull}) {
|
||||
$nonnull += 1;
|
||||
}
|
||||
}
|
||||
"""
|
||||
}.mkString("\n")
|
||||
s"""
|
||||
int $nonnull = 0;
|
||||
$code
|
||||
boolean ${ev.isNull} = false;
|
||||
boolean ${ev.primitive} = $nonnull >= $n;
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,9 +18,10 @@
|
|||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
||||
import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
object InterpretedPredicate {
|
||||
def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) =
|
||||
|
@ -82,6 +83,10 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex
|
|||
case b: Boolean => !b
|
||||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
defineCodeGen(ctx, ev, c => s"!($c)")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -141,6 +146,29 @@ case class And(left: Expression, right: Expression)
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
val eval1 = left.gen(ctx)
|
||||
val eval2 = right.gen(ctx)
|
||||
|
||||
// The result should be `false`, if any of them is `false` whenever the other is null or not.
|
||||
s"""
|
||||
${eval1.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
boolean ${ev.primitive} = false;
|
||||
|
||||
if (!${eval1.isNull} && !${eval1.primitive}) {
|
||||
} else {
|
||||
${eval2.code}
|
||||
if (!${eval2.isNull} && !${eval2.primitive}) {
|
||||
} else if (!${eval1.isNull} && !${eval2.isNull}) {
|
||||
${ev.primitive} = true;
|
||||
} else {
|
||||
${ev.isNull} = true;
|
||||
}
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
case class Or(left: Expression, right: Expression)
|
||||
|
@ -167,6 +195,29 @@ case class Or(left: Expression, right: Expression)
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
val eval1 = left.gen(ctx)
|
||||
val eval2 = right.gen(ctx)
|
||||
|
||||
// The result should be `true`, if any of them is `true` whenever the other is null or not.
|
||||
s"""
|
||||
${eval1.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
boolean ${ev.primitive} = true;
|
||||
|
||||
if (!${eval1.isNull} && ${eval1.primitive}) {
|
||||
} else {
|
||||
${eval2.code}
|
||||
if (!${eval2.isNull} && ${eval2.primitive}) {
|
||||
} else if (!${eval1.isNull} && !${eval2.isNull}) {
|
||||
${ev.primitive} = false;
|
||||
} else {
|
||||
${ev.isNull} = true;
|
||||
}
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
abstract class BinaryComparison extends BinaryExpression with Predicate {
|
||||
|
@ -198,6 +249,20 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
|
|||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
left.dataType match {
|
||||
case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, {
|
||||
(c1, c3) => s"$c1 $symbol $c3"
|
||||
})
|
||||
case TimestampType =>
|
||||
// java.sql.Timestamp does not have compare()
|
||||
super.genCode(ctx, ev)
|
||||
case other => defineCodeGen (ctx, ev, {
|
||||
(c1, c2) => s"$c1.compare($c2) $symbol 0"
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
protected def evalInternal(evalE1: Any, evalE2: Any): Any =
|
||||
sys.error(s"BinaryComparisons must override either eval or evalInternal")
|
||||
}
|
||||
|
@ -215,6 +280,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
|
|||
if (left.dataType != BinaryType) l == r
|
||||
else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
|
||||
}
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType))
|
||||
}
|
||||
}
|
||||
|
||||
case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison {
|
||||
|
@ -235,6 +303,17 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
|
|||
l == r
|
||||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
val eval1 = left.gen(ctx)
|
||||
val eval2 = right.gen(ctx)
|
||||
val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive)
|
||||
ev.isNull = "false"
|
||||
eval1.code + eval2.code + s"""
|
||||
boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) ||
|
||||
(!${eval1.isNull} && $equalCode);
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
|
||||
|
@ -309,6 +388,27 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
|
|||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
val condEval = predicate.gen(ctx)
|
||||
val trueEval = trueValue.gen(ctx)
|
||||
val falseEval = falseValue.gen(ctx)
|
||||
|
||||
s"""
|
||||
${condEval.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
|
||||
if (!${condEval.isNull} && ${condEval.primitive}) {
|
||||
${trueEval.code}
|
||||
${ev.isNull} = ${trueEval.isNull};
|
||||
${ev.primitive} = ${trueEval.primitive};
|
||||
} else {
|
||||
${falseEval.code}
|
||||
${ev.isNull} = ${falseEval.isNull};
|
||||
${ev.primitive} = ${falseEval.primitive};
|
||||
}
|
||||
"""
|
||||
}
|
||||
|
||||
override def toString: String = s"if ($predicate) $trueValue else $falseValue"
|
||||
}
|
||||
|
||||
|
@ -393,6 +493,48 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
|
|||
return res
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
val len = branchesArr.length
|
||||
val got = ctx.freshName("got")
|
||||
|
||||
val cases = (0 until len/2).map { i =>
|
||||
val cond = branchesArr(i * 2).gen(ctx)
|
||||
val res = branchesArr(i * 2 + 1).gen(ctx)
|
||||
s"""
|
||||
if (!$got) {
|
||||
${cond.code}
|
||||
if (!${cond.isNull} && ${cond.primitive}) {
|
||||
$got = true;
|
||||
${res.code}
|
||||
${ev.isNull} = ${res.isNull};
|
||||
${ev.primitive} = ${res.primitive};
|
||||
}
|
||||
}
|
||||
"""
|
||||
}.mkString("\n")
|
||||
|
||||
val other = if (len % 2 == 1) {
|
||||
val res = branchesArr(len - 1).gen(ctx)
|
||||
s"""
|
||||
if (!$got) {
|
||||
${res.code}
|
||||
${ev.isNull} = ${res.isNull};
|
||||
${ev.primitive} = ${res.primitive};
|
||||
}
|
||||
"""
|
||||
} else {
|
||||
""
|
||||
}
|
||||
|
||||
s"""
|
||||
boolean $got = false;
|
||||
boolean ${ev.isNull} = true;
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
|
||||
$cases
|
||||
$other
|
||||
"""
|
||||
}
|
||||
|
||||
override def toString: String = {
|
||||
"CASE" + branches.sliding(2, 2).map {
|
||||
case Seq(cond, value) => s" WHEN $cond THEN $value"
|
||||
|
@ -444,6 +586,52 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
|
|||
return res
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
val keyEval = key.gen(ctx)
|
||||
val len = branchesArr.length
|
||||
val got = ctx.freshName("got")
|
||||
|
||||
val cases = (0 until len/2).map { i =>
|
||||
val cond = branchesArr(i * 2).gen(ctx)
|
||||
val res = branchesArr(i * 2 + 1).gen(ctx)
|
||||
s"""
|
||||
if (!$got) {
|
||||
${cond.code}
|
||||
if (${keyEval.isNull} && ${cond.isNull} ||
|
||||
!${keyEval.isNull} && !${cond.isNull}
|
||||
&& ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) {
|
||||
$got = true;
|
||||
${res.code}
|
||||
${ev.isNull} = ${res.isNull};
|
||||
${ev.primitive} = ${res.primitive};
|
||||
}
|
||||
}
|
||||
"""
|
||||
}.mkString("\n")
|
||||
|
||||
val other = if (len % 2 == 1) {
|
||||
val res = branchesArr(len - 1).gen(ctx)
|
||||
s"""
|
||||
if (!$got) {
|
||||
${res.code}
|
||||
${ev.isNull} = ${res.isNull};
|
||||
${ev.primitive} = ${res.primitive};
|
||||
}
|
||||
"""
|
||||
} else {
|
||||
""
|
||||
}
|
||||
|
||||
s"""
|
||||
boolean $got = false;
|
||||
boolean ${ev.isNull} = true;
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
|
||||
${keyEval.code}
|
||||
$cases
|
||||
$other
|
||||
"""
|
||||
}
|
||||
|
||||
private def equalNullSafe(l: Any, r: Any) = {
|
||||
if (l == null && r == null) {
|
||||
true
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.collection.OpenHashSet
|
||||
|
||||
|
@ -60,6 +61,17 @@ case class NewSet(elementType: DataType) extends LeafExpression {
|
|||
new OpenHashSet[Any]()
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
elementType match {
|
||||
case IntegerType | LongType =>
|
||||
ev.isNull = "false"
|
||||
s"""
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = new ${ctx.javaType(dataType)}();
|
||||
"""
|
||||
case _ => super.genCode(ctx, ev)
|
||||
}
|
||||
}
|
||||
|
||||
override def toString: String = s"new Set($dataType)"
|
||||
}
|
||||
|
||||
|
@ -91,6 +103,25 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
|
|||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType
|
||||
elementType match {
|
||||
case IntegerType | LongType =>
|
||||
val itemEval = item.gen(ctx)
|
||||
val setEval = set.gen(ctx)
|
||||
val htype = ctx.javaType(dataType)
|
||||
|
||||
ev.isNull = "false"
|
||||
ev.primitive = setEval.primitive
|
||||
itemEval.code + setEval.code + s"""
|
||||
if (!${itemEval.isNull} && !${setEval.isNull}) {
|
||||
(($htype)${setEval.primitive}).add(${itemEval.primitive});
|
||||
}
|
||||
"""
|
||||
case _ => super.genCode(ctx, ev)
|
||||
}
|
||||
}
|
||||
|
||||
override def toString: String = s"$set += $item"
|
||||
}
|
||||
|
||||
|
@ -116,14 +147,31 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
|
|||
val rightValue = iterator.next()
|
||||
leftEval.add(rightValue)
|
||||
}
|
||||
leftEval
|
||||
} else {
|
||||
null
|
||||
}
|
||||
leftEval
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType
|
||||
elementType match {
|
||||
case IntegerType | LongType =>
|
||||
val leftEval = left.gen(ctx)
|
||||
val rightEval = right.gen(ctx)
|
||||
val htype = ctx.javaType(dataType)
|
||||
|
||||
ev.isNull = leftEval.isNull
|
||||
ev.primitive = leftEval.primitive
|
||||
leftEval.code + rightEval.code + s"""
|
||||
if (!${leftEval.isNull} && !${rightEval.isNull}) {
|
||||
${leftEval.primitive}.union((${htype})${rightEval.primitive});
|
||||
}
|
||||
"""
|
||||
case _ => super.genCode(ctx, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
import java.util.regex.Pattern
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
trait StringRegexExpression extends ExpectsInputTypes {
|
||||
|
@ -137,6 +138,10 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
|
|||
override def convert(v: UTF8String): UTF8String = v.toUpperCase()
|
||||
|
||||
override def toString: String = s"Upper($child)"
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
defineCodeGen(ctx, ev, c => s"($c).toUpperCase()")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -147,6 +152,10 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE
|
|||
override def convert(v: UTF8String): UTF8String = v.toLowerCase()
|
||||
|
||||
override def toString: String = s"Lower($child)"
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
defineCodeGen(ctx, ev, c => s"($c).toLowerCase()")
|
||||
}
|
||||
}
|
||||
|
||||
/** A base trait for functions that compare two strings, returning a boolean. */
|
||||
|
@ -181,6 +190,9 @@ trait StringComparison extends ExpectsInputTypes {
|
|||
case class Contains(left: Expression, right: Expression)
|
||||
extends BinaryExpression with Predicate with StringComparison {
|
||||
override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -189,6 +201,9 @@ case class Contains(left: Expression, right: Expression)
|
|||
case class StartsWith(left: Expression, right: Expression)
|
||||
extends BinaryExpression with Predicate with StringComparison {
|
||||
override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -197,6 +212,9 @@ case class StartsWith(left: Expression, right: Expression)
|
|||
case class EndsWith(left: Expression, right: Expression)
|
||||
extends BinaryExpression with Predicate with StringComparison {
|
||||
override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
|
||||
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.spark.SparkFunSuite
|
|||
import org.apache.spark.sql.catalyst.CatalystTypeConverters
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection}
|
||||
import org.apache.spark.sql.catalyst.expressions.mathfuncs._
|
||||
import org.apache.spark.sql.catalyst.util.DateUtils
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -35,11 +36,20 @@ import org.apache.spark.sql.types._
|
|||
|
||||
class ExpressionEvaluationBaseSuite extends SparkFunSuite {
|
||||
|
||||
def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = {
|
||||
checkEvaluationWithoutCodegen(expression, expected, inputRow)
|
||||
checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow)
|
||||
checkEvaluationWithGeneratedProjection(expression, expected, inputRow)
|
||||
}
|
||||
|
||||
def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = {
|
||||
expression.eval(inputRow)
|
||||
}
|
||||
|
||||
def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = {
|
||||
def checkEvaluationWithoutCodegen(
|
||||
expression: Expression,
|
||||
expected: Any,
|
||||
inputRow: Row = EmptyRow): Unit = {
|
||||
val actual = try evaluate(expression, inputRow) catch {
|
||||
case e: Exception => fail(s"Exception evaluating $expression", e)
|
||||
}
|
||||
|
@ -49,6 +59,68 @@ class ExpressionEvaluationBaseSuite extends SparkFunSuite {
|
|||
}
|
||||
}
|
||||
|
||||
def checkEvaluationWithGeneratedMutableProjection(
|
||||
expression: Expression,
|
||||
expected: Any,
|
||||
inputRow: Row = EmptyRow): Unit = {
|
||||
|
||||
val plan = try {
|
||||
GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)()
|
||||
} catch {
|
||||
case e: Throwable =>
|
||||
val ctx = GenerateProjection.newCodeGenContext()
|
||||
val evaluated = expression.gen(ctx)
|
||||
fail(
|
||||
s"""
|
||||
|Code generation of $expression failed:
|
||||
|${evaluated.code}
|
||||
|$e
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
val actual = plan(inputRow).apply(0)
|
||||
if (actual != expected) {
|
||||
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
|
||||
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
|
||||
}
|
||||
}
|
||||
|
||||
def checkEvaluationWithGeneratedProjection(
|
||||
expression: Expression,
|
||||
expected: Any,
|
||||
inputRow: Row = EmptyRow): Unit = {
|
||||
val ctx = GenerateProjection.newCodeGenContext()
|
||||
lazy val evaluated = expression.gen(ctx)
|
||||
|
||||
val plan = try {
|
||||
GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)
|
||||
} catch {
|
||||
case e: Throwable =>
|
||||
fail(
|
||||
s"""
|
||||
|Code generation of $expression failed:
|
||||
|${evaluated.code}
|
||||
|$e
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
val actual = plan(inputRow)
|
||||
val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected)))
|
||||
if (actual.hashCode() != expectedRow.hashCode()) {
|
||||
fail(
|
||||
s"""
|
||||
|Mismatched hashCodes for values: $actual, $expectedRow
|
||||
|Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()}
|
||||
|Expressions: ${expression}
|
||||
|Code: ${evaluated}
|
||||
""".stripMargin)
|
||||
}
|
||||
if (actual != expectedRow) {
|
||||
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
|
||||
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
|
||||
}
|
||||
}
|
||||
|
||||
def checkDoubleEvaluation(
|
||||
expression: Expression,
|
||||
expected: Spread[Double],
|
||||
|
@ -69,8 +141,16 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
|
|||
test("literals") {
|
||||
checkEvaluation(Literal(1), 1)
|
||||
checkEvaluation(Literal(true), true)
|
||||
checkEvaluation(Literal(false), false)
|
||||
checkEvaluation(Literal(0L), 0L)
|
||||
List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach {
|
||||
d => {
|
||||
checkEvaluation(Literal(d), d)
|
||||
checkEvaluation(Literal(d.toFloat), d.toFloat)
|
||||
}
|
||||
}
|
||||
checkEvaluation(Literal("test"), "test")
|
||||
checkEvaluation(Literal.create(null, StringType), null)
|
||||
checkEvaluation(Literal(1) + Literal(1), 2)
|
||||
}
|
||||
|
||||
|
@ -1367,6 +1447,11 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
|
|||
// TODO: Make the tests work with codegen.
|
||||
class ExpressionEvaluationWithoutCodeGenSuite extends ExpressionEvaluationBaseSuite {
|
||||
|
||||
override def checkEvaluation(
|
||||
expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = {
|
||||
checkEvaluationWithoutCodegen(expression, expected, inputRow)
|
||||
}
|
||||
|
||||
test("CreateStruct") {
|
||||
val row = Row(1, 2, 3)
|
||||
val c1 = 'a.int.at(0).as("a")
|
||||
|
|
|
@ -21,34 +21,9 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
|
|||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
|
||||
/**
|
||||
* Overrides our expression evaluation tests to use code generation for evaluation.
|
||||
* Additional tests for code generation.
|
||||
*/
|
||||
class GeneratedEvaluationSuite extends ExpressionEvaluationSuite {
|
||||
override def checkEvaluation(
|
||||
expression: Expression,
|
||||
expected: Any,
|
||||
inputRow: Row = EmptyRow): Unit = {
|
||||
val plan = try {
|
||||
GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)()
|
||||
} catch {
|
||||
case e: Throwable =>
|
||||
val ctx = GenerateProjection.newCodeGenContext()
|
||||
val evaluated = GenerateProjection.expressionEvaluator(expression, ctx)
|
||||
fail(
|
||||
s"""
|
||||
|Code generation of $expression failed:
|
||||
|${evaluated.code}
|
||||
|$e
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
val actual = plan(inputRow).apply(0)
|
||||
if (actual != expected) {
|
||||
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
|
||||
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
test("multithreaded eval") {
|
||||
import scala.concurrent._
|
||||
|
|
|
@ -1,61 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.CatalystTypeConverters
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
|
||||
/**
|
||||
* Overrides our expression evaluation tests to use generated code on mutable rows.
|
||||
*/
|
||||
class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
|
||||
override def checkEvaluation(
|
||||
expression: Expression,
|
||||
expected: Any,
|
||||
inputRow: Row = EmptyRow): Unit = {
|
||||
val ctx = GenerateProjection.newCodeGenContext()
|
||||
lazy val evaluated = GenerateProjection.expressionEvaluator(expression, ctx)
|
||||
|
||||
val plan = try {
|
||||
GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)
|
||||
} catch {
|
||||
case e: Throwable =>
|
||||
fail(
|
||||
s"""
|
||||
|Code generation of $expression failed:
|
||||
|${evaluated.code}
|
||||
|$e
|
||||
""".stripMargin)
|
||||
}
|
||||
|
||||
val actual = plan(inputRow)
|
||||
val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected)))
|
||||
if (actual.hashCode() != expectedRow.hashCode()) {
|
||||
fail(
|
||||
s"""
|
||||
|Mismatched hashCodes for values: $actual, $expectedRow
|
||||
|Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()}
|
||||
|${evaluated.code}
|
||||
""".stripMargin)
|
||||
}
|
||||
if (actual != expectedRow) {
|
||||
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
|
||||
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -53,7 +53,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
|
|||
|
||||
check("10", Literal.create(10, IntegerType))
|
||||
check("1000000000000000", Literal.create(1000000000000000L, LongType))
|
||||
check("1.5", Literal.create(1.5, FloatType))
|
||||
check("1.5", Literal.create(1.5f, FloatType))
|
||||
check("hello", Literal.create("hello", StringType))
|
||||
check(defaultPartitionName, Literal.create(null, NullType))
|
||||
}
|
||||
|
@ -83,13 +83,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
|
|||
ArrayBuffer(
|
||||
Literal.create(10, IntegerType),
|
||||
Literal.create("hello", StringType),
|
||||
Literal.create(1.5, FloatType)))
|
||||
Literal.create(1.5f, FloatType)))
|
||||
})
|
||||
|
||||
check("file://path/a=10/b_hello/c=1.5", Some {
|
||||
PartitionValues(
|
||||
ArrayBuffer("c"),
|
||||
ArrayBuffer(Literal.create(1.5, FloatType)))
|
||||
ArrayBuffer(Literal.create(1.5f, FloatType)))
|
||||
})
|
||||
|
||||
check("file:///", None)
|
||||
|
|
Loading…
Reference in a new issue