[SPARK-24121][SQL] Add API for handling expression code generation

## What changes were proposed in this pull request?

This patch tries to implement this [proposal](https://github.com/apache/spark/pull/19813#issuecomment-354045400) to add an API for handling expression code generation. It should allow us to manipulate how to generate codes for expressions.

In details, this adds an new abstraction `CodeBlock` to `JavaCode`. `CodeBlock` holds the code snippet and inputs for generating actual java code.

For example, in following java code:

```java
  int ${variable} = 1;
  boolean ${isNull} = ${CodeGenerator.defaultValue(BooleanType)};
```

`variable`, `isNull` are two `VariableValue` and `CodeGenerator.defaultValue(BooleanType)` is a string. They are all inputs to this code block and held by `CodeBlock` representing this code.

For codegen, we provide a specified string interpolator `code`, so you can define a code like this:
```scala
  val codeBlock =
    code"""
         |int ${variable} = 1;
         |boolean ${isNull} = ${CodeGenerator.defaultValue(BooleanType)};
        """.stripMargin
  // Generates actual java code.
  codeBlock.toString
```

Because those inputs are held separately in `CodeBlock` before generating code, we can safely manipulate them, e.g., replacing statements to aliased variables, etc..

## How was this patch tested?

Added tests.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #21193 from viirya/SPARK-24121.
This commit is contained in:
Liang-Chi Hsieh 2018-05-23 01:50:22 +08:00 committed by Wenchen Fan
parent 8086acc2f6
commit f9f055afa4
41 changed files with 479 additions and 172 deletions

View file

@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
/**
@ -56,13 +57,13 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (nullable) {
ev.copy(code =
s"""
code"""
|boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
|$javaType ${ev.value} = ${ev.isNull} ?
| ${CodeGenerator.defaultValue(dataType)} : ($value);
""".stripMargin)
} else {
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
}
}
}

View file

@ -23,6 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@ -623,8 +624,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
ev.copy(code = eval.code +
castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast))
ev.copy(code =
code"""
${eval.code}
// This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull}
${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)}
""")
}
// The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull`

View file

@ -22,6 +22,7 @@ import java.util.Locale
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@ -108,9 +109,9 @@ abstract class Expression extends TreeNode[Expression] {
JavaCode.isNullVariable(isNull),
JavaCode.variable(value, dataType)))
reduceCodeSize(ctx, eval)
if (eval.code.nonEmpty) {
if (eval.code.toString.nonEmpty) {
// Add `this` in the comment.
eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim)
eval.copy(code = ctx.registerComment(this.toString) + eval.code)
} else {
eval
}
@ -119,7 +120,7 @@ abstract class Expression extends TreeNode[Expression] {
private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
// TODO: support whole stage codegen too
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
if (eval.code.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
@ -136,14 +137,14 @@ abstract class Expression extends TreeNode[Expression] {
val funcFullName = ctx.addNewFunction(funcName,
s"""
|private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) {
| ${eval.code.trim}
| ${eval.code}
| $setIsNull
| return ${eval.value};
|}
""".stripMargin)
eval.value = JavaCode.variable(newValue, dataType)
eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
}
}
@ -437,15 +438,14 @@ abstract class UnaryExpression extends Expression {
if (nullable) {
val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode)
ev.copy(code = s"""
ev.copy(code = code"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
ev.copy(code = code"""
${childGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = FalseLiteral)
@ -537,14 +537,13 @@ abstract class BinaryExpression extends Expression {
}
}
ev.copy(code = s"""
ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
ev.copy(code = code"""
${leftGen.code}
${rightGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@ -681,13 +680,12 @@ abstract class TernaryExpression extends Expression {
}
}
ev.copy(code = s"""
ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
ev.copy(code = code"""
${leftGen.code}
${midGen.code}
${rightGen.code}

View file

@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, LongType}
/**
@ -72,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
ev.copy(code = s"""
ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
$countTerm++;""", isNull = FalseLiteral)
}

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.DataType
/**
@ -1030,7 +1031,7 @@ case class ScalaUDF(
""".stripMargin
ev.copy(code =
s"""
code"""
|$evalCode
|${initArgs.mkString("\n")}
|$callFunc

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._
@ -181,7 +182,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
}
ev.copy(code = childCode.code +
s"""
code"""
|long ${ev.value} = 0L;
|boolean ${ev.isNull} = ${childCode.isNull};
|if (!${childCode.isNull}) {

View file

@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, IntegerType}
/**
@ -46,7 +47,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
val idTerm = "partitionId"
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm)
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
isNull = FalseLiteral)
}
}

View file

@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@ -164,7 +165,7 @@ case class PreciseTimestampConversion(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
ev.copy(code = eval.code +
s"""boolean ${ev.isNull} = ${eval.isNull};
code"""boolean ${ev.isNull} = ${eval.isNull};
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value};
""".stripMargin)
}

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@ -259,7 +260,7 @@ trait DivModLike extends BinaryArithmetic {
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
}
if (!left.nullable && !right.nullable) {
ev.copy(code = s"""
ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@ -270,7 +271,7 @@ trait DivModLike extends BinaryArithmetic {
${ev.value} = $operation;
}""")
} else {
ev.copy(code = s"""
ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@ -436,7 +437,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
}
if (!left.nullable && !right.nullable) {
ev.copy(code = s"""
ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@ -447,7 +448,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
$result
}""")
} else {
ev.copy(code = s"""
ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@ -569,7 +570,7 @@ case class Least(children: Seq[Expression]) extends Expression {
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
s"""
code"""
|${ev.isNull} = true;
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$codes
@ -644,7 +645,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
s"""
code"""
|${ev.isNull} = true;
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$codes

View file

@ -38,6 +38,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@ -57,19 +58,19 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
* valid if `isNull` is set to `true`.
*/
case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue)
case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue)
object ExprCode {
def apply(isNull: ExprValue, value: ExprValue): ExprCode = {
ExprCode(code = "", isNull, value)
ExprCode(code = EmptyBlock, isNull, value)
}
def forNullValue(dataType: DataType): ExprCode = {
ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
ExprCode(code = EmptyBlock, isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
}
def forNonNullValue(value: ExprValue): ExprCode = {
ExprCode(code = "", isNull = FalseLiteral, value = value)
ExprCode(code = EmptyBlock, isNull = FalseLiteral, value = value)
}
}
@ -330,9 +331,9 @@ class CodegenContext {
def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = {
val value = addMutableState(javaType(dataType), variableName)
val code = dataType match {
case StringType => s"$value = $initCode.clone();"
case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();"
case _ => s"$value = $initCode;"
case StringType => code"$value = $initCode.clone();"
case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();"
case _ => code"$value = $initCode;"
}
ExprCode(code, FalseLiteral, JavaCode.global(value, dataType))
}
@ -1056,7 +1057,7 @@ class CodegenContext {
val eval = expr.genCode(this)
val state = SubExprEliminationState(eval.isNull, eval.value)
e.foreach(localSubExprEliminationExprs.put(_, state))
eval.code.trim
eval.code.toString
}
SubExprCodes(codes, localSubExprEliminationExprs.toMap)
}
@ -1084,7 +1085,7 @@ class CodegenContext {
val fn =
s"""
|private void $fnName(InternalRow $INPUT_ROW) {
| ${eval.code.trim}
| ${eval.code}
| $isNull = ${eval.isNull};
| $value = ${eval.value};
|}
@ -1141,7 +1142,7 @@ class CodegenContext {
def registerComment(
text: => String,
placeholderId: String = "",
force: Boolean = false): String = {
force: Boolean = false): Block = {
// By default, disable comments in generated code because computing the comments themselves can
// be extremely expensive in certain cases, such as deeply-nested expressions which operate over
// inputs with wide schemas. For more details on the performance issues that motivated this
@ -1160,9 +1161,9 @@ class CodegenContext {
s"// $text"
}
placeHolderToComments += (name -> comment)
s"/*$name*/"
code"/*$name*/"
} else {
""
EmptyBlock
}
}
}

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
/**
* A trait that can be used to provide a fallback mode for expression code generation.
@ -46,7 +47,7 @@ trait CodegenFallback extends Expression {
val placeHolder = ctx.registerComment(this.toString)
val javaType = CodeGenerator.javaType(this.dataType)
if (nullable) {
ev.copy(code = s"""
ev.copy(code = code"""
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
boolean ${ev.isNull} = $objectTerm == null;
@ -55,7 +56,7 @@ trait CodegenFallback extends Expression {
${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;
}""")
} else {
ev.copy(code = s"""
ev.copy(code = code"""
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
$javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;

View file

@ -22,6 +22,7 @@ import scala.annotation.tailrec
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
@ -71,7 +72,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values)
)
val code =
s"""
code"""
|final InternalRow $tmpInput = $input;
|final Object[] $values = new Object[${schema.length}];
|$allFields
@ -97,7 +98,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
ctx,
JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType),
elementType)
val code = s"""
val code = code"""
final ArrayData $tmpInput = $input;
final int $numElements = $tmpInput.numElements();
final Object[] $values = new Object[$numElements];
@ -124,7 +125,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType)
val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType)
val code = s"""
val code = code"""
final MapData $tmpInput = $input;
${keyConverter.code}
${valueConverter.code}

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
/**
@ -286,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true)
val code =
s"""
code"""
|$rowWriter.reset();
|$evalSubexpr
|$writeExpressions
@ -343,7 +344,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
| }
|
| public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) {
| ${eval.code.trim}
| ${eval.code}
| return ${eval.value};
| }
|

View file

@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import java.lang.{Boolean => JBool}
import scala.collection.mutable.ArrayBuffer
import scala.language.{existentials, implicitConversions}
import org.apache.spark.sql.types.{BooleanType, DataType}
@ -114,6 +115,147 @@ object JavaCode {
}
}
/**
* A trait representing a block of java code.
*/
trait Block extends JavaCode {
// The expressions to be evaluated inside this block.
def exprValues: Set[ExprValue]
// Returns java code string for this code block.
override def toString: String = _marginChar match {
case Some(c) => code.stripMargin(c).trim
case _ => code.trim
}
def length: Int = toString.length
def nonEmpty: Boolean = toString.nonEmpty
// The leading prefix that should be stripped from each line.
// By default we strip blanks or control characters followed by '|' from the line.
var _marginChar: Option[Char] = Some('|')
def stripMargin(c: Char): this.type = {
_marginChar = Some(c)
this
}
def stripMargin: this.type = {
_marginChar = Some('|')
this
}
// Concatenates this block with other block.
def + (other: Block): Block
}
object Block {
val CODE_BLOCK_BUFFER_LENGTH: Int = 512
implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks)
implicit class BlockHelper(val sc: StringContext) extends AnyVal {
def code(args: Any*): Block = {
sc.checkLengths(args)
if (sc.parts.length == 0) {
EmptyBlock
} else {
args.foreach {
case _: ExprValue =>
case _: Int | _: Long | _: Float | _: Double | _: String =>
case _: Block =>
case other => throw new IllegalArgumentException(
s"Can not interpolate ${other.getClass.getName} into code block.")
}
val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args)
CodeBlock(codeParts, blockInputs)
}
}
}
// Folds eagerly the literal args into the code parts.
private def foldLiteralArgs(parts: Seq[String], args: Seq[Any]): (Seq[String], Seq[JavaCode]) = {
val codeParts = ArrayBuffer.empty[String]
val blockInputs = ArrayBuffer.empty[JavaCode]
val strings = parts.iterator
val inputs = args.iterator
val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH)
buf.append(strings.next)
while (strings.hasNext) {
val input = inputs.next
input match {
case _: ExprValue | _: Block =>
codeParts += buf.toString
buf.clear
blockInputs += input.asInstanceOf[JavaCode]
case _ =>
buf.append(input)
}
buf.append(strings.next)
}
if (buf.nonEmpty) {
codeParts += buf.toString
}
(codeParts.toSeq, blockInputs.toSeq)
}
}
/**
* A block of java code. Including a sequence of code parts and some inputs to this block.
* The actual java code is generated by embedding the inputs into the code parts.
*/
case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block {
override lazy val exprValues: Set[ExprValue] = {
blockInputs.flatMap {
case b: Block => b.exprValues
case e: ExprValue => Set(e)
}.toSet
}
override lazy val code: String = {
val strings = codeParts.iterator
val inputs = blockInputs.iterator
val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH)
buf.append(StringContext.treatEscapes(strings.next))
while (strings.hasNext) {
buf.append(inputs.next)
buf.append(StringContext.treatEscapes(strings.next))
}
buf.toString
}
override def + (other: Block): Block = other match {
case c: CodeBlock => Blocks(Seq(this, c))
case b: Blocks => Blocks(Seq(this) ++ b.blocks)
case EmptyBlock => this
}
}
case class Blocks(blocks: Seq[Block]) extends Block {
override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet
override lazy val code: String = blocks.map(_.toString).mkString("\n")
override def + (other: Block): Block = other match {
case c: CodeBlock => Blocks(blocks :+ c)
case b: Blocks => Blocks(blocks ++ b.blocks)
case EmptyBlock => this
}
}
object EmptyBlock extends Block with Serializable {
override val code: String = ""
override val exprValues: Set[ExprValue] = Set.empty
override def + (other: Block): Block = other
}
/**
* A typed java fragment that must be a valid java expression.
*/
@ -123,10 +265,9 @@ trait ExprValue extends JavaCode {
}
object ExprValue {
implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString
implicit def exprValueToString(exprValue: ExprValue): String = exprValue.code
}
/**
* A java expression fragment.
*/

View file

@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@ -91,7 +92,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
ev.copy(code = s"""
ev.copy(code = code"""
boolean ${ev.isNull} = false;
${childGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
@ -1177,14 +1178,14 @@ case class ArrayJoin(
}
if (nullable) {
ev.copy(
s"""
code"""
|boolean ${ev.isNull} = true;
|UTF8String ${ev.value} = null;
|$code
""".stripMargin)
} else {
ev.copy(
s"""
code"""
|UTF8String ${ev.value} = null;
|$code
""".stripMargin, FalseLiteral)
@ -1269,11 +1270,11 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast
val childGen = child.genCode(ctx)
val javaType = CodeGenerator.javaType(dataType)
val i = ctx.freshName("i")
val item = ExprCode("",
val item = ExprCode(EmptyBlock,
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
ev.copy(code =
s"""
code"""
|${childGen.code}
|boolean ${ev.isNull} = true;
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@ -1334,11 +1335,11 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
val childGen = child.genCode(ctx)
val javaType = CodeGenerator.javaType(dataType)
val i = ctx.freshName("i")
val item = ExprCode("",
val item = ExprCode(EmptyBlock,
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
ev.copy(code =
s"""
code"""
|${childGen.code}
|boolean ${ev.isNull} = true;
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@ -1653,7 +1654,7 @@ case class Concat(children: Seq[Expression]) extends Expression {
expressions = inputs,
funcName = "valueConcat",
extraArguments = (s"$javaType[]", args) :: Nil)
ev.copy(s"""
ev.copy(code"""
$initCode
$codes
$javaType ${ev.value} = $concatenator.concat($args);
@ -1963,7 +1964,7 @@ case class ArrayRepeat(left: Expression, right: Expression)
val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic)
ev.copy(code =
s"""
code"""
|boolean ${ev.isNull} = false;
|${leftGen.code}
|${rightGen.code}

View file

@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
@ -63,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
val (preprocess, assigns, postprocess, arrayData) =
GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
ev.copy(
code = preprocess + assigns + postprocess,
code = code"${preprocess}${assigns}${postprocess}",
value = JavaCode.variable(arrayData, dataType),
isNull = FalseLiteral)
}
@ -219,7 +220,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) =
GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false)
val code =
s"""
code"""
final boolean ${ev.isNull} = false;
$preprocessKeyData
$assignKeys
@ -373,7 +374,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
extraArguments = "Object[]" -> values :: Nil)
ev.copy(code =
s"""
code"""
|Object[] $values = new Object[${valExprs.size}];
|$valuesCode
|final InternalRow ${ev.value} = new $rowClass($values);

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
// scalastyle:off line.size.limit
@ -66,7 +67,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
val falseEval = falseValue.genCode(ctx)
val code =
s"""
code"""
|${condEval.code}
|boolean ${ev.isNull} = false;
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@ -265,7 +266,7 @@ case class CaseWhen(
}.mkString)
ev.copy(code =
s"""
code"""
|${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED;
|do {
| $codes

View file

@ -27,6 +27,7 @@ import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@ -717,7 +718,7 @@ abstract class UnixTime
} else {
val formatterName = ctx.addReferenceObj("formatter", formatter, df)
val eval1 = left.genCode(ctx)
ev.copy(code = s"""
ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@ -746,7 +747,7 @@ abstract class UnixTime
})
case TimestampType =>
val eval1 = left.genCode(ctx)
ev.copy(code = s"""
ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@ -757,7 +758,7 @@ abstract class UnixTime
val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
val eval1 = left.genCode(ctx)
ev.copy(code = s"""
ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@ -852,7 +853,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
} else {
val formatterName = ctx.addReferenceObj("formatter", formatter, df)
val t = left.genCode(ctx)
ev.copy(code = s"""
ev.copy(code = code"""
${t.code}
boolean ${ev.isNull} = ${t.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@ -1042,7 +1043,7 @@ case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Optio
val tz = ctx.addReferenceObj("timeZone", timeZone)
val longOpt = ctx.freshName("longOpt")
val eval = child.genCode(ctx)
val code = s"""
val code = code"""
|${eval.code}
|${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true;
|${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)};
@ -1090,7 +1091,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
if (right.foldable) {
val tz = right.eval().asInstanceOf[UTF8String]
if (tz == null) {
ev.copy(code = s"""
ev.copy(code = code"""
|boolean ${ev.isNull} = true;
|long ${ev.value} = 0;
""".stripMargin)
@ -1104,7 +1105,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
v => s"""$v = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
ev.copy(code = s"""
ev.copy(code = code"""
|${eval.code}
|boolean ${ev.isNull} = ${eval.isNull};
|long ${ev.value} = 0;
@ -1287,7 +1288,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
if (right.foldable) {
val tz = right.eval().asInstanceOf[UTF8String]
if (tz == null) {
ev.copy(code = s"""
ev.copy(code = code"""
|boolean ${ev.isNull} = true;
|long ${ev.value} = 0;
""".stripMargin)
@ -1301,7 +1302,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
v => s"""$v = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
ev.copy(code = s"""
ev.copy(code = code"""
|${eval.code}
|boolean ${ev.isNull} = ${eval.isNull};
|long ${ev.value} = 0;
@ -1444,13 +1445,13 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes {
val javaType = CodeGenerator.javaType(dataType)
if (format.foldable) {
if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) {
ev.copy(code = s"""
ev.copy(code = code"""
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
} else {
val t = instant.genCode(ctx)
val truncFuncStr = truncFunc(t.value, truncLevel.toString)
ev.copy(code = s"""
ev.copy(code = code"""
${t.code}
boolean ${ev.isNull} = ${t.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode}
import org.apache.spark.sql.types._
/**
@ -72,7 +72,8 @@ case class PromotePrecision(child: Expression) extends UnaryExpression {
override def eval(input: InternalRow): Any = child.eval(input)
/** Just a simple pass-through for code generation. */
override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx)
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("")
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
ev.copy(EmptyBlock)
override def prettyName: String = "promote_precision"
override def sql: String = child.sql
override lazy val canonicalized: Expression = child.canonicalized

View file

@ -23,6 +23,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
@ -215,7 +216,7 @@ case class Stack(children: Seq[Expression]) extends Generator {
// Create the collection.
val wrapperClass = classOf[mutable.WrappedArray[_]].getName
ev.copy(code =
s"""
code"""
|$code
|$wrapperClass<InternalRow> ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);
""".stripMargin, isNull = FalseLiteral)

View file

@ -28,6 +28,7 @@ import org.apache.commons.codec.digest.DigestUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
@ -293,7 +294,7 @@ abstract class HashExpression[E] extends Expression {
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
s"""
code"""
|$hashResultType ${ev.value} = $seed;
|$codes
""".stripMargin)
@ -674,7 +675,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
ev.copy(code =
s"""
code"""
|${CodeGenerator.JAVA_INT} ${ev.value} = $seed;
|${CodeGenerator.JAVA_INT} $childHash = 0;
|$codes

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.rdd.InputFileBlockHolder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
@ -42,8 +43,9 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
s"$className.getInputFilePath();", isNull = FalseLiteral)
val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();",
isNull = FalseLiteral)
}
}
@ -65,8 +67,8 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
s"$className.getStartOffset();", isNull = FalseLiteral)
val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral)
}
}
@ -88,7 +90,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
s"$className.getLength();", isNull = FalseLiteral)
val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral)
}
}

View file

@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.NumberConverter
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@ -1191,11 +1192,11 @@ abstract class RoundBase(child: Expression, scale: Expression,
val javaType = CodeGenerator.javaType(dataType)
if (scaleV == null) { // if scale is null, no need to eval its child at all
ev.copy(code = s"""
ev.copy(code = code"""
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
} else {
ev.copy(code = s"""
ev.copy(code = code"""
${ce.code}
boolean ${ev.isNull} = ${ce.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};

View file

@ -21,6 +21,7 @@ import java.util.UUID
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@ -88,7 +89,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the value is null or false.
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
ExprCode(code = s"""${eval.code}
ExprCode(code = code"""${eval.code}
|if (${eval.isNull} || !${eval.value}) {
| throw new RuntimeException($errMsgField);
|}""".stripMargin, isNull = TrueLiteral,
@ -151,7 +152,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta
ctx.addPartitionInitializationStatement(s"$randomGen = " +
"new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" +
s"${randomSeed.get}L + partitionIndex);")
ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();",
ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();",
isNull = FalseLiteral)
}

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@ -111,7 +112,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
ev.copy(code =
s"""
code"""
|${ev.isNull} = true;
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|do {
@ -232,7 +233,7 @@ case class IsNaN(child: Expression) extends UnaryExpression
val eval = child.genCode(ctx)
child.dataType match {
case DoubleType | FloatType =>
ev.copy(code = s"""
ev.copy(code = code"""
${eval.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral)
@ -278,7 +279,7 @@ case class NaNvl(left: Expression, right: Expression)
val rightGen = right.genCode(ctx)
left.dataType match {
case DoubleType | FloatType =>
ev.copy(code = s"""
ev.copy(code = code"""
${leftGen.code}
boolean ${ev.isNull} = false;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@ -440,7 +441,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
}.mkString)
ev.copy(code =
s"""
code"""
|${CodeGenerator.JAVA_INT} $nonnull = 0;
|do {
| $codes

View file

@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@ -269,7 +270,7 @@ case class StaticInvoke(
s"${ev.value} = $callFunc;"
}
val code = s"""
val code = code"""
$argCode
$prepareIsNull
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@ -385,8 +386,7 @@ case class Invoke(
"""
}
val code = s"""
${obj.code}
val code = obj.code + code"""
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${obj.isNull}) {
@ -492,7 +492,7 @@ case class NewInstance(
s"new $className($argString)"
}
val code = s"""
val code = code"""
$argCode
${outer.map(_.code).getOrElse("")}
final $javaType ${ev.value} = ${ev.isNull} ?
@ -532,9 +532,7 @@ case class UnwrapOption(
val javaType = CodeGenerator.javaType(dataType)
val inputObject = child.genCode(ctx)
val code = s"""
${inputObject.code}
val code = inputObject.code + code"""
final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
$javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} :
(${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get();
@ -564,9 +562,7 @@ case class WrapOption(child: Expression, optType: DataType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val inputObject = child.genCode(ctx)
val code = s"""
${inputObject.code}
val code = inputObject.code + code"""
scala.Option ${ev.value} =
${inputObject.isNull} ?
scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value});
@ -935,8 +931,7 @@ case class MapObjects private(
)
}
val code = s"""
${genInputData.code}
val code = genInputData.code + code"""
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${genInputData.isNull}) {
@ -1147,8 +1142,7 @@ case class CatalystToExternalMap private(
"""
val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();"
val code = s"""
${genInputData.code}
val code = genInputData.code + code"""
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${genInputData.isNull}) {
@ -1391,9 +1385,8 @@ case class ExternalMapToCatalyst private(
val mapCls = classOf[ArrayBasedMapData].getName
val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType)
val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType)
val code =
s"""
${inputMap.code}
val code = inputMap.code +
code"""
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${inputMap.isNull}) {
final int $length = ${inputMap.value}.size();
@ -1471,7 +1464,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
val schemaField = ctx.addReferenceObj("schema", schema)
val code =
s"""
code"""
|Object[] $values = new Object[${children.size}];
|$childrenCode
|final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField);
@ -1499,8 +1492,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
val javaType = CodeGenerator.javaType(dataType)
val serialize = s"$serializer.serialize(${input.value}, null).array()"
val code = s"""
${input.code}
val code = input.code + code"""
final $javaType ${ev.value} =
${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize;
"""
@ -1532,8 +1524,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
val deserialize =
s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)"
val code = s"""
${input.code}
val code = input.code + code"""
final $javaType ${ev.value} =
${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize;
"""
@ -1614,9 +1605,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
funcName = "initializeJavaBean",
extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil)
val code =
s"""
|${instanceGen.code}
val code = instanceGen.code +
code"""
|$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value};
|if (!${instanceGen.isNull}) {
| $initializeCode
@ -1664,9 +1654,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)
// because errMsgField is used only when the value is null.
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val code = s"""
${childGen.code}
val code = childGen.code + code"""
if (${childGen.isNull}) {
throw new NullPointerException($errMsgField);
}
@ -1709,7 +1697,7 @@ case class GetExternalRowField(
// because errMsgField is used only when the field is null.
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val row = child.genCode(ctx)
val code = s"""
val code = code"""
${row.code}
if (${row.isNull}) {
@ -1784,7 +1772,7 @@ case class ValidateExternalType(child: Expression, expected: DataType)
s"$obj instanceof ${CodeGenerator.boxedType(dataType)}"
}
val code = s"""
val code = code"""
${input.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${input.isNull}) {

View file

@ -22,6 +22,7 @@ import scala.collection.immutable.TreeSet
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@ -290,7 +291,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}.mkString("\n"))
ev.copy(code =
s"""
code"""
|${valueGen.code}
|byte $tmpResult = $HAS_NULL;
|if (!${valueGen.isNull}) {
@ -354,7 +355,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
""
}
ev.copy(code =
s"""
code"""
|${childGen.code}
|${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull};
|${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false;
@ -406,7 +407,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
// The result should be `false`, if any of them is `false` whenever the other is null or not.
if (!left.nullable && !right.nullable) {
ev.copy(code = s"""
ev.copy(code = code"""
${eval1.code}
boolean ${ev.value} = false;
@ -415,7 +416,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
${ev.value} = ${eval2.value};
}""", isNull = FalseLiteral)
} else {
ev.copy(code = s"""
ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = false;
boolean ${ev.value} = false;
@ -470,7 +471,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
// The result should be `true`, if any of them is `true` whenever the other is null or not.
if (!left.nullable && !right.nullable) {
ev.isNull = FalseLiteral
ev.copy(code = s"""
ev.copy(code = code"""
${eval1.code}
boolean ${ev.value} = true;
@ -479,7 +480,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
${ev.value} = ${eval2.value};
}""", isNull = FalseLiteral)
} else {
ev.copy(code = s"""
ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = false;
boolean ${ev.value} = true;
@ -621,7 +622,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
val eval1 = left.genCode(ctx)
val eval2 = right.genCode(ctx)
val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value)
ev.copy(code = eval1.code + eval2.code + s"""
ev.copy(code = eval1.code + eval2.code + code"""
boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) ||
(!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral)
}

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
@ -82,7 +83,7 @@ case class Rand(child: Expression) extends RDG {
val rngTerm = ctx.addMutableState(className, "rng")
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
ev.copy(code = s"""
ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""",
isNull = FalseLiteral)
}
@ -120,7 +121,7 @@ case class Randn(child: Expression) extends RDG {
val rngTerm = ctx.addMutableState(className, "rng")
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
ev.copy(code = s"""
ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""",
isNull = FalseLiteral)
}

View file

@ -23,6 +23,7 @@ import java.util.regex.{MatchResult, Pattern}
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@ -123,7 +124,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
ev.copy(code = s"""
ev.copy(code = code"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@ -132,7 +133,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
}
""")
} else {
ev.copy(code = s"""
ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
""")
@ -198,7 +199,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
ev.copy(code = s"""
ev.copy(code = code"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@ -207,7 +208,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
}
""")
} else {
ev.copy(code = s"""
ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
""")

View file

@ -27,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@ -105,7 +106,7 @@ case class ConcatWs(children: Seq[Expression])
expressions = inputs,
funcName = "valueConcatWs",
extraArguments = ("UTF8String[]", args) :: Nil)
ev.copy(s"""
ev.copy(code"""
UTF8String[] $args = new UTF8String[$numArgs];
${separator.code}
$codes
@ -149,7 +150,7 @@ case class ConcatWs(children: Seq[Expression])
}
}.unzip
val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code))
val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code.toString))
val varargCounts = ctx.splitExpressionsWithCurrentInputs(
expressions = varargCount,
@ -176,7 +177,7 @@ case class ConcatWs(children: Seq[Expression])
foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n"))
ev.copy(
s"""
code"""
$codes
int $varargNum = ${children.count(_.dataType == StringType) - 1};
int $idxVararg = 0;
@ -288,7 +289,7 @@ case class Elt(children: Seq[Expression]) extends Expression {
}.mkString)
ev.copy(
s"""
code"""
|${index.code}
|final int $indexVal = ${index.value};
|${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false;
@ -654,7 +655,7 @@ case class StringTrim(
val srcString = evals(0)
if (evals.length == 1) {
ev.copy(evals.map(_.code).mkString + s"""
ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@ -671,7 +672,7 @@ case class StringTrim(
} else {
${ev.value} = ${srcString.value}.trim(${trimString.value});
}"""
ev.copy(evals.map(_.code).mkString + s"""
ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@ -754,7 +755,7 @@ case class StringTrimLeft(
val srcString = evals(0)
if (evals.length == 1) {
ev.copy(evals.map(_.code).mkString + s"""
ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@ -771,7 +772,7 @@ case class StringTrimLeft(
} else {
${ev.value} = ${srcString.value}.trimLeft(${trimString.value});
}"""
ev.copy(evals.map(_.code).mkString + s"""
ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@ -856,7 +857,7 @@ case class StringTrimRight(
val srcString = evals(0)
if (evals.length == 1) {
ev.copy(evals.map(_.code).mkString + s"""
ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@ -873,7 +874,7 @@ case class StringTrimRight(
} else {
${ev.value} = ${srcString.value}.trimRight(${trimString.value});
}"""
ev.copy(evals.map(_.code).mkString + s"""
ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@ -1024,7 +1025,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
val substrGen = substr.genCode(ctx)
val strGen = str.genCode(ctx)
val startGen = start.genCode(ctx)
ev.copy(code = s"""
ev.copy(code = code"""
int ${ev.value} = 0;
boolean ${ev.isNull} = false;
${startGen.code}
@ -1350,7 +1351,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
val formatter = classOf[java.util.Formatter].getName
val sb = ctx.freshName("sb")
val stringBuffer = classOf[StringBuffer].getName
ev.copy(code = s"""
ev.copy(code = code"""
${pattern.code}
boolean ${ev.isNull} = ${pattern.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, IntegerType}
/**
@ -45,7 +46,7 @@ case class BadCodegenExpression() extends LeafExpression {
override def eval(input: InternalRow): Any = 10
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ev.copy(code =
s"""
code"""
|int some_variable = 11;
|int ${ev.value} = 10;
""".stripMargin)

View file

@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{BooleanType, IntegerType}
class CodeBlockSuite extends SparkFunSuite {
test("Block interpolates string and ExprValue inputs") {
val isNull = JavaCode.isNullVariable("expr1_isNull")
val stringLiteral = "false"
val code = code"boolean $isNull = $stringLiteral;"
assert(code.toString == "boolean expr1_isNull = false;")
}
test("Literals are folded into string code parts instead of block inputs") {
val value = JavaCode.variable("expr1", IntegerType)
val intLiteral = 1
val code = code"int $value = $intLiteral;"
assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value))
}
test("Block.stripMargin") {
val isNull = JavaCode.isNullVariable("expr1_isNull")
val value = JavaCode.variable("expr1", IntegerType)
val code1 =
code"""
|boolean $isNull = false;
|int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin
val expected =
s"""
|boolean expr1_isNull = false;
|int expr1 = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin.trim
assert(code1.toString == expected)
val code2 =
code"""
>boolean $isNull = false;
>int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin('>')
assert(code2.toString == expected)
}
test("Block can capture input expr values") {
val isNull = JavaCode.isNullVariable("expr1_isNull")
val value = JavaCode.variable("expr1", IntegerType)
val code =
code"""
|boolean $isNull = false;
|int $value = -1;
""".stripMargin
val exprValues = code.exprValues
assert(exprValues.size == 2)
assert(exprValues === Set(value, isNull))
}
test("concatenate blocks") {
val isNull1 = JavaCode.isNullVariable("expr1_isNull")
val value1 = JavaCode.variable("expr1", IntegerType)
val isNull2 = JavaCode.isNullVariable("expr2_isNull")
val value2 = JavaCode.variable("expr2", IntegerType)
val literal = JavaCode.literal("100", IntegerType)
val code =
code"""
|boolean $isNull1 = false;
|int $value1 = -1;""".stripMargin +
code"""
|boolean $isNull2 = true;
|int $value2 = $literal;""".stripMargin
val expected =
"""
|boolean expr1_isNull = false;
|int expr1 = -1;
|boolean expr2_isNull = true;
|int expr2 = 100;""".stripMargin.trim
assert(code.toString == expected)
val exprValues = code.exprValues
assert(exprValues.size == 5)
assert(exprValues === Set(isNull1, value1, isNull2, value2, literal))
}
test("Throws exception when interpolating unexcepted object in code block") {
val obj = Tuple2(1, 1)
val e = intercept[IllegalArgumentException] {
code"$obj"
}
assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}"))
}
test("replace expr values in code block") {
val expr = JavaCode.expression("1 + 1", IntegerType)
val isNull = JavaCode.isNullVariable("expr1_isNull")
val exprInFunc = JavaCode.variable("expr1", IntegerType)
val code =
code"""
|callFunc(int $expr) {
| boolean $isNull = false;
| int $exprInFunc = $expr + 1;
|}""".stripMargin
val aliasedParam = JavaCode.variable("aliased", expr.javaType)
val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map {
case _: SimpleExprValue => aliasedParam
case other => other
}
val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin
val expected =
code"""
|callFunc(int $aliasedParam) {
| boolean $isNull = false;
| int $exprInFunc = $aliasedParam + 1;
|}""".stripMargin
assert(aliasedCode.toString == expected.toString)
}
}

View file

@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
@ -58,14 +59,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
}
val valueVar = ctx.freshName("value")
val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
val code = s"${ctx.registerComment(str)}\n" + (if (nullable) {
s"""
val code = code"${ctx.registerComment(str)}" + (if (nullable) {
code"""
boolean $isNullVar = $columnVar.isNullAt($ordinal);
$javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value);
"""
} else {
s"$javaType $valueVar = $value;"
}).trim
code"$javaType $valueVar = $value;"
})
ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType))
}

View file

@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.metric.SQLMetrics
@ -152,7 +153,7 @@ case class ExpandExec(
} else {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val code = s"""
val code = code"""
|boolean $isNull = true;
|${CodeGenerator.javaType(firstExpr.dataType)} $value =
| ${CodeGenerator.defaultValue(firstExpr.dataType)};

View file

@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types._
@ -313,13 +314,13 @@ case class GenerateExec(
if (checks.nonEmpty) {
val isNull = ctx.freshName("isNull")
val code =
s"""
code"""
|boolean $isNull = ${checks.mkString(" || ")};
|$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter;
""".stripMargin
ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt))
} else {
ExprCode(s"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt))
ExprCode(code"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt))
}
}

View file

@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
@ -122,10 +123,10 @@ trait CodegenSupport extends SparkPlan {
ctx.INPUT_ROW = row
ctx.currentVars = colVars
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
val code = s"""
val code = code"""
|$evaluateInputs
|${ev.code.trim}
""".stripMargin.trim
|${ev.code}
""".stripMargin
ExprCode(code, FalseLiteral, ev.value)
} else {
// There are no columns
@ -259,8 +260,8 @@ trait CodegenSupport extends SparkPlan {
* them to be evaluated twice.
*/
protected def evaluateVariables(variables: Seq[ExprCode]): String = {
val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n")
variables.foreach(_.code = "")
val evaluate = variables.filter(_.code.nonEmpty).map(_.code.toString).mkString("\n")
variables.foreach(_.code = EmptyBlock)
evaluate
}
@ -275,8 +276,8 @@ trait CodegenSupport extends SparkPlan {
val evaluateVars = new StringBuilder
variables.zipWithIndex.foreach { case (ev, i) =>
if (ev.code != "" && required.contains(attributes(i))) {
evaluateVars.append(ev.code.trim + "\n")
ev.code = ""
evaluateVars.append(ev.code.toString + "\n")
ev.code = EmptyBlock
}
}
evaluateVars.toString()

View file

@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
@ -190,7 +191,7 @@ case class HashAggregateExec(
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
// The initial expression should not access any column
val ev = e.genCode(ctx)
val initVars = s"""
val initVars = code"""
| $isNull = ${ev.isNull};
| $value = ${ev.value};
""".stripMargin
@ -773,8 +774,8 @@ case class HashAggregateExec(
val findOrInsertRegularHashMap: String =
s"""
|// generate grouping key
|${unsafeRowKeyCode.code.trim}
|${hashEval.code.trim}
|${unsafeRowKeyCode.code}
|${hashEval.code}
|if ($checkFallbackForBytesToBytesMap) {
| // try to get the buffer from hash map
| $unsafeRowBuffer =

View file

@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
/**
@ -50,7 +51,7 @@ abstract class HashMapGenerator(
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
val ev = e.genCode(ctx)
val initVars =
s"""
code"""
| $isNull = ${ev.isNull};
| $value = ${ev.value};
""".stripMargin

View file

@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan}
@ -183,7 +184,7 @@ case class BroadcastHashJoinExec(
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val javaType = CodeGenerator.javaType(a.dataType)
val code = s"""
val code = code"""
|boolean $isNull = true;
|$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
|if ($matched != null) {

View file

@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
@ -521,7 +522,7 @@ case class SortMergeJoinExec(
if (a.nullable) {
val isNull = ctx.freshName("isNull")
val code =
s"""
code"""
|$isNull = $leftRow.isNullAt($i);
|$value = $isNull ? $defaultValue : ($valueCode);
""".stripMargin
@ -533,7 +534,7 @@ case class SortMergeJoinExec(
(ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)),
leftVarsDecl)
} else {
val code = s"$value = $valueCode;"
val code = code"$value = $valueCode;"
val leftVarsDecl = s"""$javaType $value = $defaultValue;"""
(ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl)
}

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, Generator}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StructType}
@ -315,6 +316,7 @@ case class EmptyGenerator() extends Generator {
override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val iteratorClass = classOf[Iterator[_]].getName
ev.copy(code = s"$iteratorClass<InternalRow> ${ev.value} = $iteratorClass$$.MODULE$$.empty();")
ev.copy(code =
code"$iteratorClass<InternalRow> ${ev.value} = $iteratorClass$$.MODULE$$.empty();")
}
}