[SPARK-24121][SQL] Add API for handling expression code generation
## What changes were proposed in this pull request? This patch tries to implement this [proposal](https://github.com/apache/spark/pull/19813#issuecomment-354045400) to add an API for handling expression code generation. It should allow us to manipulate how to generate codes for expressions. In details, this adds an new abstraction `CodeBlock` to `JavaCode`. `CodeBlock` holds the code snippet and inputs for generating actual java code. For example, in following java code: ```java int ${variable} = 1; boolean ${isNull} = ${CodeGenerator.defaultValue(BooleanType)}; ``` `variable`, `isNull` are two `VariableValue` and `CodeGenerator.defaultValue(BooleanType)` is a string. They are all inputs to this code block and held by `CodeBlock` representing this code. For codegen, we provide a specified string interpolator `code`, so you can define a code like this: ```scala val codeBlock = code""" |int ${variable} = 1; |boolean ${isNull} = ${CodeGenerator.defaultValue(BooleanType)}; """.stripMargin // Generates actual java code. codeBlock.toString ``` Because those inputs are held separately in `CodeBlock` before generating code, we can safely manipulate them, e.g., replacing statements to aliased variables, etc.. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #21193 from viirya/SPARK-24121.
This commit is contained in:
parent
8086acc2f6
commit
f9f055afa4
|
@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.errors.attachTree
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
|
@ -56,13 +57,13 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
|
|||
val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
|
||||
if (nullable) {
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
|
||||
|$javaType ${ev.value} = ${ev.isNull} ?
|
||||
| ${CodeGenerator.defaultValue(dataType)} : ($value);
|
||||
""".stripMargin)
|
||||
} else {
|
||||
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
|
||||
ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.spark.SparkException
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||
|
@ -623,8 +624,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|
|||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val eval = child.genCode(ctx)
|
||||
val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
|
||||
ev.copy(code = eval.code +
|
||||
castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast))
|
||||
|
||||
ev.copy(code =
|
||||
code"""
|
||||
${eval.code}
|
||||
// This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull}
|
||||
${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)}
|
||||
""")
|
||||
}
|
||||
|
||||
// The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull`
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.util.Locale
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.trees.TreeNode
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.Utils
|
||||
|
@ -108,9 +109,9 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
JavaCode.isNullVariable(isNull),
|
||||
JavaCode.variable(value, dataType)))
|
||||
reduceCodeSize(ctx, eval)
|
||||
if (eval.code.nonEmpty) {
|
||||
if (eval.code.toString.nonEmpty) {
|
||||
// Add `this` in the comment.
|
||||
eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim)
|
||||
eval.copy(code = ctx.registerComment(this.toString) + eval.code)
|
||||
} else {
|
||||
eval
|
||||
}
|
||||
|
@ -119,7 +120,7 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
|
||||
private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
|
||||
// TODO: support whole stage codegen too
|
||||
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
|
||||
if (eval.code.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
|
||||
val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
|
||||
val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
|
||||
val localIsNull = eval.isNull
|
||||
|
@ -136,14 +137,14 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
val funcFullName = ctx.addNewFunction(funcName,
|
||||
s"""
|
||||
|private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) {
|
||||
| ${eval.code.trim}
|
||||
| ${eval.code}
|
||||
| $setIsNull
|
||||
| return ${eval.value};
|
||||
|}
|
||||
""".stripMargin)
|
||||
|
||||
eval.value = JavaCode.variable(newValue, dataType)
|
||||
eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
|
||||
eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -437,15 +438,14 @@ abstract class UnaryExpression extends Expression {
|
|||
|
||||
if (nullable) {
|
||||
val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode)
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${childGen.code}
|
||||
boolean ${ev.isNull} = ${childGen.isNull};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
$nullSafeEval
|
||||
""")
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
boolean ${ev.isNull} = false;
|
||||
ev.copy(code = code"""
|
||||
${childGen.code}
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
$resultCode""", isNull = FalseLiteral)
|
||||
|
@ -537,14 +537,13 @@ abstract class BinaryExpression extends Expression {
|
|||
}
|
||||
}
|
||||
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
boolean ${ev.isNull} = true;
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
$nullSafeEval
|
||||
""")
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
boolean ${ev.isNull} = false;
|
||||
ev.copy(code = code"""
|
||||
${leftGen.code}
|
||||
${rightGen.code}
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
@ -681,13 +680,12 @@ abstract class TernaryExpression extends Expression {
|
|||
}
|
||||
}
|
||||
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
boolean ${ev.isNull} = true;
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
$nullSafeEval""")
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
boolean ${ev.isNull} = false;
|
||||
ev.copy(code = code"""
|
||||
${leftGen.code}
|
||||
${midGen.code}
|
||||
${rightGen.code}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.types.{DataType, LongType}
|
||||
|
||||
/**
|
||||
|
@ -72,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {
|
|||
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
|
||||
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
|
||||
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
|
||||
$countTerm++;""", isNull = FalseLiteral)
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.types.DataType
|
||||
|
||||
/**
|
||||
|
@ -1030,7 +1031,7 @@ case class ScalaUDF(
|
|||
""".stripMargin
|
||||
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|$evalCode
|
||||
|${initArgs.mkString("\n")}
|
||||
|$callFunc
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._
|
||||
|
||||
|
@ -181,7 +182,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
|
|||
}
|
||||
|
||||
ev.copy(code = childCode.code +
|
||||
s"""
|
||||
code"""
|
||||
|long ${ev.value} = 0L;
|
||||
|boolean ${ev.isNull} = ${childCode.isNull};
|
||||
|if (!${childCode.isNull}) {
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.types.{DataType, IntegerType}
|
||||
|
||||
/**
|
||||
|
@ -46,7 +47,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
|
|||
val idTerm = "partitionId"
|
||||
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm)
|
||||
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
|
||||
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
|
||||
ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
|
||||
isNull = FalseLiteral)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException
|
|||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.CalendarInterval
|
||||
|
||||
|
@ -164,7 +165,7 @@ case class PreciseTimestampConversion(
|
|||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val eval = child.genCode(ctx)
|
||||
ev.copy(code = eval.code +
|
||||
s"""boolean ${ev.isNull} = ${eval.isNull};
|
||||
code"""boolean ${ev.isNull} = ${eval.isNull};
|
||||
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value};
|
||||
""".stripMargin)
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.CalendarInterval
|
||||
|
@ -259,7 +260,7 @@ trait DivModLike extends BinaryArithmetic {
|
|||
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
|
||||
}
|
||||
if (!left.nullable && !right.nullable) {
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${eval2.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
@ -270,7 +271,7 @@ trait DivModLike extends BinaryArithmetic {
|
|||
${ev.value} = $operation;
|
||||
}""")
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${eval2.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
@ -436,7 +437,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
|
|||
}
|
||||
|
||||
if (!left.nullable && !right.nullable) {
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${eval2.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
@ -447,7 +448,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
|
|||
$result
|
||||
}""")
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${eval2.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
@ -569,7 +570,7 @@ case class Least(children: Seq[Expression]) extends Expression {
|
|||
""".stripMargin,
|
||||
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|${ev.isNull} = true;
|
||||
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|$codes
|
||||
|
@ -644,7 +645,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
|
|||
""".stripMargin,
|
||||
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|${ev.isNull} = true;
|
||||
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|$codes
|
||||
|
|
|
@ -38,6 +38,7 @@ import org.apache.spark.internal.Logging
|
|||
import org.apache.spark.metrics.source.CodegenMetrics
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -57,19 +58,19 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
|
|||
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
|
||||
* valid if `isNull` is set to `true`.
|
||||
*/
|
||||
case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue)
|
||||
case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue)
|
||||
|
||||
object ExprCode {
|
||||
def apply(isNull: ExprValue, value: ExprValue): ExprCode = {
|
||||
ExprCode(code = "", isNull, value)
|
||||
ExprCode(code = EmptyBlock, isNull, value)
|
||||
}
|
||||
|
||||
def forNullValue(dataType: DataType): ExprCode = {
|
||||
ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
|
||||
ExprCode(code = EmptyBlock, isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
|
||||
}
|
||||
|
||||
def forNonNullValue(value: ExprValue): ExprCode = {
|
||||
ExprCode(code = "", isNull = FalseLiteral, value = value)
|
||||
ExprCode(code = EmptyBlock, isNull = FalseLiteral, value = value)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -330,9 +331,9 @@ class CodegenContext {
|
|||
def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = {
|
||||
val value = addMutableState(javaType(dataType), variableName)
|
||||
val code = dataType match {
|
||||
case StringType => s"$value = $initCode.clone();"
|
||||
case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();"
|
||||
case _ => s"$value = $initCode;"
|
||||
case StringType => code"$value = $initCode.clone();"
|
||||
case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();"
|
||||
case _ => code"$value = $initCode;"
|
||||
}
|
||||
ExprCode(code, FalseLiteral, JavaCode.global(value, dataType))
|
||||
}
|
||||
|
@ -1056,7 +1057,7 @@ class CodegenContext {
|
|||
val eval = expr.genCode(this)
|
||||
val state = SubExprEliminationState(eval.isNull, eval.value)
|
||||
e.foreach(localSubExprEliminationExprs.put(_, state))
|
||||
eval.code.trim
|
||||
eval.code.toString
|
||||
}
|
||||
SubExprCodes(codes, localSubExprEliminationExprs.toMap)
|
||||
}
|
||||
|
@ -1084,7 +1085,7 @@ class CodegenContext {
|
|||
val fn =
|
||||
s"""
|
||||
|private void $fnName(InternalRow $INPUT_ROW) {
|
||||
| ${eval.code.trim}
|
||||
| ${eval.code}
|
||||
| $isNull = ${eval.isNull};
|
||||
| $value = ${eval.value};
|
||||
|}
|
||||
|
@ -1141,7 +1142,7 @@ class CodegenContext {
|
|||
def registerComment(
|
||||
text: => String,
|
||||
placeholderId: String = "",
|
||||
force: Boolean = false): String = {
|
||||
force: Boolean = false): Block = {
|
||||
// By default, disable comments in generated code because computing the comments themselves can
|
||||
// be extremely expensive in certain cases, such as deeply-nested expressions which operate over
|
||||
// inputs with wide schemas. For more details on the performance issues that motivated this
|
||||
|
@ -1160,9 +1161,9 @@ class CodegenContext {
|
|||
s"// $text"
|
||||
}
|
||||
placeHolderToComments += (name -> comment)
|
||||
s"/*$name*/"
|
||||
code"/*$name*/"
|
||||
} else {
|
||||
""
|
||||
EmptyBlock
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.expressions.codegen
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
|
||||
/**
|
||||
* A trait that can be used to provide a fallback mode for expression code generation.
|
||||
|
@ -46,7 +47,7 @@ trait CodegenFallback extends Expression {
|
|||
val placeHolder = ctx.registerComment(this.toString)
|
||||
val javaType = CodeGenerator.javaType(this.dataType)
|
||||
if (nullable) {
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
$placeHolder
|
||||
Object $objectTerm = ((Expression) references[$idx]).eval($input);
|
||||
boolean ${ev.isNull} = $objectTerm == null;
|
||||
|
@ -55,7 +56,7 @@ trait CodegenFallback extends Expression {
|
|||
${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;
|
||||
}""")
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
$placeHolder
|
||||
Object $objectTerm = ((Expression) references[$idx]).eval($input);
|
||||
$javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;
|
||||
|
|
|
@ -22,6 +22,7 @@ import scala.annotation.tailrec
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -71,7 +72,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
|
|||
arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values)
|
||||
)
|
||||
val code =
|
||||
s"""
|
||||
code"""
|
||||
|final InternalRow $tmpInput = $input;
|
||||
|final Object[] $values = new Object[${schema.length}];
|
||||
|$allFields
|
||||
|
@ -97,7 +98,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
|
|||
ctx,
|
||||
JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType),
|
||||
elementType)
|
||||
val code = s"""
|
||||
val code = code"""
|
||||
final ArrayData $tmpInput = $input;
|
||||
final int $numElements = $tmpInput.numElements();
|
||||
final Object[] $values = new Object[$numElements];
|
||||
|
@ -124,7 +125,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
|
|||
|
||||
val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType)
|
||||
val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType)
|
||||
val code = s"""
|
||||
val code = code"""
|
||||
final MapData $tmpInput = $input;
|
||||
${keyConverter.code}
|
||||
${valueConverter.code}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.expressions.codegen
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
|
@ -286,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
|||
ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true)
|
||||
|
||||
val code =
|
||||
s"""
|
||||
code"""
|
||||
|$rowWriter.reset();
|
||||
|$evalSubexpr
|
||||
|$writeExpressions
|
||||
|
@ -343,7 +344,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
|||
| }
|
||||
|
|
||||
| public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) {
|
||||
| ${eval.code.trim}
|
||||
| ${eval.code}
|
||||
| return ${eval.value};
|
||||
| }
|
||||
|
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen
|
|||
|
||||
import java.lang.{Boolean => JBool}
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.language.{existentials, implicitConversions}
|
||||
|
||||
import org.apache.spark.sql.types.{BooleanType, DataType}
|
||||
|
@ -114,6 +115,147 @@ object JavaCode {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A trait representing a block of java code.
|
||||
*/
|
||||
trait Block extends JavaCode {
|
||||
|
||||
// The expressions to be evaluated inside this block.
|
||||
def exprValues: Set[ExprValue]
|
||||
|
||||
// Returns java code string for this code block.
|
||||
override def toString: String = _marginChar match {
|
||||
case Some(c) => code.stripMargin(c).trim
|
||||
case _ => code.trim
|
||||
}
|
||||
|
||||
def length: Int = toString.length
|
||||
|
||||
def nonEmpty: Boolean = toString.nonEmpty
|
||||
|
||||
// The leading prefix that should be stripped from each line.
|
||||
// By default we strip blanks or control characters followed by '|' from the line.
|
||||
var _marginChar: Option[Char] = Some('|')
|
||||
|
||||
def stripMargin(c: Char): this.type = {
|
||||
_marginChar = Some(c)
|
||||
this
|
||||
}
|
||||
|
||||
def stripMargin: this.type = {
|
||||
_marginChar = Some('|')
|
||||
this
|
||||
}
|
||||
|
||||
// Concatenates this block with other block.
|
||||
def + (other: Block): Block
|
||||
}
|
||||
|
||||
object Block {
|
||||
|
||||
val CODE_BLOCK_BUFFER_LENGTH: Int = 512
|
||||
|
||||
implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks)
|
||||
|
||||
implicit class BlockHelper(val sc: StringContext) extends AnyVal {
|
||||
def code(args: Any*): Block = {
|
||||
sc.checkLengths(args)
|
||||
if (sc.parts.length == 0) {
|
||||
EmptyBlock
|
||||
} else {
|
||||
args.foreach {
|
||||
case _: ExprValue =>
|
||||
case _: Int | _: Long | _: Float | _: Double | _: String =>
|
||||
case _: Block =>
|
||||
case other => throw new IllegalArgumentException(
|
||||
s"Can not interpolate ${other.getClass.getName} into code block.")
|
||||
}
|
||||
|
||||
val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args)
|
||||
CodeBlock(codeParts, blockInputs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Folds eagerly the literal args into the code parts.
|
||||
private def foldLiteralArgs(parts: Seq[String], args: Seq[Any]): (Seq[String], Seq[JavaCode]) = {
|
||||
val codeParts = ArrayBuffer.empty[String]
|
||||
val blockInputs = ArrayBuffer.empty[JavaCode]
|
||||
|
||||
val strings = parts.iterator
|
||||
val inputs = args.iterator
|
||||
val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH)
|
||||
|
||||
buf.append(strings.next)
|
||||
while (strings.hasNext) {
|
||||
val input = inputs.next
|
||||
input match {
|
||||
case _: ExprValue | _: Block =>
|
||||
codeParts += buf.toString
|
||||
buf.clear
|
||||
blockInputs += input.asInstanceOf[JavaCode]
|
||||
case _ =>
|
||||
buf.append(input)
|
||||
}
|
||||
buf.append(strings.next)
|
||||
}
|
||||
if (buf.nonEmpty) {
|
||||
codeParts += buf.toString
|
||||
}
|
||||
|
||||
(codeParts.toSeq, blockInputs.toSeq)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A block of java code. Including a sequence of code parts and some inputs to this block.
|
||||
* The actual java code is generated by embedding the inputs into the code parts.
|
||||
*/
|
||||
case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block {
|
||||
override lazy val exprValues: Set[ExprValue] = {
|
||||
blockInputs.flatMap {
|
||||
case b: Block => b.exprValues
|
||||
case e: ExprValue => Set(e)
|
||||
}.toSet
|
||||
}
|
||||
|
||||
override lazy val code: String = {
|
||||
val strings = codeParts.iterator
|
||||
val inputs = blockInputs.iterator
|
||||
val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH)
|
||||
buf.append(StringContext.treatEscapes(strings.next))
|
||||
while (strings.hasNext) {
|
||||
buf.append(inputs.next)
|
||||
buf.append(StringContext.treatEscapes(strings.next))
|
||||
}
|
||||
buf.toString
|
||||
}
|
||||
|
||||
override def + (other: Block): Block = other match {
|
||||
case c: CodeBlock => Blocks(Seq(this, c))
|
||||
case b: Blocks => Blocks(Seq(this) ++ b.blocks)
|
||||
case EmptyBlock => this
|
||||
}
|
||||
}
|
||||
|
||||
case class Blocks(blocks: Seq[Block]) extends Block {
|
||||
override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet
|
||||
override lazy val code: String = blocks.map(_.toString).mkString("\n")
|
||||
|
||||
override def + (other: Block): Block = other match {
|
||||
case c: CodeBlock => Blocks(blocks :+ c)
|
||||
case b: Blocks => Blocks(blocks ++ b.blocks)
|
||||
case EmptyBlock => this
|
||||
}
|
||||
}
|
||||
|
||||
object EmptyBlock extends Block with Serializable {
|
||||
override val code: String = ""
|
||||
override val exprValues: Set[ExprValue] = Set.empty
|
||||
|
||||
override def + (other: Block): Block = other
|
||||
}
|
||||
|
||||
/**
|
||||
* A typed java fragment that must be a valid java expression.
|
||||
*/
|
||||
|
@ -123,10 +265,9 @@ trait ExprValue extends JavaCode {
|
|||
}
|
||||
|
||||
object ExprValue {
|
||||
implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString
|
||||
implicit def exprValueToString(exprValue: ExprValue): String = exprValue.code
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A java expression fragment.
|
||||
*/
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
|
||||
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -91,7 +92,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
|
|||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val childGen = child.genCode(ctx)
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
boolean ${ev.isNull} = false;
|
||||
${childGen.code}
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
|
||||
|
@ -1177,14 +1178,14 @@ case class ArrayJoin(
|
|||
}
|
||||
if (nullable) {
|
||||
ev.copy(
|
||||
s"""
|
||||
code"""
|
||||
|boolean ${ev.isNull} = true;
|
||||
|UTF8String ${ev.value} = null;
|
||||
|$code
|
||||
""".stripMargin)
|
||||
} else {
|
||||
ev.copy(
|
||||
s"""
|
||||
code"""
|
||||
|UTF8String ${ev.value} = null;
|
||||
|$code
|
||||
""".stripMargin, FalseLiteral)
|
||||
|
@ -1269,11 +1270,11 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast
|
|||
val childGen = child.genCode(ctx)
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
val i = ctx.freshName("i")
|
||||
val item = ExprCode("",
|
||||
val item = ExprCode(EmptyBlock,
|
||||
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
|
||||
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|${childGen.code}
|
||||
|boolean ${ev.isNull} = true;
|
||||
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
@ -1334,11 +1335,11 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
|
|||
val childGen = child.genCode(ctx)
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
val i = ctx.freshName("i")
|
||||
val item = ExprCode("",
|
||||
val item = ExprCode(EmptyBlock,
|
||||
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
|
||||
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|${childGen.code}
|
||||
|boolean ${ev.isNull} = true;
|
||||
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
@ -1653,7 +1654,7 @@ case class Concat(children: Seq[Expression]) extends Expression {
|
|||
expressions = inputs,
|
||||
funcName = "valueConcat",
|
||||
extraArguments = (s"$javaType[]", args) :: Nil)
|
||||
ev.copy(s"""
|
||||
ev.copy(code"""
|
||||
$initCode
|
||||
$codes
|
||||
$javaType ${ev.value} = $concatenator.concat($args);
|
||||
|
@ -1963,7 +1964,7 @@ case class ArrayRepeat(left: Expression, right: Expression)
|
|||
val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic)
|
||||
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|boolean ${ev.isNull} = false;
|
||||
|${leftGen.code}
|
||||
|${rightGen.code}
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.Platform
|
||||
|
@ -63,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
|
|||
val (preprocess, assigns, postprocess, arrayData) =
|
||||
GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
|
||||
ev.copy(
|
||||
code = preprocess + assigns + postprocess,
|
||||
code = code"${preprocess}${assigns}${postprocess}",
|
||||
value = JavaCode.variable(arrayData, dataType),
|
||||
isNull = FalseLiteral)
|
||||
}
|
||||
|
@ -219,7 +220,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
|
|||
val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) =
|
||||
GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false)
|
||||
val code =
|
||||
s"""
|
||||
code"""
|
||||
final boolean ${ev.isNull} = false;
|
||||
$preprocessKeyData
|
||||
$assignKeys
|
||||
|
@ -373,7 +374,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
|
|||
extraArguments = "Object[]" -> values :: Nil)
|
||||
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|Object[] $values = new Object[${valExprs.size}];
|
||||
|$valuesCode
|
||||
|final InternalRow ${ev.value} = new $rowClass($values);
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
// scalastyle:off line.size.limit
|
||||
|
@ -66,7 +67,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
|
|||
val falseEval = falseValue.genCode(ctx)
|
||||
|
||||
val code =
|
||||
s"""
|
||||
code"""
|
||||
|${condEval.code}
|
||||
|boolean ${ev.isNull} = false;
|
||||
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
@ -265,7 +266,7 @@ case class CaseWhen(
|
|||
}.mkString)
|
||||
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED;
|
||||
|do {
|
||||
| $codes
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.commons.lang3.StringEscapeUtils
|
|||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||
|
@ -717,7 +718,7 @@ abstract class UnixTime
|
|||
} else {
|
||||
val formatterName = ctx.addReferenceObj("formatter", formatter, df)
|
||||
val eval1 = left.genCode(ctx)
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${eval1.code}
|
||||
boolean ${ev.isNull} = ${eval1.isNull};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
@ -746,7 +747,7 @@ abstract class UnixTime
|
|||
})
|
||||
case TimestampType =>
|
||||
val eval1 = left.genCode(ctx)
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${eval1.code}
|
||||
boolean ${ev.isNull} = ${eval1.isNull};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
@ -757,7 +758,7 @@ abstract class UnixTime
|
|||
val tz = ctx.addReferenceObj("timeZone", timeZone)
|
||||
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
|
||||
val eval1 = left.genCode(ctx)
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${eval1.code}
|
||||
boolean ${ev.isNull} = ${eval1.isNull};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
@ -852,7 +853,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
|
|||
} else {
|
||||
val formatterName = ctx.addReferenceObj("formatter", formatter, df)
|
||||
val t = left.genCode(ctx)
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${t.code}
|
||||
boolean ${ev.isNull} = ${t.isNull};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
@ -1042,7 +1043,7 @@ case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Optio
|
|||
val tz = ctx.addReferenceObj("timeZone", timeZone)
|
||||
val longOpt = ctx.freshName("longOpt")
|
||||
val eval = child.genCode(ctx)
|
||||
val code = s"""
|
||||
val code = code"""
|
||||
|${eval.code}
|
||||
|${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true;
|
||||
|${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)};
|
||||
|
@ -1090,7 +1091,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
|
|||
if (right.foldable) {
|
||||
val tz = right.eval().asInstanceOf[UTF8String]
|
||||
if (tz == null) {
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
|boolean ${ev.isNull} = true;
|
||||
|long ${ev.value} = 0;
|
||||
""".stripMargin)
|
||||
|
@ -1104,7 +1105,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
|
|||
ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
|
||||
v => s"""$v = $dtu.getTimeZone("UTC");""")
|
||||
val eval = left.genCode(ctx)
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
|${eval.code}
|
||||
|boolean ${ev.isNull} = ${eval.isNull};
|
||||
|long ${ev.value} = 0;
|
||||
|
@ -1287,7 +1288,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
|
|||
if (right.foldable) {
|
||||
val tz = right.eval().asInstanceOf[UTF8String]
|
||||
if (tz == null) {
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
|boolean ${ev.isNull} = true;
|
||||
|long ${ev.value} = 0;
|
||||
""".stripMargin)
|
||||
|
@ -1301,7 +1302,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
|
|||
ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
|
||||
v => s"""$v = $dtu.getTimeZone("UTC");""")
|
||||
val eval = left.genCode(ctx)
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
|${eval.code}
|
||||
|boolean ${ev.isNull} = ${eval.isNull};
|
||||
|long ${ev.value} = 0;
|
||||
|
@ -1444,13 +1445,13 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes {
|
|||
val javaType = CodeGenerator.javaType(dataType)
|
||||
if (format.foldable) {
|
||||
if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) {
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
boolean ${ev.isNull} = true;
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
|
||||
} else {
|
||||
val t = instant.genCode(ctx)
|
||||
val truncFuncStr = truncFunc(t.value, truncLevel.toString)
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${t.code}
|
||||
boolean ${ev.isNull} = ${t.isNull};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
|
@ -72,7 +72,8 @@ case class PromotePrecision(child: Expression) extends UnaryExpression {
|
|||
override def eval(input: InternalRow): Any = child.eval(input)
|
||||
/** Just a simple pass-through for code generation. */
|
||||
override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx)
|
||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("")
|
||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
|
||||
ev.copy(EmptyBlock)
|
||||
override def prettyName: String = "promote_precision"
|
||||
override def sql: String = child.sql
|
||||
override lazy val canonicalized: Expression = child.canonicalized
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.spark.sql.Row
|
|||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -215,7 +216,7 @@ case class Stack(children: Seq[Expression]) extends Generator {
|
|||
// Create the collection.
|
||||
val wrapperClass = classOf[mutable.WrappedArray[_]].getName
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|$code
|
||||
|$wrapperClass<InternalRow> ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);
|
||||
""".stripMargin, isNull = FalseLiteral)
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.commons.codec.digest.DigestUtils
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.Platform
|
||||
|
@ -293,7 +294,7 @@ abstract class HashExpression[E] extends Expression {
|
|||
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
|
||||
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|$hashResultType ${ev.value} = $seed;
|
||||
|$codes
|
||||
""".stripMargin)
|
||||
|
@ -674,7 +675,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
|
|||
|
||||
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|${CodeGenerator.JAVA_INT} ${ev.value} = $seed;
|
||||
|${CodeGenerator.JAVA_INT} $childHash = 0;
|
||||
|$codes
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
import org.apache.spark.rdd.InputFileBlockHolder
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.types.{DataType, LongType, StringType}
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
|
@ -42,8 +43,9 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
|
|||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
|
||||
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
|
||||
s"$className.getInputFilePath();", isNull = FalseLiteral)
|
||||
val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
|
||||
ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();",
|
||||
isNull = FalseLiteral)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -65,8 +67,8 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic {
|
|||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
|
||||
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
|
||||
s"$className.getStartOffset();", isNull = FalseLiteral)
|
||||
val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
|
||||
ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -88,7 +90,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic {
|
|||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
|
||||
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
|
||||
s"$className.getLength();", isNull = FalseLiteral)
|
||||
val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
|
||||
ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.NumberConverter
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
@ -1191,11 +1192,11 @@ abstract class RoundBase(child: Expression, scale: Expression,
|
|||
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
if (scaleV == null) { // if scale is null, no need to eval its child at all
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
boolean ${ev.isNull} = true;
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${ce.code}
|
||||
boolean ${ev.isNull} = ${ce.isNull};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
|
|
@ -21,6 +21,7 @@ import java.util.UUID
|
|||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
@ -88,7 +89,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa
|
|||
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
|
||||
// because errMsgField is used only when the value is null or false.
|
||||
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
|
||||
ExprCode(code = s"""${eval.code}
|
||||
ExprCode(code = code"""${eval.code}
|
||||
|if (${eval.isNull} || !${eval.value}) {
|
||||
| throw new RuntimeException($errMsgField);
|
||||
|}""".stripMargin, isNull = TrueLiteral,
|
||||
|
@ -151,7 +152,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta
|
|||
ctx.addPartitionInitializationStatement(s"$randomGen = " +
|
||||
"new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" +
|
||||
s"${randomSeed.get}L + partitionIndex);")
|
||||
ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();",
|
||||
ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();",
|
||||
isNull = FalseLiteral)
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -111,7 +112,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
|
|||
|
||||
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|${ev.isNull} = true;
|
||||
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|do {
|
||||
|
@ -232,7 +233,7 @@ case class IsNaN(child: Expression) extends UnaryExpression
|
|||
val eval = child.genCode(ctx)
|
||||
child.dataType match {
|
||||
case DoubleType | FloatType =>
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${eval.code}
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral)
|
||||
|
@ -278,7 +279,7 @@ case class NaNvl(left: Expression, right: Expression)
|
|||
val rightGen = right.genCode(ctx)
|
||||
left.dataType match {
|
||||
case DoubleType | FloatType =>
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${leftGen.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
@ -440,7 +441,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
|
|||
}.mkString)
|
||||
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|${CodeGenerator.JAVA_INT} $nonnull = 0;
|
||||
|do {
|
||||
| $codes
|
||||
|
|
|
@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
|
|||
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||
|
@ -269,7 +270,7 @@ case class StaticInvoke(
|
|||
s"${ev.value} = $callFunc;"
|
||||
}
|
||||
|
||||
val code = s"""
|
||||
val code = code"""
|
||||
$argCode
|
||||
$prepareIsNull
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
@ -385,8 +386,7 @@ case class Invoke(
|
|||
"""
|
||||
}
|
||||
|
||||
val code = s"""
|
||||
${obj.code}
|
||||
val code = obj.code + code"""
|
||||
boolean ${ev.isNull} = true;
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!${obj.isNull}) {
|
||||
|
@ -492,7 +492,7 @@ case class NewInstance(
|
|||
s"new $className($argString)"
|
||||
}
|
||||
|
||||
val code = s"""
|
||||
val code = code"""
|
||||
$argCode
|
||||
${outer.map(_.code).getOrElse("")}
|
||||
final $javaType ${ev.value} = ${ev.isNull} ?
|
||||
|
@ -532,9 +532,7 @@ case class UnwrapOption(
|
|||
val javaType = CodeGenerator.javaType(dataType)
|
||||
val inputObject = child.genCode(ctx)
|
||||
|
||||
val code = s"""
|
||||
${inputObject.code}
|
||||
|
||||
val code = inputObject.code + code"""
|
||||
final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
|
||||
$javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} :
|
||||
(${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get();
|
||||
|
@ -564,9 +562,7 @@ case class WrapOption(child: Expression, optType: DataType)
|
|||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val inputObject = child.genCode(ctx)
|
||||
|
||||
val code = s"""
|
||||
${inputObject.code}
|
||||
|
||||
val code = inputObject.code + code"""
|
||||
scala.Option ${ev.value} =
|
||||
${inputObject.isNull} ?
|
||||
scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value});
|
||||
|
@ -935,8 +931,7 @@ case class MapObjects private(
|
|||
)
|
||||
}
|
||||
|
||||
val code = s"""
|
||||
${genInputData.code}
|
||||
val code = genInputData.code + code"""
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
||||
if (!${genInputData.isNull}) {
|
||||
|
@ -1147,8 +1142,7 @@ case class CatalystToExternalMap private(
|
|||
"""
|
||||
val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();"
|
||||
|
||||
val code = s"""
|
||||
${genInputData.code}
|
||||
val code = genInputData.code + code"""
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
||||
if (!${genInputData.isNull}) {
|
||||
|
@ -1391,9 +1385,8 @@ case class ExternalMapToCatalyst private(
|
|||
val mapCls = classOf[ArrayBasedMapData].getName
|
||||
val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType)
|
||||
val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType)
|
||||
val code =
|
||||
s"""
|
||||
${inputMap.code}
|
||||
val code = inputMap.code +
|
||||
code"""
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!${inputMap.isNull}) {
|
||||
final int $length = ${inputMap.value}.size();
|
||||
|
@ -1471,7 +1464,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
|
|||
val schemaField = ctx.addReferenceObj("schema", schema)
|
||||
|
||||
val code =
|
||||
s"""
|
||||
code"""
|
||||
|Object[] $values = new Object[${children.size}];
|
||||
|$childrenCode
|
||||
|final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField);
|
||||
|
@ -1499,8 +1492,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
|
|||
val javaType = CodeGenerator.javaType(dataType)
|
||||
val serialize = s"$serializer.serialize(${input.value}, null).array()"
|
||||
|
||||
val code = s"""
|
||||
${input.code}
|
||||
val code = input.code + code"""
|
||||
final $javaType ${ev.value} =
|
||||
${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize;
|
||||
"""
|
||||
|
@ -1532,8 +1524,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
|
|||
val deserialize =
|
||||
s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)"
|
||||
|
||||
val code = s"""
|
||||
${input.code}
|
||||
val code = input.code + code"""
|
||||
final $javaType ${ev.value} =
|
||||
${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize;
|
||||
"""
|
||||
|
@ -1614,9 +1605,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
|
|||
funcName = "initializeJavaBean",
|
||||
extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil)
|
||||
|
||||
val code =
|
||||
s"""
|
||||
|${instanceGen.code}
|
||||
val code = instanceGen.code +
|
||||
code"""
|
||||
|$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value};
|
||||
|if (!${instanceGen.isNull}) {
|
||||
| $initializeCode
|
||||
|
@ -1664,9 +1654,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)
|
|||
// because errMsgField is used only when the value is null.
|
||||
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
|
||||
|
||||
val code = s"""
|
||||
${childGen.code}
|
||||
|
||||
val code = childGen.code + code"""
|
||||
if (${childGen.isNull}) {
|
||||
throw new NullPointerException($errMsgField);
|
||||
}
|
||||
|
@ -1709,7 +1697,7 @@ case class GetExternalRowField(
|
|||
// because errMsgField is used only when the field is null.
|
||||
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
|
||||
val row = child.genCode(ctx)
|
||||
val code = s"""
|
||||
val code = code"""
|
||||
${row.code}
|
||||
|
||||
if (${row.isNull}) {
|
||||
|
@ -1784,7 +1772,7 @@ case class ValidateExternalType(child: Expression, expected: DataType)
|
|||
s"$obj instanceof ${CodeGenerator.boxedType(dataType)}"
|
||||
}
|
||||
|
||||
val code = s"""
|
||||
val code = code"""
|
||||
${input.code}
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!${input.isNull}) {
|
||||
|
|
|
@ -22,6 +22,7 @@ import scala.collection.immutable.TreeSet
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -290,7 +291,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|
|||
}.mkString("\n"))
|
||||
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|${valueGen.code}
|
||||
|byte $tmpResult = $HAS_NULL;
|
||||
|if (!${valueGen.isNull}) {
|
||||
|
@ -354,7 +355,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
|
|||
""
|
||||
}
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|${childGen.code}
|
||||
|${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull};
|
||||
|${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false;
|
||||
|
@ -406,7 +407,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
|
|||
|
||||
// The result should be `false`, if any of them is `false` whenever the other is null or not.
|
||||
if (!left.nullable && !right.nullable) {
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${eval1.code}
|
||||
boolean ${ev.value} = false;
|
||||
|
||||
|
@ -415,7 +416,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
|
|||
${ev.value} = ${eval2.value};
|
||||
}""", isNull = FalseLiteral)
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${eval1.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
boolean ${ev.value} = false;
|
||||
|
@ -470,7 +471,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
|
|||
// The result should be `true`, if any of them is `true` whenever the other is null or not.
|
||||
if (!left.nullable && !right.nullable) {
|
||||
ev.isNull = FalseLiteral
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${eval1.code}
|
||||
boolean ${ev.value} = true;
|
||||
|
||||
|
@ -479,7 +480,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
|
|||
${ev.value} = ${eval2.value};
|
||||
}""", isNull = FalseLiteral)
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${eval1.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
boolean ${ev.value} = true;
|
||||
|
@ -621,7 +622,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
|
|||
val eval1 = left.genCode(ctx)
|
||||
val eval2 = right.genCode(ctx)
|
||||
val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value)
|
||||
ev.copy(code = eval1.code + eval2.code + s"""
|
||||
ev.copy(code = eval1.code + eval2.code + code"""
|
||||
boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) ||
|
||||
(!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral)
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.util.random.XORShiftRandom
|
||||
|
@ -82,7 +83,7 @@ case class Rand(child: Expression) extends RDG {
|
|||
val rngTerm = ctx.addMutableState(className, "rng")
|
||||
ctx.addPartitionInitializationStatement(
|
||||
s"$rngTerm = new $className(${seed}L + partitionIndex);")
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""",
|
||||
isNull = FalseLiteral)
|
||||
}
|
||||
|
@ -120,7 +121,7 @@ case class Randn(child: Expression) extends RDG {
|
|||
val rngTerm = ctx.addMutableState(className, "rng")
|
||||
ctx.addPartitionInitializationStatement(
|
||||
s"$rngTerm = new $className(${seed}L + partitionIndex);")
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""",
|
||||
isNull = FalseLiteral)
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import java.util.regex.{MatchResult, Pattern}
|
|||
import org.apache.commons.lang3.StringEscapeUtils
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
@ -123,7 +124,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
|
|||
|
||||
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
|
||||
val eval = left.genCode(ctx)
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${eval.code}
|
||||
boolean ${ev.isNull} = ${eval.isNull};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
@ -132,7 +133,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
|
|||
}
|
||||
""")
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
boolean ${ev.isNull} = true;
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
""")
|
||||
|
@ -198,7 +199,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
|
|||
|
||||
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
|
||||
val eval = left.genCode(ctx)
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${eval.code}
|
||||
boolean ${ev.isNull} = ${eval.isNull};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
@ -207,7 +208,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
|
|||
}
|
||||
""")
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
boolean ${ev.isNull} = true;
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
""")
|
||||
|
|
|
@ -27,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
|
||||
|
@ -105,7 +106,7 @@ case class ConcatWs(children: Seq[Expression])
|
|||
expressions = inputs,
|
||||
funcName = "valueConcatWs",
|
||||
extraArguments = ("UTF8String[]", args) :: Nil)
|
||||
ev.copy(s"""
|
||||
ev.copy(code"""
|
||||
UTF8String[] $args = new UTF8String[$numArgs];
|
||||
${separator.code}
|
||||
$codes
|
||||
|
@ -149,7 +150,7 @@ case class ConcatWs(children: Seq[Expression])
|
|||
}
|
||||
}.unzip
|
||||
|
||||
val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code))
|
||||
val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code.toString))
|
||||
|
||||
val varargCounts = ctx.splitExpressionsWithCurrentInputs(
|
||||
expressions = varargCount,
|
||||
|
@ -176,7 +177,7 @@ case class ConcatWs(children: Seq[Expression])
|
|||
foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n"))
|
||||
|
||||
ev.copy(
|
||||
s"""
|
||||
code"""
|
||||
$codes
|
||||
int $varargNum = ${children.count(_.dataType == StringType) - 1};
|
||||
int $idxVararg = 0;
|
||||
|
@ -288,7 +289,7 @@ case class Elt(children: Seq[Expression]) extends Expression {
|
|||
}.mkString)
|
||||
|
||||
ev.copy(
|
||||
s"""
|
||||
code"""
|
||||
|${index.code}
|
||||
|final int $indexVal = ${index.value};
|
||||
|${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false;
|
||||
|
@ -654,7 +655,7 @@ case class StringTrim(
|
|||
val srcString = evals(0)
|
||||
|
||||
if (evals.length == 1) {
|
||||
ev.copy(evals.map(_.code).mkString + s"""
|
||||
ev.copy(evals.map(_.code) :+ code"""
|
||||
boolean ${ev.isNull} = false;
|
||||
UTF8String ${ev.value} = null;
|
||||
if (${srcString.isNull}) {
|
||||
|
@ -671,7 +672,7 @@ case class StringTrim(
|
|||
} else {
|
||||
${ev.value} = ${srcString.value}.trim(${trimString.value});
|
||||
}"""
|
||||
ev.copy(evals.map(_.code).mkString + s"""
|
||||
ev.copy(evals.map(_.code) :+ code"""
|
||||
boolean ${ev.isNull} = false;
|
||||
UTF8String ${ev.value} = null;
|
||||
if (${srcString.isNull}) {
|
||||
|
@ -754,7 +755,7 @@ case class StringTrimLeft(
|
|||
val srcString = evals(0)
|
||||
|
||||
if (evals.length == 1) {
|
||||
ev.copy(evals.map(_.code).mkString + s"""
|
||||
ev.copy(evals.map(_.code) :+ code"""
|
||||
boolean ${ev.isNull} = false;
|
||||
UTF8String ${ev.value} = null;
|
||||
if (${srcString.isNull}) {
|
||||
|
@ -771,7 +772,7 @@ case class StringTrimLeft(
|
|||
} else {
|
||||
${ev.value} = ${srcString.value}.trimLeft(${trimString.value});
|
||||
}"""
|
||||
ev.copy(evals.map(_.code).mkString + s"""
|
||||
ev.copy(evals.map(_.code) :+ code"""
|
||||
boolean ${ev.isNull} = false;
|
||||
UTF8String ${ev.value} = null;
|
||||
if (${srcString.isNull}) {
|
||||
|
@ -856,7 +857,7 @@ case class StringTrimRight(
|
|||
val srcString = evals(0)
|
||||
|
||||
if (evals.length == 1) {
|
||||
ev.copy(evals.map(_.code).mkString + s"""
|
||||
ev.copy(evals.map(_.code) :+ code"""
|
||||
boolean ${ev.isNull} = false;
|
||||
UTF8String ${ev.value} = null;
|
||||
if (${srcString.isNull}) {
|
||||
|
@ -873,7 +874,7 @@ case class StringTrimRight(
|
|||
} else {
|
||||
${ev.value} = ${srcString.value}.trimRight(${trimString.value});
|
||||
}"""
|
||||
ev.copy(evals.map(_.code).mkString + s"""
|
||||
ev.copy(evals.map(_.code) :+ code"""
|
||||
boolean ${ev.isNull} = false;
|
||||
UTF8String ${ev.value} = null;
|
||||
if (${srcString.isNull}) {
|
||||
|
@ -1024,7 +1025,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
|
|||
val substrGen = substr.genCode(ctx)
|
||||
val strGen = str.genCode(ctx)
|
||||
val startGen = start.genCode(ctx)
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
int ${ev.value} = 0;
|
||||
boolean ${ev.isNull} = false;
|
||||
${startGen.code}
|
||||
|
@ -1350,7 +1351,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
|
|||
val formatter = classOf[java.util.Formatter].getName
|
||||
val sb = ctx.freshName("sb")
|
||||
val stringBuffer = classOf[StringBuffer].getName
|
||||
ev.copy(code = s"""
|
||||
ev.copy(code = code"""
|
||||
${pattern.code}
|
||||
boolean ${ev.isNull} = ${pattern.isNull};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.types.{DataType, IntegerType}
|
||||
|
||||
/**
|
||||
|
@ -45,7 +46,7 @@ case class BadCodegenExpression() extends LeafExpression {
|
|||
override def eval(input: InternalRow): Any = 10
|
||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
ev.copy(code =
|
||||
s"""
|
||||
code"""
|
||||
|int some_variable = 11;
|
||||
|int ${ev.value} = 10;
|
||||
""".stripMargin)
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.expressions.codegen
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.types.{BooleanType, IntegerType}
|
||||
|
||||
class CodeBlockSuite extends SparkFunSuite {
|
||||
|
||||
test("Block interpolates string and ExprValue inputs") {
|
||||
val isNull = JavaCode.isNullVariable("expr1_isNull")
|
||||
val stringLiteral = "false"
|
||||
val code = code"boolean $isNull = $stringLiteral;"
|
||||
assert(code.toString == "boolean expr1_isNull = false;")
|
||||
}
|
||||
|
||||
test("Literals are folded into string code parts instead of block inputs") {
|
||||
val value = JavaCode.variable("expr1", IntegerType)
|
||||
val intLiteral = 1
|
||||
val code = code"int $value = $intLiteral;"
|
||||
assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value))
|
||||
}
|
||||
|
||||
test("Block.stripMargin") {
|
||||
val isNull = JavaCode.isNullVariable("expr1_isNull")
|
||||
val value = JavaCode.variable("expr1", IntegerType)
|
||||
val code1 =
|
||||
code"""
|
||||
|boolean $isNull = false;
|
||||
|int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin
|
||||
val expected =
|
||||
s"""
|
||||
|boolean expr1_isNull = false;
|
||||
|int expr1 = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin.trim
|
||||
assert(code1.toString == expected)
|
||||
|
||||
val code2 =
|
||||
code"""
|
||||
>boolean $isNull = false;
|
||||
>int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin('>')
|
||||
assert(code2.toString == expected)
|
||||
}
|
||||
|
||||
test("Block can capture input expr values") {
|
||||
val isNull = JavaCode.isNullVariable("expr1_isNull")
|
||||
val value = JavaCode.variable("expr1", IntegerType)
|
||||
val code =
|
||||
code"""
|
||||
|boolean $isNull = false;
|
||||
|int $value = -1;
|
||||
""".stripMargin
|
||||
val exprValues = code.exprValues
|
||||
assert(exprValues.size == 2)
|
||||
assert(exprValues === Set(value, isNull))
|
||||
}
|
||||
|
||||
test("concatenate blocks") {
|
||||
val isNull1 = JavaCode.isNullVariable("expr1_isNull")
|
||||
val value1 = JavaCode.variable("expr1", IntegerType)
|
||||
val isNull2 = JavaCode.isNullVariable("expr2_isNull")
|
||||
val value2 = JavaCode.variable("expr2", IntegerType)
|
||||
val literal = JavaCode.literal("100", IntegerType)
|
||||
|
||||
val code =
|
||||
code"""
|
||||
|boolean $isNull1 = false;
|
||||
|int $value1 = -1;""".stripMargin +
|
||||
code"""
|
||||
|boolean $isNull2 = true;
|
||||
|int $value2 = $literal;""".stripMargin
|
||||
|
||||
val expected =
|
||||
"""
|
||||
|boolean expr1_isNull = false;
|
||||
|int expr1 = -1;
|
||||
|boolean expr2_isNull = true;
|
||||
|int expr2 = 100;""".stripMargin.trim
|
||||
|
||||
assert(code.toString == expected)
|
||||
|
||||
val exprValues = code.exprValues
|
||||
assert(exprValues.size == 5)
|
||||
assert(exprValues === Set(isNull1, value1, isNull2, value2, literal))
|
||||
}
|
||||
|
||||
test("Throws exception when interpolating unexcepted object in code block") {
|
||||
val obj = Tuple2(1, 1)
|
||||
val e = intercept[IllegalArgumentException] {
|
||||
code"$obj"
|
||||
}
|
||||
assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}"))
|
||||
}
|
||||
|
||||
test("replace expr values in code block") {
|
||||
val expr = JavaCode.expression("1 + 1", IntegerType)
|
||||
val isNull = JavaCode.isNullVariable("expr1_isNull")
|
||||
val exprInFunc = JavaCode.variable("expr1", IntegerType)
|
||||
|
||||
val code =
|
||||
code"""
|
||||
|callFunc(int $expr) {
|
||||
| boolean $isNull = false;
|
||||
| int $exprInFunc = $expr + 1;
|
||||
|}""".stripMargin
|
||||
|
||||
val aliasedParam = JavaCode.variable("aliased", expr.javaType)
|
||||
val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map {
|
||||
case _: SimpleExprValue => aliasedParam
|
||||
case other => other
|
||||
}
|
||||
val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin
|
||||
val expected =
|
||||
code"""
|
||||
|callFunc(int $aliasedParam) {
|
||||
| boolean $isNull = false;
|
||||
| int $exprInFunc = $aliasedParam + 1;
|
||||
|}""".stripMargin
|
||||
assert(aliasedCode.toString == expected.toString)
|
||||
}
|
||||
}
|
|
@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
|
|||
|
||||
import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
import org.apache.spark.sql.types.DataType
|
||||
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
|
||||
|
@ -58,14 +59,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
|
|||
}
|
||||
val valueVar = ctx.freshName("value")
|
||||
val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
|
||||
val code = s"${ctx.registerComment(str)}\n" + (if (nullable) {
|
||||
s"""
|
||||
val code = code"${ctx.registerComment(str)}" + (if (nullable) {
|
||||
code"""
|
||||
boolean $isNullVar = $columnVar.isNullAt($ordinal);
|
||||
$javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value);
|
||||
"""
|
||||
} else {
|
||||
s"$javaType $valueVar = $value;"
|
||||
}).trim
|
||||
code"$javaType $valueVar = $value;"
|
||||
})
|
||||
ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType))
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.errors._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
|
||||
|
@ -152,7 +153,7 @@ case class ExpandExec(
|
|||
} else {
|
||||
val isNull = ctx.freshName("isNull")
|
||||
val value = ctx.freshName("value")
|
||||
val code = s"""
|
||||
val code = code"""
|
||||
|boolean $isNull = true;
|
||||
|${CodeGenerator.javaType(firstExpr.dataType)} $value =
|
||||
| ${CodeGenerator.defaultValue(firstExpr.dataType)};
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -313,13 +314,13 @@ case class GenerateExec(
|
|||
if (checks.nonEmpty) {
|
||||
val isNull = ctx.freshName("isNull")
|
||||
val code =
|
||||
s"""
|
||||
code"""
|
||||
|boolean $isNull = ${checks.mkString(" || ")};
|
||||
|$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter;
|
||||
""".stripMargin
|
||||
ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt))
|
||||
} else {
|
||||
ExprCode(s"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt))
|
||||
ExprCode(code"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
|
||||
|
@ -122,10 +123,10 @@ trait CodegenSupport extends SparkPlan {
|
|||
ctx.INPUT_ROW = row
|
||||
ctx.currentVars = colVars
|
||||
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
|
||||
val code = s"""
|
||||
val code = code"""
|
||||
|$evaluateInputs
|
||||
|${ev.code.trim}
|
||||
""".stripMargin.trim
|
||||
|${ev.code}
|
||||
""".stripMargin
|
||||
ExprCode(code, FalseLiteral, ev.value)
|
||||
} else {
|
||||
// There are no columns
|
||||
|
@ -259,8 +260,8 @@ trait CodegenSupport extends SparkPlan {
|
|||
* them to be evaluated twice.
|
||||
*/
|
||||
protected def evaluateVariables(variables: Seq[ExprCode]): String = {
|
||||
val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n")
|
||||
variables.foreach(_.code = "")
|
||||
val evaluate = variables.filter(_.code.nonEmpty).map(_.code.toString).mkString("\n")
|
||||
variables.foreach(_.code = EmptyBlock)
|
||||
evaluate
|
||||
}
|
||||
|
||||
|
@ -275,8 +276,8 @@ trait CodegenSupport extends SparkPlan {
|
|||
val evaluateVars = new StringBuilder
|
||||
variables.zipWithIndex.foreach { case (ev, i) =>
|
||||
if (ev.code != "" && required.contains(attributes(i))) {
|
||||
evaluateVars.append(ev.code.trim + "\n")
|
||||
ev.code = ""
|
||||
evaluateVars.append(ev.code.toString + "\n")
|
||||
ev.code = EmptyBlock
|
||||
}
|
||||
}
|
||||
evaluateVars.toString()
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.errors._
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution._
|
||||
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
|
||||
|
@ -190,7 +191,7 @@ case class HashAggregateExec(
|
|||
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
|
||||
// The initial expression should not access any column
|
||||
val ev = e.genCode(ctx)
|
||||
val initVars = s"""
|
||||
val initVars = code"""
|
||||
| $isNull = ${ev.isNull};
|
||||
| $value = ${ev.value};
|
||||
""".stripMargin
|
||||
|
@ -773,8 +774,8 @@ case class HashAggregateExec(
|
|||
val findOrInsertRegularHashMap: String =
|
||||
s"""
|
||||
|// generate grouping key
|
||||
|${unsafeRowKeyCode.code.trim}
|
||||
|${hashEval.code.trim}
|
||||
|${unsafeRowKeyCode.code}
|
||||
|${hashEval.code}
|
||||
|if ($checkFallbackForBytesToBytesMap) {
|
||||
| // try to get the buffer from hash map
|
||||
| $unsafeRowBuffer =
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate
|
|||
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
|
@ -50,7 +51,7 @@ abstract class HashMapGenerator(
|
|||
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
|
||||
val ev = e.genCode(ctx)
|
||||
val initVars =
|
||||
s"""
|
||||
code"""
|
||||
| $isNull = ${ev.isNull};
|
||||
| $value = ${ev.value};
|
||||
""".stripMargin
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
|
||||
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan}
|
||||
|
@ -183,7 +184,7 @@ case class BroadcastHashJoinExec(
|
|||
val isNull = ctx.freshName("isNull")
|
||||
val value = ctx.freshName("value")
|
||||
val javaType = CodeGenerator.javaType(a.dataType)
|
||||
val code = s"""
|
||||
val code = code"""
|
||||
|boolean $isNull = true;
|
||||
|$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
|
||||
|if ($matched != null) {
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution._
|
||||
|
@ -521,7 +522,7 @@ case class SortMergeJoinExec(
|
|||
if (a.nullable) {
|
||||
val isNull = ctx.freshName("isNull")
|
||||
val code =
|
||||
s"""
|
||||
code"""
|
||||
|$isNull = $leftRow.isNullAt($i);
|
||||
|$value = $isNull ? $defaultValue : ($valueCode);
|
||||
""".stripMargin
|
||||
|
@ -533,7 +534,7 @@ case class SortMergeJoinExec(
|
|||
(ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)),
|
||||
leftVarsDecl)
|
||||
} else {
|
||||
val code = s"$value = $valueCode;"
|
||||
val code = code"$value = $valueCode;"
|
||||
val leftVarsDecl = s"""$javaType $value = $defaultValue;"""
|
||||
(ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl)
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, Generator}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types.{IntegerType, StructType}
|
||||
|
@ -315,6 +316,7 @@ case class EmptyGenerator() extends Generator {
|
|||
override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty
|
||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val iteratorClass = classOf[Iterator[_]].getName
|
||||
ev.copy(code = s"$iteratorClass<InternalRow> ${ev.value} = $iteratorClass$$.MODULE$$.empty();")
|
||||
ev.copy(code =
|
||||
code"$iteratorClass<InternalRow> ${ev.value} = $iteratorClass$$.MODULE$$.empty();")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue