[SPARK-23546][SQL] Refactor stateless methods/values in CodegenContext

## What changes were proposed in this pull request?

A current `CodegenContext` class has immutable value or method without mutable state, too.
This refactoring moves them to `CodeGenerator` object class which can be accessed from anywhere without an instantiated `CodegenContext` in the program.

## How was this patch tested?

Existing tests

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #20700 from kiszk/SPARK-23546.
This commit is contained in:
Kazuaki Ishizaki 2018-03-05 11:39:01 +01:00 committed by Herman van Hovell
parent 269cd53590
commit 2ce37b50fc
45 changed files with 535 additions and 497 deletions

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types._
/**
@ -66,13 +66,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
ev.copy(code = oev.code)
} else {
assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.")
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
val javaType = CodeGenerator.javaType(dataType)
val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (nullable) {
ev.copy(code =
s"""
|boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
|$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
|$javaType ${ev.value} = ${ev.isNull} ?
| ${CodeGenerator.defaultValue(dataType)} : ($value);
""".stripMargin)
} else {
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false")

View file

@ -669,7 +669,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = {
s"""
boolean $resultIsNull = $inputIsNull;
${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)};
${CodeGenerator.javaType(resultType)} $result = ${CodeGenerator.defaultValue(resultType)};
if (!$inputIsNull) {
${cast(input, result, resultIsNull)}
}
@ -685,7 +685,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val funcName = ctx.freshName("elementToString")
val elementToStringFunc = ctx.addNewFunction(funcName,
s"""
|private UTF8String $funcName(${ctx.javaType(et)} element) {
|private UTF8String $funcName(${CodeGenerator.javaType(et)} element) {
| UTF8String elementStr = null;
| ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)}
| return elementStr;
@ -697,13 +697,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|$buffer.append("[");
|if ($array.numElements() > 0) {
| if (!$array.isNullAt(0)) {
| $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")}));
| $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, "0")}));
| }
| for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) {
| $buffer.append(",");
| if (!$array.isNullAt($loopIndex)) {
| $buffer.append(" ");
| $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)}));
| $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, loopIndex)}));
| }
| }
|}
@ -723,7 +723,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val dataToStringCode = castToStringCode(dataType, ctx)
ctx.addNewFunction(funcName,
s"""
|private UTF8String $funcName(${ctx.javaType(dataType)} data) {
|private UTF8String $funcName(${CodeGenerator.javaType(dataType)} data) {
| UTF8String dataStr = null;
| ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)}
| return dataStr;
@ -734,23 +734,26 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val keyToStringFunc = dataToStringFunc("keyToString", kt)
val valueToStringFunc = dataToStringFunc("valueToString", vt)
val loopIndex = ctx.freshName("loopIndex")
val getMapFirstKey = CodeGenerator.getValue(s"$map.keyArray()", kt, "0")
val getMapFirstValue = CodeGenerator.getValue(s"$map.valueArray()", vt, "0")
val getMapKeyArray = CodeGenerator.getValue(s"$map.keyArray()", kt, loopIndex)
val getMapValueArray = CodeGenerator.getValue(s"$map.valueArray()", vt, loopIndex)
s"""
|$buffer.append("[");
|if ($map.numElements() > 0) {
| $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")}));
| $buffer.append($keyToStringFunc($getMapFirstKey));
| $buffer.append(" ->");
| if (!$map.valueArray().isNullAt(0)) {
| $buffer.append(" ");
| $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")}));
| $buffer.append($valueToStringFunc($getMapFirstValue));
| }
| for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) {
| $buffer.append(", ");
| $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)}));
| $buffer.append($keyToStringFunc($getMapKeyArray));
| $buffer.append(" ->");
| if (!$map.valueArray().isNullAt($loopIndex)) {
| $buffer.append(" ");
| $buffer.append($valueToStringFunc(
| ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)}));
| $buffer.append($valueToStringFunc($getMapValueArray));
| }
| }
|}
@ -773,7 +776,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
| ${if (i != 0) s"""$buffer.append(" ");""" else ""}
|
| // Append $i field into the string buffer
| ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")};
| ${CodeGenerator.javaType(ft)} $field = ${CodeGenerator.getValue(row, ft, s"$i")};
| UTF8String $fieldStr = null;
| ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)}
| $buffer.append($fieldStr);
@ -1202,8 +1205,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
$values[$j] = null;
} else {
boolean $fromElementNull = false;
${ctx.javaType(fromType)} $fromElementPrim =
${ctx.getValue(c, fromType, j)};
${CodeGenerator.javaType(fromType)} $fromElementPrim =
${CodeGenerator.getValue(c, fromType, j)};
${castCode(ctx, fromElementPrim,
fromElementNull, toElementPrim, toElementNull, toType, elementCast)}
if ($toElementNull) {
@ -1259,20 +1262,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val fromFieldNull = ctx.freshName("ffn")
val toFieldPrim = ctx.freshName("tfp")
val toFieldNull = ctx.freshName("tfn")
val fromType = ctx.javaType(from.fields(i).dataType)
val fromType = CodeGenerator.javaType(from.fields(i).dataType)
s"""
boolean $fromFieldNull = $tmpInput.isNullAt($i);
if ($fromFieldNull) {
$tmpResult.setNullAt($i);
} else {
$fromType $fromFieldPrim =
${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)};
${CodeGenerator.getValue(tmpInput, from.fields(i).dataType, i.toString)};
${castCode(ctx, fromFieldPrim,
fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)}
if ($toFieldNull) {
$tmpResult.setNullAt($i);
} else {
${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)};
${CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)};
}
}
"""

View file

@ -119,7 +119,7 @@ abstract class Expression extends TreeNode[Expression] {
// TODO: support whole stage codegen too
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") {
val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull")
val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
eval.isNull = globalIsNull
s"$globalIsNull = $localIsNull;"
@ -127,7 +127,7 @@ abstract class Expression extends TreeNode[Expression] {
""
}
val javaType = ctx.javaType(dataType)
val javaType = CodeGenerator.javaType(dataType)
val newValue = ctx.freshName("value")
val funcName = ctx.freshName(nodeName)
@ -411,14 +411,14 @@ abstract class UnaryExpression extends Expression {
ev.copy(code = s"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
${childGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = "false")
}
}
@ -510,7 +510,7 @@ abstract class BinaryExpression extends Expression {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
@ -518,7 +518,7 @@ abstract class BinaryExpression extends Expression {
boolean ${ev.isNull} = false;
${leftGen.code}
${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = "false")
}
}
@ -654,7 +654,7 @@ abstract class TernaryExpression extends Expression {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval""")
} else {
ev.copy(code = s"""
@ -662,7 +662,7 @@ abstract class TernaryExpression extends Expression {
${leftGen.code}
${midGen.code}
${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = "false")
}
}

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types.{DataType, LongType}
/**
@ -65,14 +65,14 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count")
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count")
val partitionMaskTerm = "partitionMask"
ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm)
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm)
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
ev.copy(code = s"""
final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
$countTerm++;""", isNull = "false")
}

View file

@ -1018,11 +1018,12 @@ case class ScalaUDF(
val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}")
val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})"
val resultConverter = s"$convertersTerm[${children.length}]"
val boxedType = CodeGenerator.boxedType(dataType)
val callFunc =
s"""
|${ctx.boxedType(dataType)} $resultTerm = null;
|$boxedType $resultTerm = null;
|try {
| $resultTerm = (${ctx.boxedType(dataType)})$resultConverter.apply($getFuncResult);
| $resultTerm = ($boxedType)$resultConverter.apply($getFuncResult);
|} catch (Exception e) {
| throw new org.apache.spark.SparkException($errorMsgTerm, e);
|}
@ -1035,7 +1036,7 @@ case class ScalaUDF(
|$callFunc
|
|boolean ${ev.isNull} = $resultTerm == null;
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $resultTerm;
|}

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}
/**
@ -44,8 +44,9 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val idTerm = "partitionId"
ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm)
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm)
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
isNull = "false")
}
}

View file

@ -22,7 +22,7 @@ import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@ -165,7 +165,7 @@ case class PreciseTimestampConversion(
val eval = child.genCode(ctx)
ev.copy(code = eval.code +
s"""boolean ${ev.isNull} = ${eval.isNull};
|${ctx.javaType(dataType)} ${ev.value} = ${eval.value};
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value};
""".stripMargin)
}
override def nullSafeEval(input: Any): Any = input

View file

@ -49,8 +49,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
// codegen would fail to compile if we just write (-($c))
// for example, we could not write --9223372036854775808L in code
s"""
${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval);
${ev.value} = (${ctx.javaType(dt)})(-($originValue));
${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval);
${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue));
"""})
case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
}
@ -107,7 +107,7 @@ case class Abs(child: Expression)
case dt: DecimalType =>
defineCodeGen(ctx, ev, c => s"$c.abs()")
case dt: NumericType =>
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(java.lang.Math.abs($c))")
defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))")
}
protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
@ -129,7 +129,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
(eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
@ -167,7 +167,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)")
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
(eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
case CalendarIntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)")
case _ =>
@ -203,7 +203,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)")
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
(eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
case CalendarIntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)")
case _ =>
@ -278,7 +278,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
} else {
s"${eval2.value} == 0"
}
val javaType = ctx.javaType(dataType)
val javaType = CodeGenerator.javaType(dataType)
val divide = if (dataType.isInstanceOf[DecimalType]) {
s"${eval1.value}.$decimalMethod(${eval2.value})"
} else {
@ -288,7 +288,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
ev.copy(code = s"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if ($isZero) {
${ev.isNull} = true;
} else {
@ -299,7 +299,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
ev.copy(code = s"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (${eval2.isNull} || $isZero) {
${ev.isNull} = true;
} else {
@ -365,7 +365,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
} else {
s"${eval2.value} == 0"
}
val javaType = ctx.javaType(dataType)
val javaType = CodeGenerator.javaType(dataType)
val remainder = if (dataType.isInstanceOf[DecimalType]) {
s"${eval1.value}.$decimalMethod(${eval2.value})"
} else {
@ -375,7 +375,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
ev.copy(code = s"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if ($isZero) {
${ev.isNull} = true;
} else {
@ -386,7 +386,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
ev.copy(code = s"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (${eval2.isNull} || $isZero) {
${ev.isNull} = true;
} else {
@ -454,13 +454,13 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
s"${eval2.value} == 0"
}
val remainder = ctx.freshName("remainder")
val javaType = ctx.javaType(dataType)
val javaType = CodeGenerator.javaType(dataType)
val result = dataType match {
case DecimalType.Fixed(_, _) =>
val decimalAdd = "$plus"
s"""
${ctx.javaType(dataType)} $remainder = ${eval1.value}.remainder(${eval2.value});
$javaType $remainder = ${eval1.value}.remainder(${eval2.value});
if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
${ev.value}=($remainder.$decimalAdd(${eval2.value})).remainder(${eval2.value});
} else {
@ -470,17 +470,16 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
s"""
${ctx.javaType(dataType)} $remainder =
(${ctx.javaType(dataType)})(${eval1.value} % ${eval2.value});
$javaType $remainder = ($javaType)(${eval1.value} % ${eval2.value});
if ($remainder < 0) {
${ev.value}=(${ctx.javaType(dataType)})(($remainder + ${eval2.value}) % ${eval2.value});
${ev.value}=($javaType)(($remainder + ${eval2.value}) % ${eval2.value});
} else {
${ev.value}=$remainder;
}
"""
case _ =>
s"""
${ctx.javaType(dataType)} $remainder = ${eval1.value} % ${eval2.value};
$javaType $remainder = ${eval1.value} % ${eval2.value};
if ($remainder < 0) {
${ev.value}=($remainder + ${eval2.value}) % ${eval2.value};
} else {
@ -493,7 +492,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
ev.copy(code = s"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if ($isZero) {
${ev.isNull} = true;
} else {
@ -504,7 +503,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
ev.copy(code = s"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (${eval2.isNull} || $isZero) {
${ev.isNull} = true;
} else {
@ -602,7 +601,7 @@ case class Least(children: Seq[Expression]) extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
@ -614,7 +613,7 @@ case class Least(children: Seq[Expression]) extends Expression {
""".stripMargin
)
val resultType = ctx.javaType(dataType)
val resultType = CodeGenerator.javaType(dataType)
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = evals,
funcName = "least",
@ -629,7 +628,7 @@ case class Least(children: Seq[Expression]) extends Expression {
ev.copy(code =
s"""
|${ev.isNull} = true;
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$codes
""".stripMargin)
}
@ -681,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evalChildren = children.map(_.genCode(ctx))
ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)
val evals = evalChildren.map(eval =>
s"""
|${eval.code}
@ -693,7 +692,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
""".stripMargin
)
val resultType = ctx.javaType(dataType)
val resultType = CodeGenerator.javaType(dataType)
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = evals,
funcName = "greatest",
@ -708,7 +707,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
ev.copy(code =
s"""
|${ev.isNull} = true;
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$codes
""".stripMargin)
}

View file

@ -147,7 +147,7 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)")
defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dataType)}) ~($c)")
}
protected override def nullSafeEval(input: Any): Any = not(input)

View file

@ -59,6 +59,11 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
case class ExprCode(var code: String, var isNull: String, var value: String)
object ExprCode {
def forNullValue(dataType: DataType): ExprCode = {
val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true)
ExprCode(code = "", isNull = "true", value = defaultValueLiteral)
}
def forNonNullValue(value: String): ExprCode = {
ExprCode(code = "", isNull = "false", value = value)
}
@ -105,6 +110,8 @@ private[codegen] case class NewFunctionSpec(
*/
class CodegenContext {
import CodeGenerator._
/**
* Holding a list of objects that could be used passed into generated class.
*/
@ -196,11 +203,11 @@ class CodegenContext {
/**
* Returns the reference of next available slot in current compacted array. The size of each
* compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`.
* compacted array is controlled by the constant `MUTABLESTATEARRAY_SIZE_LIMIT`.
* Once reaching the threshold, new compacted array is created.
*/
def getNextSlot(): String = {
if (currentIndex < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT) {
if (currentIndex < MUTABLESTATEARRAY_SIZE_LIMIT) {
val res = s"${arrayNames.last}[$currentIndex]"
currentIndex += 1
res
@ -247,10 +254,10 @@ class CodegenContext {
* are satisfied:
* 1. forceInline is true
* 2. its type is primitive type and the total number of the inlined mutable variables
* is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD`
* is less than `OUTER_CLASS_VARIABLES_THRESHOLD`
* 3. its type is multi-dimensional array
* When a variable is compacted into an array, the max size of the array for compaction
* is given by `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`.
* is given by `MUTABLESTATEARRAY_SIZE_LIMIT`.
*/
def addMutableState(
javaType: String,
@ -261,7 +268,7 @@ class CodegenContext {
// want to put a primitive type variable at outerClass for performance
val canInlinePrimitive = isPrimitiveType(javaType) &&
(inlinedMutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD)
(inlinedMutableStates.length < OUTER_CLASS_VARIABLES_THRESHOLD)
if (forceInline || canInlinePrimitive || javaType.contains("[][]")) {
val varName = if (useFreshName) freshName(variableName) else variableName
val initCode = initFunc(varName)
@ -339,7 +346,7 @@ class CodegenContext {
val length = if (index + 1 == numArrays) {
mutableStateArrays.getCurrentIndex
} else {
CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT
MUTABLESTATEARRAY_SIZE_LIMIT
}
if (javaType.contains("[]")) {
// initializer had an one-dimensional array variable
@ -468,7 +475,7 @@ class CodegenContext {
inlineToOuterClass: Boolean): NewFunctionSpec = {
val (className, classInstance) = if (inlineToOuterClass) {
outerClassName -> ""
} else if (currClassSize > CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD) {
} else if (currClassSize > GENERATED_CLASS_SIZE_THRESHOLD) {
val className = freshName("NestedClass")
val classInstance = freshName("nestedClassInstance")
@ -537,14 +544,6 @@ class CodegenContext {
extraClasses.append(code)
}
final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
final val JAVA_INT = "int"
final val JAVA_LONG = "long"
final val JAVA_FLOAT = "float"
final val JAVA_DOUBLE = "double"
/**
* The map from a variable name to it's next ID.
*/
@ -580,196 +579,6 @@ class CodegenContext {
}
}
/**
* Returns the specialized code to access a value from `inputRow` at `ordinal`.
*/
def getValue(input: String, dataType: DataType, ordinal: String): String = {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)"
case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})"
case StringType => s"$input.getUTF8String($ordinal)"
case BinaryType => s"$input.getBinary($ordinal)"
case CalendarIntervalType => s"$input.getInterval($ordinal)"
case t: StructType => s"$input.getStruct($ordinal, ${t.size})"
case _: ArrayType => s"$input.getArray($ordinal)"
case _: MapType => s"$input.getMap($ordinal)"
case NullType => "null"
case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal)
case _ => s"($jt)$input.get($ordinal, null)"
}
}
/**
* Returns the code to update a column in Row for a given DataType.
*/
def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
// The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy
// it to avoid keeping a "pointer" to a memory region which may get updated afterwards.
case StringType | _: StructType | _: ArrayType | _: MapType =>
s"$row.update($ordinal, $value.copy())"
case _ => s"$row.update($ordinal, $value)"
}
}
/**
* Update a column in MutableRow from ExprCode.
*
* @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise
*/
def updateColumn(
row: String,
dataType: DataType,
ordinal: Int,
ev: ExprCode,
nullable: Boolean,
isVectorized: Boolean = false): String = {
if (nullable) {
// Can't call setNullAt on DecimalType, because we need to keep the offset
if (!isVectorized && dataType.isInstanceOf[DecimalType]) {
s"""
if (!${ev.isNull}) {
${setColumn(row, dataType, ordinal, ev.value)};
} else {
${setColumn(row, dataType, ordinal, "null")};
}
"""
} else {
s"""
if (!${ev.isNull}) {
${setColumn(row, dataType, ordinal, ev.value)};
} else {
$row.setNullAt($ordinal);
}
"""
}
} else {
s"""${setColumn(row, dataType, ordinal, ev.value)};"""
}
}
/**
* Returns the specialized code to set a given value in a column vector for a given `DataType`.
*/
def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) =>
s"$vector.put${primitiveTypeName(jt)}($rowId, $value);"
case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});"
case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());"
case _ =>
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
}
}
/**
* Returns the specialized code to set a given value in a column vector for a given `DataType`
* that could potentially be nullable.
*/
def updateColumn(
vector: String,
rowId: String,
dataType: DataType,
ev: ExprCode,
nullable: Boolean): String = {
if (nullable) {
s"""
if (!${ev.isNull}) {
${setValue(vector, rowId, dataType, ev.value)}
} else {
$vector.putNull($rowId);
}
"""
} else {
s"""${setValue(vector, rowId, dataType, ev.value)};"""
}
}
/**
* Returns the specialized code to access a value from a column vector for a given `DataType`.
*/
def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = {
if (dataType.isInstanceOf[StructType]) {
// `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an
// `ordinal` parameter.
s"$vector.getStruct($rowId)"
} else {
getValue(vector, dataType, rowId)
}
}
/**
* Returns the name used in accessor and setter for a Java primitive type.
*/
def primitiveTypeName(jt: String): String = jt match {
case JAVA_INT => "Int"
case _ => boxedType(jt)
}
def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt))
/**
* Returns the Java type for a DataType.
*/
def javaType(dt: DataType): String = dt match {
case BooleanType => JAVA_BOOLEAN
case ByteType => JAVA_BYTE
case ShortType => JAVA_SHORT
case IntegerType | DateType => JAVA_INT
case LongType | TimestampType => JAVA_LONG
case FloatType => JAVA_FLOAT
case DoubleType => JAVA_DOUBLE
case dt: DecimalType => "Decimal"
case BinaryType => "byte[]"
case StringType => "UTF8String"
case CalendarIntervalType => "CalendarInterval"
case _: StructType => "InternalRow"
case _: ArrayType => "ArrayData"
case _: MapType => "MapData"
case udt: UserDefinedType[_] => javaType(udt.sqlType)
case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
case ObjectType(cls) => cls.getName
case _ => "Object"
}
/**
* Returns the boxed type in Java.
*/
def boxedType(jt: String): String = jt match {
case JAVA_BOOLEAN => "Boolean"
case JAVA_BYTE => "Byte"
case JAVA_SHORT => "Short"
case JAVA_INT => "Integer"
case JAVA_LONG => "Long"
case JAVA_FLOAT => "Float"
case JAVA_DOUBLE => "Double"
case other => other
}
def boxedType(dt: DataType): String = boxedType(javaType(dt))
/**
* Returns the representation of default value for a given Java Type.
*/
def defaultValue(jt: String): String = jt match {
case JAVA_BOOLEAN => "false"
case JAVA_BYTE => "(byte)-1"
case JAVA_SHORT => "(short)-1"
case JAVA_INT => "-1"
case JAVA_LONG => "-1L"
case JAVA_FLOAT => "-1.0f"
case JAVA_DOUBLE => "-1.0"
case _ => "null"
}
def defaultValue(dt: DataType): String = defaultValue(javaType(dt))
/**
* Generates code for equal expression in Java.
*/
@ -812,6 +621,7 @@ class CodegenContext {
val isNullB = freshName("isNullB")
val compareFunc = freshName("compareArray")
val minLength = freshName("minLength")
val jt = javaType(elementType)
val funcCode: String =
s"""
public int $compareFunc(ArrayData a, ArrayData b) {
@ -833,8 +643,8 @@ class CodegenContext {
} else if ($isNullB) {
return 1;
} else {
${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")};
${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")};
$jt $elementA = ${getValue("a", elementType, "i")};
$jt $elementB = ${getValue("b", elementType, "i")};
int comp = ${genComp(elementType, elementA, elementB)};
if (comp != 0) {
return comp;
@ -906,19 +716,6 @@ class CodegenContext {
}
}
/**
* List of java data types that have special accessors and setters in [[InternalRow]].
*/
val primitiveTypes =
Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE)
/**
* Returns true if the Java type has a special accessor and setter in [[InternalRow]].
*/
def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)
def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))
/**
* Splits the generated code of expressions into multiple functions, because function has
* 64kb code size limit in JVM. If the class to which the function would be inlined would grow
@ -1089,7 +886,7 @@ class CodegenContext {
// for performance reasons, the functions are prepended, instead of appended,
// thus here they are in reversed order
val orderedFunctions = innerClassFunctions.reverse
if (orderedFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) {
if (orderedFunctions.size > MERGE_SPLIT_METHODS_THRESHOLD) {
// Adding a new function to each inner class which contains the invocation of all the
// ones which have been added to that inner class. For example,
// private class NestedClass {
@ -1289,7 +1086,7 @@ class CodegenContext {
* length less than a pre-defined constant.
*/
def isValidParamLength(paramLength: Int): Boolean = {
paramLength <= CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH
}
}
@ -1524,4 +1321,221 @@ object CodeGenerator extends Logging {
result
}
})
/**
* Name of Java primitive data type
*/
final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
final val JAVA_INT = "int"
final val JAVA_LONG = "long"
final val JAVA_FLOAT = "float"
final val JAVA_DOUBLE = "double"
/**
* List of java primitive data types
*/
val primitiveTypes =
Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE)
/**
* Returns true if a Java type is Java primitive primitive type
*/
def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)
def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))
/**
* Returns the specialized code to access a value from `inputRow` at `ordinal`.
*/
def getValue(input: String, dataType: DataType, ordinal: String): String = {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)"
case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})"
case StringType => s"$input.getUTF8String($ordinal)"
case BinaryType => s"$input.getBinary($ordinal)"
case CalendarIntervalType => s"$input.getInterval($ordinal)"
case t: StructType => s"$input.getStruct($ordinal, ${t.size})"
case _: ArrayType => s"$input.getArray($ordinal)"
case _: MapType => s"$input.getMap($ordinal)"
case NullType => "null"
case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal)
case _ => s"($jt)$input.get($ordinal, null)"
}
}
/**
* Returns the code to update a column in Row for a given DataType.
*/
def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
// The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy
// it to avoid keeping a "pointer" to a memory region which may get updated afterwards.
case StringType | _: StructType | _: ArrayType | _: MapType =>
s"$row.update($ordinal, $value.copy())"
case _ => s"$row.update($ordinal, $value)"
}
}
/**
* Update a column in MutableRow from ExprCode.
*
* @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise
*/
def updateColumn(
row: String,
dataType: DataType,
ordinal: Int,
ev: ExprCode,
nullable: Boolean,
isVectorized: Boolean = false): String = {
if (nullable) {
// Can't call setNullAt on DecimalType, because we need to keep the offset
if (!isVectorized && dataType.isInstanceOf[DecimalType]) {
s"""
|if (!${ev.isNull}) {
| ${setColumn(row, dataType, ordinal, ev.value)};
|} else {
| ${setColumn(row, dataType, ordinal, "null")};
|}
""".stripMargin
} else {
s"""
|if (!${ev.isNull}) {
| ${setColumn(row, dataType, ordinal, ev.value)};
|} else {
| $row.setNullAt($ordinal);
|}
""".stripMargin
}
} else {
s"""${setColumn(row, dataType, ordinal, ev.value)};"""
}
}
/**
* Returns the specialized code to set a given value in a column vector for a given `DataType`.
*/
def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) =>
s"$vector.put${primitiveTypeName(jt)}($rowId, $value);"
case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});"
case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());"
case _ =>
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
}
}
/**
* Returns the specialized code to set a given value in a column vector for a given `DataType`
* that could potentially be nullable.
*/
def updateColumn(
vector: String,
rowId: String,
dataType: DataType,
ev: ExprCode,
nullable: Boolean): String = {
if (nullable) {
s"""
|if (!${ev.isNull}) {
| ${setValue(vector, rowId, dataType, ev.value)}
|} else {
| $vector.putNull($rowId);
|}
""".stripMargin
} else {
s"""${setValue(vector, rowId, dataType, ev.value)};"""
}
}
/**
* Returns the specialized code to access a value from a column vector for a given `DataType`.
*/
def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = {
if (dataType.isInstanceOf[StructType]) {
// `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an
// `ordinal` parameter.
s"$vector.getStruct($rowId)"
} else {
getValue(vector, dataType, rowId)
}
}
/**
* Returns the name used in accessor and setter for a Java primitive type.
*/
def primitiveTypeName(jt: String): String = jt match {
case JAVA_INT => "Int"
case _ => boxedType(jt)
}
def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt))
/**
* Returns the Java type for a DataType.
*/
def javaType(dt: DataType): String = dt match {
case BooleanType => JAVA_BOOLEAN
case ByteType => JAVA_BYTE
case ShortType => JAVA_SHORT
case IntegerType | DateType => JAVA_INT
case LongType | TimestampType => JAVA_LONG
case FloatType => JAVA_FLOAT
case DoubleType => JAVA_DOUBLE
case _: DecimalType => "Decimal"
case BinaryType => "byte[]"
case StringType => "UTF8String"
case CalendarIntervalType => "CalendarInterval"
case _: StructType => "InternalRow"
case _: ArrayType => "ArrayData"
case _: MapType => "MapData"
case udt: UserDefinedType[_] => javaType(udt.sqlType)
case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
case ObjectType(cls) => cls.getName
case _ => "Object"
}
/**
* Returns the boxed type in Java.
*/
def boxedType(jt: String): String = jt match {
case JAVA_BOOLEAN => "Boolean"
case JAVA_BYTE => "Byte"
case JAVA_SHORT => "Short"
case JAVA_INT => "Integer"
case JAVA_LONG => "Long"
case JAVA_FLOAT => "Float"
case JAVA_DOUBLE => "Double"
case other => other
}
def boxedType(dt: DataType): String = boxedType(javaType(dt))
/**
* Returns the representation of default value for a given Java Type.
* @param jt the string name of the Java type
* @param typedNull if true, for null literals, return a typed (with a cast) version
*/
def defaultValue(jt: String, typedNull: Boolean): String = jt match {
case JAVA_BOOLEAN => "false"
case JAVA_BYTE => "(byte)-1"
case JAVA_SHORT => "(short)-1"
case JAVA_INT => "-1"
case JAVA_LONG => "-1L"
case JAVA_FLOAT => "-1.0f"
case JAVA_DOUBLE => "-1.0"
case _ => if (typedNull) s"(($jt)null)" else "null"
}
def defaultValue(dt: DataType, typedNull: Boolean = false): String =
defaultValue(javaType(dt), typedNull)
}

View file

@ -44,20 +44,21 @@ trait CodegenFallback extends Expression {
}
val objectTerm = ctx.freshName("obj")
val placeHolder = ctx.registerComment(this.toString)
val javaType = CodeGenerator.javaType(this.dataType)
if (nullable) {
ev.copy(code = s"""
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
boolean ${ev.isNull} = $objectTerm == null;
${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(this.dataType)};
if (!${ev.isNull}) {
${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;
}""")
} else {
ev.copy(code = s"""
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
$javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;
""", isNull = "false")
}
}

View file

@ -62,9 +62,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map {
case (ev, i) =>
val e = expressions(i)
val value = ctx.addMutableState(ctx.javaType(e.dataType), "value")
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value")
if (e.nullable) {
val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "isNull")
val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull")
(s"""
|${ev.code}
|$isNull = ${ev.isNull};
@ -84,7 +84,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
val updates = validExpr.zip(projectionCodes).map {
case (e, (_, isNull, value, i)) =>
val ev = ExprCode("", isNull, value)
ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
}
val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1))

View file

@ -89,7 +89,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
s"""
${ctx.INPUT_ROW} = a;
boolean $isNullA;
${ctx.javaType(order.child.dataType)} $primitiveA;
${CodeGenerator.javaType(order.child.dataType)} $primitiveA;
{
${eval.code}
$isNullA = ${eval.isNull};
@ -97,7 +97,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
}
${ctx.INPUT_ROW} = b;
boolean $isNullB;
${ctx.javaType(order.child.dataType)} $primitiveB;
${CodeGenerator.javaType(order.child.dataType)} $primitiveB;
{
${eval.code}
$isNullB = ${eval.isNull};

View file

@ -53,7 +53,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val rowClass = classOf[GenericInternalRow].getName
val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, i.toString), dt)
val converter = convertToSafe(ctx, CodeGenerator.getValue(tmpInput, dt, i.toString), dt)
s"""
if (!$tmpInput.isNullAt($i)) {
${converter.code}
@ -90,7 +90,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val arrayClass = classOf[GenericArrayData].getName
val elementConverter = convertToSafe(
ctx, ctx.getValue(tmpInput, elementType, index), elementType)
ctx, CodeGenerator.getValue(tmpInput, elementType, index), elementType)
val code = s"""
final ArrayData $tmpInput = $input;
final int $numElements = $tmpInput.numElements();
@ -153,7 +153,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
mutableRow.setNullAt($i);
} else {
${converter.code}
${ctx.setColumn("mutableRow", e.dataType, i, converter.value)};
${CodeGenerator.setColumn("mutableRow", e.dataType, i, converter.value)};
}
"""
}

View file

@ -52,7 +52,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
val tmpInput = ctx.freshName("tmpInput")
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, i.toString))
ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString))
}
s"""
@ -195,16 +195,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case other => other
}
val jt = ctx.javaType(et)
val jt = CodeGenerator.javaType(et)
val elementOrOffsetSize = et match {
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8
case _ if ctx.isPrimitiveType(jt) => et.defaultSize
case _ if CodeGenerator.isPrimitiveType(jt) => et.defaultSize
case _ => 8 // we need 8 bytes to store offset and length
}
val tmpCursor = ctx.freshName("tmpCursor")
val element = ctx.getValue(tmpInput, et, index)
val element = CodeGenerator.getValue(tmpInput, et, index)
val writeElement = et match {
case t: StructType =>
s"""
@ -235,7 +235,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => s"$arrayWriter.write($index, $element);"
}
val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else ""
val primitiveTypeName =
if (CodeGenerator.isPrimitiveType(jt)) CodeGenerator.primitiveTypeName(et) else ""
s"""
final ArrayData $tmpInput = $input;
if ($tmpInput instanceof UnsafeArrayData) {

View file

@ -20,7 +20,7 @@ import java.util.Comparator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
@ -54,7 +54,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
ev.copy(code = s"""
boolean ${ev.isNull} = false;
${childGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
(${childGen.value}).numElements();""", isNull = "false")
}
}
@ -270,7 +270,7 @@ case class ArrayContains(left: Expression, right: Expression)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (arr, value) => {
val i = ctx.freshName("i")
val getValue = ctx.getValue(arr, right.dataType, i)
val getValue = CodeGenerator.getValue(arr, right.dataType, i)
s"""
for (int $i = 0; $i < $arr.numElements(); $i ++) {
if ($arr.isNullAt($i)) {

View file

@ -90,7 +90,7 @@ private [sql] object GenArrayData {
val arrayDataName = ctx.freshName("arrayData")
val numElements = elementsCode.length
if (!ctx.isPrimitiveType(elementType)) {
if (!CodeGenerator.isPrimitiveType(elementType)) {
val arrayName = ctx.freshName("arrayObject")
val genericArrayClass = classOf[GenericArrayData].getName
@ -124,7 +124,7 @@ private [sql] object GenArrayData {
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
val baseOffset = Platform.BYTE_ARRAY_OFFSET
val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
val isNullAssignment = if (!isMapKey) {
s"$arrayDataName.setNullAt($i);"

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
@ -129,12 +129,12 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
if ($eval.isNullAt($ordinal)) {
${ev.isNull} = true;
} else {
${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)};
}
"""
} else {
s"""
${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)};
"""
}
})
@ -205,7 +205,7 @@ case class GetArrayStructFields(
} else {
final InternalRow $row = $eval.getStruct($j, $numFields);
$nullSafeEval {
$values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)};
$values[$j] = ${CodeGenerator.getValue(row, field.dataType, ordinal.toString)};
}
}
}
@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
if ($index >= $eval1.numElements() || $index < 0$nullCheck) {
${ev.isNull} = true;
} else {
${ev.value} = ${ctx.getValue(eval1, dataType, index)};
${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
}
"""
})
@ -327,6 +327,7 @@ case class GetMapValue(child: Expression, key: Expression)
} else {
""
}
val keyJavaType = CodeGenerator.javaType(keyType)
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
final int $length = $eval1.numElements();
@ -336,7 +337,7 @@ case class GetMapValue(child: Expression, key: Expression)
int $index = 0;
boolean $found = false;
while ($index < $length && !$found) {
final ${ctx.javaType(keyType)} $key = ${ctx.getValue(keys, keyType, index)};
final $keyJavaType $key = ${CodeGenerator.getValue(keys, keyType, index)};
if (${ctx.genEqual(keyType, key, eval2)}) {
$found = true;
} else {
@ -347,7 +348,7 @@ case class GetMapValue(child: Expression, key: Expression)
if (!$found$nullCheck) {
${ev.isNull} = true;
} else {
${ev.value} = ${ctx.getValue(values, dataType, index)};
${ev.value} = ${CodeGenerator.getValue(values, dataType, index)};
}
"""
})

View file

@ -69,7 +69,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
s"""
|${condEval.code}
|boolean ${ev.isNull} = false;
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${condEval.isNull} && ${condEval.value}) {
| ${trueEval.code}
| ${ev.isNull} = ${trueEval.isNull};
@ -191,7 +191,7 @@ case class CaseWhen(
// It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`,
// We won't go on anymore on the computation.
val resultState = ctx.freshName("caseWhenResultState")
ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value)
ev.value = ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value)
// these blocks are meant to be inside a
// do {
@ -244,10 +244,10 @@ case class CaseWhen(
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = allConditions,
funcName = "caseWhen",
returnType = ctx.JAVA_BYTE,
returnType = CodeGenerator.JAVA_BYTE,
makeSplitFunction = func =>
s"""
|${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
|${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED;
|do {
| $func
|} while (false);
@ -264,7 +264,7 @@ case class CaseWhen(
ev.copy(code =
s"""
|${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
|${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED;
|do {
| $codes
|} while (false);

View file

@ -26,7 +26,7 @@ import scala.util.control.NonFatal
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@ -673,18 +673,19 @@ abstract class UnixTime
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = CodeGenerator.javaType(dataType)
left.dataType match {
case StringType if right.foldable =>
val df = classOf[DateFormat].getName
if (formatter == null) {
ExprCode("", "true", ctx.defaultValue(dataType))
ExprCode.forNullValue(dataType)
} else {
val formatterName = ctx.addReferenceObj("formatter", formatter, df)
val eval1 = left.genCode(ctx)
ev.copy(code = s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
try {
${ev.value} = $formatterName.parse(${eval1.value}.toString()).getTime() / 1000L;
@ -713,7 +714,7 @@ abstract class UnixTime
ev.copy(code = s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = ${eval1.value} / 1000000L;
}""")
@ -724,7 +725,7 @@ abstract class UnixTime
ev.copy(code = s"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $dtu.daysToMillis(${eval1.value}, $tz) / 1000L;
}""")
@ -819,7 +820,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
ev.copy(code = s"""
${t.code}
boolean ${ev.isNull} = ${t.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
try {
${ev.value} = UTF8String.fromString($formatterName.format(
@ -1344,18 +1345,19 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes {
: ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
val javaType = CodeGenerator.javaType(dataType)
if (format.foldable) {
if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
} else {
val t = instant.genCode(ctx)
val truncFuncStr = truncFunc(t.value, truncLevel.toString)
ev.copy(code = s"""
${t.code}
boolean ${ev.isNull} = ${t.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $dtu.$truncFuncStr;
}""")

View file

@ -278,7 +278,7 @@ abstract class HashExpression[E] extends Expression {
}
}
val hashResultType = ctx.javaType(dataType)
val hashResultType = CodeGenerator.javaType(dataType)
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = childrenHash,
funcName = "computeHash",
@ -307,9 +307,10 @@ abstract class HashExpression[E] extends Expression {
ctx: CodegenContext): String = {
val element = ctx.freshName("element")
val jt = CodeGenerator.javaType(elementType)
ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") {
s"""
final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)};
final $jt $element = ${CodeGenerator.getValue(input, elementType, index)};
${computeHash(element, elementType, result, ctx)}
"""
}
@ -407,7 +408,7 @@ abstract class HashExpression[E] extends Expression {
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
}
val hashResultType = ctx.javaType(dataType)
val hashResultType = CodeGenerator.javaType(dataType)
ctx.splitExpressions(
expressions = fieldsHash,
funcName = "computeHashForStruct",
@ -651,11 +652,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = childrenHash,
funcName = "computeHash",
extraArguments = Seq(ctx.JAVA_INT -> ev.value),
returnType = ctx.JAVA_INT,
extraArguments = Seq(CodeGenerator.JAVA_INT -> ev.value),
returnType = CodeGenerator.JAVA_INT,
makeSplitFunction = body =>
s"""
|${ctx.JAVA_INT} $childHash = 0;
|${CodeGenerator.JAVA_INT} $childHash = 0;
|$body
|return ${ev.value};
""".stripMargin,
@ -664,8 +665,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
ev.copy(code =
s"""
|${ctx.JAVA_INT} ${ev.value} = $seed;
|${ctx.JAVA_INT} $childHash = 0;
|${CodeGenerator.JAVA_INT} ${ev.value} = $seed;
|${CodeGenerator.JAVA_INT} $childHash = 0;
|$codes
""".stripMargin)
}
@ -780,14 +781,14 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
""".stripMargin
}
s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions(
s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions(
expressions = fieldsHash,
funcName = "computeHashForStruct",
arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result),
returnType = ctx.JAVA_INT,
arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result),
returnType = CodeGenerator.JAVA_INT,
makeSplitFunction = body =>
s"""
|${ctx.JAVA_INT} $childResult = 0;
|${CodeGenerator.JAVA_INT} $childResult = 0;
|$body
|return $result;
""".stripMargin,

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.rdd.InputFileBlockHolder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types.{DataType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
@ -42,7 +42,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
s"$className.getInputFilePath();", isNull = "false")
}
}
@ -65,7 +65,7 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
s"$className.getStartOffset();", isNull = "false")
}
}
@ -88,7 +88,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
s"$className.getLength();", isNull = "false")
}
}

View file

@ -277,13 +277,9 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
override def eval(input: InternalRow): Any = value
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
val javaType = CodeGenerator.javaType(dataType)
if (value == null) {
val defaultValueLiteral = ctx.defaultValue(javaType) match {
case "null" => s"(($javaType)null)"
case lit => lit
}
ExprCode(code = "", isNull = "true", value = defaultValueLiteral)
ExprCode.forNullValue(dataType)
} else {
dataType match {
case BooleanType | IntegerType | DateType =>

View file

@ -1128,15 +1128,16 @@ abstract class RoundBase(child: Expression, scale: Expression,
}"""
}
val javaType = CodeGenerator.javaType(dataType)
if (scaleV == null) { // if scale is null, no need to eval its child at all
ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
} else {
ev.copy(code = s"""
${ce.code}
boolean ${ev.isNull} = ${ce.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
$evaluationCode
}""")

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@ -72,7 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)
// all the evals are meant to be in a do { ... } while (false); loop
val evals = children.map { e =>
@ -87,14 +87,14 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
""".stripMargin
}
val resultType = ctx.javaType(dataType)
val resultType = CodeGenerator.javaType(dataType)
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = evals,
funcName = "coalesce",
returnType = resultType,
makeSplitFunction = func =>
s"""
|$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|do {
| $func
|} while (false);
@ -113,7 +113,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
ev.copy(code =
s"""
|${ev.isNull} = true;
|$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|do {
| $codes
|} while (false);
@ -234,7 +234,7 @@ case class IsNaN(child: Expression) extends UnaryExpression
case DoubleType | FloatType =>
ev.copy(code = s"""
${eval.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false")
}
}
@ -281,7 +281,7 @@ case class NaNvl(left: Expression, right: Expression)
ev.copy(code = s"""
${leftGen.code}
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (${leftGen.isNull}) {
${ev.isNull} = true;
} else {
@ -416,8 +416,8 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = evals,
funcName = "atLeastNNonNulls",
extraArguments = (ctx.JAVA_INT, nonnull) :: Nil,
returnType = ctx.JAVA_INT,
extraArguments = (CodeGenerator.JAVA_INT, nonnull) :: Nil,
returnType = CodeGenerator.JAVA_INT,
makeSplitFunction = body =>
s"""
|do {
@ -436,11 +436,11 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
ev.copy(code =
s"""
|${ctx.JAVA_INT} $nonnull = 0;
|${CodeGenerator.JAVA_INT} $nonnull = 0;
|do {
| $codes
|} while (false);
|${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n;
|${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n;
""".stripMargin, isNull = "false")
}
}

View file

@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
@ -62,13 +62,13 @@ trait InvokeLike extends Expression with NonSQLExpression {
def prepareArguments(ctx: CodegenContext): (String, String, String) = {
val resultIsNull = if (needNullCheck) {
val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull")
val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull")
resultIsNull
} else {
"false"
}
val argValues = arguments.map { e =>
val argValue = ctx.addMutableState(ctx.javaType(e.dataType), "argValue")
val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue")
argValue
}
@ -137,7 +137,7 @@ case class StaticInvoke(
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
val javaType = CodeGenerator.javaType(dataType)
val (argCode, argString, resultIsNull) = prepareArguments(ctx)
@ -151,7 +151,7 @@ case class StaticInvoke(
}
val evaluate = if (returnNullable) {
if (ctx.defaultValue(dataType) == "null") {
if (CodeGenerator.defaultValue(dataType) == "null") {
s"""
${ev.value} = $callFunc;
${ev.isNull} = ${ev.value} == null;
@ -159,7 +159,7 @@ case class StaticInvoke(
} else {
val boxedResult = ctx.freshName("boxedResult")
s"""
${ctx.boxedType(dataType)} $boxedResult = $callFunc;
${CodeGenerator.boxedType(dataType)} $boxedResult = $callFunc;
${ev.isNull} = $boxedResult == null;
if (!${ev.isNull}) {
${ev.value} = $boxedResult;
@ -173,7 +173,7 @@ case class StaticInvoke(
val code = s"""
$argCode
$prepareIsNull
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!$resultIsNull) {
$evaluate
}
@ -228,7 +228,7 @@ case class Invoke(
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
val javaType = CodeGenerator.javaType(dataType)
val obj = targetObject.genCode(ctx)
val (argCode, argString, resultIsNull) = prepareArguments(ctx)
@ -255,11 +255,11 @@ case class Invoke(
// If the function can return null, we do an extra check to make sure our null bit is still
// set correctly.
val assignResult = if (!returnNullable) {
s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;"
s"${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;"
} else {
s"""
if ($funcResult != null) {
${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;
${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;
} else {
${ev.isNull} = true;
}
@ -275,7 +275,7 @@ case class Invoke(
val code = s"""
${obj.code}
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${obj.isNull}) {
$argCode
${ev.isNull} = $resultIsNull;
@ -341,7 +341,7 @@ case class NewInstance(
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
val javaType = CodeGenerator.javaType(dataType)
val (argCode, argString, resultIsNull) = prepareArguments(ctx)
@ -358,7 +358,8 @@ case class NewInstance(
val code = s"""
$argCode
${outer.map(_.code).getOrElse("")}
final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall;
final $javaType ${ev.value} = ${ev.isNull} ?
${CodeGenerator.defaultValue(dataType)} : $constructorCall;
"""
ev.copy(code = code)
}
@ -385,15 +386,15 @@ case class UnwrapOption(
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
val javaType = CodeGenerator.javaType(dataType)
val inputObject = child.genCode(ctx)
val code = s"""
${inputObject.code}
final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
$javaType ${ev.value} = ${ev.isNull} ?
${ctx.defaultValue(javaType)} : (${ctx.boxedType(javaType)}) ${inputObject.value}.get();
$javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} :
(${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get();
"""
ev.copy(code = code)
}
@ -546,7 +547,7 @@ case class MapObjects private(
ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable))
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val elementJavaType = ctx.javaType(loopVarDataType)
val elementJavaType = CodeGenerator.javaType(loopVarDataType)
ctx.addMutableState(elementJavaType, loopValue, forceInline = true, useFreshName = false)
val genInputData = inputData.genCode(ctx)
val genFunction = lambdaFunction.genCode(ctx)
@ -554,7 +555,7 @@ case class MapObjects private(
val convertedArray = ctx.freshName("convertedArray")
val loopIndex = ctx.freshName("loopIndex")
val convertedType = ctx.boxedType(lambdaFunction.dataType)
val convertedType = CodeGenerator.boxedType(lambdaFunction.dataType)
// Because of the way Java defines nested arrays, we have to handle the syntax specially.
// Specifically, we have to insert the [$dataLength] in between the type and any extra nested
@ -621,7 +622,7 @@ case class MapObjects private(
(
s"${genInputData.value}.numElements()",
"",
ctx.getValue(genInputData.value, et, loopIndex)
CodeGenerator.getValue(genInputData.value, et, loopIndex)
)
case ObjectType(cls) if cls == classOf[Object] =>
val it = ctx.freshName("it")
@ -643,7 +644,8 @@ case class MapObjects private(
}
val loopNullCheck = if (loopIsNull != "false") {
ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false)
ctx.addMutableState(
CodeGenerator.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false)
inputDataType match {
case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
case _ => s"$loopIsNull = $loopValue == null;"
@ -695,7 +697,7 @@ case class MapObjects private(
val code = s"""
${genInputData.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${genInputData.isNull}) {
$determineCollectionType
@ -806,10 +808,10 @@ case class CatalystToExternalMap private(
}
val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType]
val keyElementJavaType = ctx.javaType(mapType.keyType)
val keyElementJavaType = CodeGenerator.javaType(mapType.keyType)
ctx.addMutableState(keyElementJavaType, keyLoopValue, forceInline = true, useFreshName = false)
val genKeyFunction = keyLambdaFunction.genCode(ctx)
val valueElementJavaType = ctx.javaType(mapType.valueType)
val valueElementJavaType = CodeGenerator.javaType(mapType.valueType)
ctx.addMutableState(valueElementJavaType, valueLoopValue, forceInline = true,
useFreshName = false)
val genValueFunction = valueLambdaFunction.genCode(ctx)
@ -825,10 +827,11 @@ case class CatalystToExternalMap private(
val valueArray = ctx.freshName("valueArray")
val getKeyArray =
s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();"
val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex)
val getKeyLoopVar = CodeGenerator.getValue(keyArray, inputDataType(mapType.keyType), loopIndex)
val getValueArray =
s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();"
val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex)
val getValueLoopVar = CodeGenerator.getValue(
valueArray, inputDataType(mapType.valueType), loopIndex)
// Make a copy of the data if it's unsafe-backed
def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
@ -844,7 +847,7 @@ case class CatalystToExternalMap private(
val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction)
val valueLoopNullCheck = if (valueLoopIsNull != "false") {
ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true,
ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true,
useFreshName = false)
s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);"
} else {
@ -873,7 +876,7 @@ case class CatalystToExternalMap private(
val code = s"""
${genInputData.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${genInputData.isNull}) {
int $dataLength = $getLength;
@ -993,8 +996,8 @@ case class ExternalMapToCatalyst private(
val entry = ctx.freshName("entry")
val entries = ctx.freshName("entries")
val keyElementJavaType = ctx.javaType(keyType)
val valueElementJavaType = ctx.javaType(valueType)
val keyElementJavaType = CodeGenerator.javaType(keyType)
val valueElementJavaType = CodeGenerator.javaType(valueType)
ctx.addMutableState(keyElementJavaType, key, forceInline = true, useFreshName = false)
ctx.addMutableState(valueElementJavaType, value, forceInline = true, useFreshName = false)
@ -1009,8 +1012,8 @@ case class ExternalMapToCatalyst private(
val defineKeyValue =
s"""
final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next();
$key = (${ctx.boxedType(keyType)}) $entry.getKey();
$value = (${ctx.boxedType(valueType)}) $entry.getValue();
$key = (${CodeGenerator.boxedType(keyType)}) $entry.getKey();
$value = (${CodeGenerator.boxedType(valueType)}) $entry.getValue();
"""
defineEntries -> defineKeyValue
@ -1024,22 +1027,24 @@ case class ExternalMapToCatalyst private(
val defineKeyValue =
s"""
final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next();
$key = (${ctx.boxedType(keyType)}) $entry._1();
$value = (${ctx.boxedType(valueType)}) $entry._2();
$key = (${CodeGenerator.boxedType(keyType)}) $entry._1();
$value = (${CodeGenerator.boxedType(valueType)}) $entry._2();
"""
defineEntries -> defineKeyValue
}
val keyNullCheck = if (keyIsNull != "false") {
ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false)
ctx.addMutableState(
CodeGenerator.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false)
s"$keyIsNull = $key == null;"
} else {
""
}
val valueNullCheck = if (valueIsNull != "false") {
ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false)
ctx.addMutableState(
CodeGenerator.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false)
s"$valueIsNull = $value == null;"
} else {
""
@ -1047,12 +1052,12 @@ case class ExternalMapToCatalyst private(
val arrayCls = classOf[GenericArrayData].getName
val mapCls = classOf[ArrayBasedMapData].getName
val convertedKeyType = ctx.boxedType(keyConverter.dataType)
val convertedValueType = ctx.boxedType(valueConverter.dataType)
val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType)
val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType)
val code =
s"""
${inputMap.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${inputMap.isNull}) {
final int $length = ${inputMap.value}.size();
final Object[] $convertedKeys = new Object[$length];
@ -1174,12 +1179,13 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
// Code to serialize.
val input = child.genCode(ctx)
val javaType = ctx.javaType(dataType)
val javaType = CodeGenerator.javaType(dataType)
val serialize = s"$serializer.serialize(${input.value}, null).array()"
val code = s"""
${input.code}
final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $serialize;
final $javaType ${ev.value} =
${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize;
"""
ev.copy(code = code, isNull = input.isNull)
}
@ -1223,13 +1229,14 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
// Code to deserialize.
val input = child.genCode(ctx)
val javaType = ctx.javaType(dataType)
val javaType = CodeGenerator.javaType(dataType)
val deserialize =
s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)"
val code = s"""
${input.code}
final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize;
final $javaType ${ev.value} =
${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize;
"""
ev.copy(code = code, isNull = input.isNull)
}
@ -1254,7 +1261,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
val instanceGen = beanInstance.genCode(ctx)
val javaBeanInstance = ctx.freshName("javaBean")
val beanInstanceJavaType = ctx.javaType(beanInstance.dataType)
val beanInstanceJavaType = CodeGenerator.javaType(beanInstance.dataType)
val initialize = setters.map {
case (setterMethod, fieldValue) =>
@ -1405,15 +1412,15 @@ case class ValidateExternalType(child: Expression, expected: DataType)
case _: ArrayType =>
s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()"
case _ =>
s"$obj instanceof ${ctx.boxedType(dataType)}"
s"$obj instanceof ${CodeGenerator.boxedType(dataType)}"
}
val code = s"""
${input.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${input.isNull}) {
if ($typeCheck) {
${ev.value} = (${ctx.boxedType(dataType)}) $obj;
${ev.value} = (${CodeGenerator.boxedType(dataType)}) $obj;
} else {
throw new RuntimeException($obj.getClass().getName() + $errMsgField);
}

View file

@ -21,7 +21,7 @@ import scala.collection.immutable.TreeSet
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@ -235,7 +235,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaDataType = ctx.javaType(value.dataType)
val javaDataType = CodeGenerator.javaType(value.dataType)
val valueGen = value.genCode(ctx)
val listGen = list.map(_.genCode(ctx))
// inTmpResult has 3 possible values:
@ -263,8 +263,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = listCode,
funcName = "valueIn",
extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil,
returnType = ctx.JAVA_BYTE,
extraArguments = (javaDataType, valueArg) :: (CodeGenerator.JAVA_BYTE, tmpResult) :: Nil,
returnType = CodeGenerator.JAVA_BYTE,
makeSplitFunction = body =>
s"""
|do {
@ -348,8 +348,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
ev.copy(code =
s"""
|${childGen.code}
|${ctx.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull};
|${ctx.JAVA_BOOLEAN} ${ev.value} = false;
|${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull};
|${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false;
|if (!${ev.isNull}) {
| ${ev.value} = $setTerm.contains(${childGen.value});
| $setIsNull
@ -505,7 +505,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (ctx.isPrimitiveType(left.dataType)
if (CodeGenerator.isPrimitiveType(left.dataType)
&& left.dataType != BooleanType // java boolean doesn't support > or < operator
&& left.dataType != FloatType
&& left.dataType != DoubleType) {

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
@ -82,7 +82,8 @@ case class Rand(child: Expression) extends RDG {
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
ev.copy(code = s"""
final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false")
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""",
isNull = "false")
}
}
@ -116,7 +117,8 @@ case class Randn(child: Expression) extends RDG {
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
ev.copy(code = s"""
final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false")
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""",
isNull = "false")
}
}

View file

@ -126,7 +126,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
ev.copy(code = s"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $pattern.matcher(${eval.value}.toString()).matches();
}
@ -134,7 +134,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
""")
}
} else {
@ -201,7 +201,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
ev.copy(code = s"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0);
}
@ -209,7 +209,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
""")
}
} else {

View file

@ -102,11 +102,11 @@ case class Concat(children: Seq[Expression]) extends Expression {
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = inputs,
funcName = "valueConcat",
extraArguments = (s"${ctx.javaType(dataType)}[]", args) :: Nil)
extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil)
ev.copy(s"""
$initCode
$codes
${ctx.javaType(dataType)} ${ev.value} = $concatenator.concat($args);
${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args);
boolean ${ev.isNull} = ${ev.value} == null;
""")
}
@ -196,7 +196,7 @@ case class ConcatWs(children: Seq[Expression])
} else {
val array = ctx.freshName("array")
val varargNum = ctx.freshName("varargNum")
val idxInVararg = ctx.freshName("idxInVararg")
val idxVararg = ctx.freshName("idxInVararg")
val evals = children.map(_.genCode(ctx))
val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) =>
@ -206,7 +206,7 @@ case class ConcatWs(children: Seq[Expression])
if (eval.isNull == "true") {
""
} else {
s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};"
s"$array[$idxVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};"
})
case _: ArrayType =>
val size = ctx.freshName("n")
@ -222,7 +222,7 @@ case class ConcatWs(children: Seq[Expression])
if (!${eval.isNull}) {
final int $size = ${eval.value}.numElements();
for (int j = 0; j < $size; j ++) {
$array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")};
$array[$idxVararg ++] = ${CodeGenerator.getValue(eval.value, StringType, "j")};
}
}
""")
@ -247,20 +247,20 @@ case class ConcatWs(children: Seq[Expression])
val varargBuilds = ctx.splitExpressionsWithCurrentInputs(
expressions = varargBuild,
funcName = "varargBuildsConcatWs",
extraArguments = ("UTF8String []", array) :: ("int", idxInVararg) :: Nil,
extraArguments = ("UTF8String []", array) :: ("int", idxVararg) :: Nil,
returnType = "int",
makeSplitFunction = body =>
s"""
|$body
|return $idxInVararg;
|return $idxVararg;
""".stripMargin,
foldFunctions = _.map(funcCall => s"$idxInVararg = $funcCall;").mkString("\n"))
foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n"))
ev.copy(
s"""
$codes
int $varargNum = ${children.count(_.dataType == StringType) - 1};
int $idxInVararg = 0;
int $idxVararg = 0;
$varargCounts
UTF8String[] $array = new UTF8String[$varargNum];
$varargBuilds
@ -333,7 +333,7 @@ case class Elt(children: Seq[Expression]) extends Expression {
val indexVal = ctx.freshName("index")
val indexMatched = ctx.freshName("eltIndexMatched")
val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal")
val inputVal = ctx.addMutableState(CodeGenerator.javaType(dataType), "inputVal")
val assignInputValue = inputs.zipWithIndex.map { case (eval, index) =>
s"""
@ -350,10 +350,10 @@ case class Elt(children: Seq[Expression]) extends Expression {
expressions = assignInputValue,
funcName = "eltFunc",
extraArguments = ("int", indexVal) :: Nil,
returnType = ctx.JAVA_BOOLEAN,
returnType = CodeGenerator.JAVA_BOOLEAN,
makeSplitFunction = body =>
s"""
|${ctx.JAVA_BOOLEAN} $indexMatched = false;
|${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false;
|do {
| $body
|} while (false);
@ -372,12 +372,12 @@ case class Elt(children: Seq[Expression]) extends Expression {
s"""
|${index.code}
|final int $indexVal = ${index.value};
|${ctx.JAVA_BOOLEAN} $indexMatched = false;
|${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false;
|$inputVal = null;
|do {
| $codes
|} while (false);
|final ${ctx.javaType(dataType)} ${ev.value} = $inputVal;
|final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal;
|final boolean ${ev.isNull} = ${ev.value} == null;
""".stripMargin)
}
@ -1410,10 +1410,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
val numArgLists = argListGen.length
val argListCode = argListGen.zipWithIndex.map { case(v, index) =>
val value =
if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
if (CodeGenerator.boxedType(v._1) != CodeGenerator.javaType(v._1)) {
// Java primitives get boxed in order to allow null values.
s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
s"new ${ctx.boxedType(v._1)}(${v._2.value})"
s"(${v._2.isNull}) ? (${CodeGenerator.boxedType(v._1)}) null : " +
s"new ${CodeGenerator.boxedType(v._1)}(${v._2.value})"
} else {
s"(${v._2.isNull}) ? null : ${v._2.value}"
}
@ -1434,7 +1434,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
ev.copy(code = s"""
${pattern.code}
boolean ${ev.isNull} = ${pattern.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
$stringBuffer $sb = new $stringBuffer();
$formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
@ -2110,7 +2110,8 @@ case class FormatNumber(x: Expression, d: Expression)
val usLocale = "US"
val i = ctx.freshName("i")
val dFormat = ctx.freshName("dFormat")
val lastDValue = ctx.addMutableState(ctx.JAVA_INT, "lastDValue", v => s"$v = -100;")
val lastDValue =
ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;")
val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();")
val numberFormat = ctx.addMutableState(df, "numberFormat",
v => s"""$v = new $df("", new $dfs($l.$usLocale));""")

View file

@ -405,12 +405,12 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-18016: define mutable states by using an array") {
val ctx1 = new CodegenContext
for (i <- 1 to CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) {
ctx1.addMutableState(ctx1.JAVA_INT, "i", v => s"$v = $i;")
ctx1.addMutableState(CodeGenerator.JAVA_INT, "i", v => s"$v = $i;")
}
assert(ctx1.inlinedMutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD)
// When the number of primitive type mutable states is over the threshold, others are
// allocated into an array
assert(ctx1.arrayCompactedMutableStates.get(ctx1.JAVA_INT).get.arrayNames.size == 1)
assert(ctx1.arrayCompactedMutableStates.get(CodeGenerator.JAVA_INT).get.arrayNames.size == 1)
assert(ctx1.mutableStateInitCode.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10)
val ctx2 = new CodegenContext

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
@ -49,15 +49,15 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
ordinal: String,
dataType: DataType,
nullable: Boolean): ExprCode = {
val javaType = ctx.javaType(dataType)
val value = ctx.getValueFromVector(columnVar, dataType, ordinal)
val javaType = CodeGenerator.javaType(dataType)
val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal)
val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" }
val valueVar = ctx.freshName("value")
val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
val code = s"${ctx.registerComment(str)}\n" + (if (nullable) {
s"""
boolean $isNullVar = $columnVar.isNullAt($ordinal);
$javaType $valueVar = $isNullVar ? ${ctx.defaultValue(dataType)} : ($value);
$javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value);
"""
} else {
s"$javaType $valueVar = $value;"
@ -85,12 +85,13 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
// metrics
val numOutputRows = metricTerm(ctx, "numOutputRows")
val scanTimeMetric = metricTerm(ctx, "scanTime")
val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime") // init as scanTime = 0
val scanTimeTotalNs =
ctx.addMutableState(CodeGenerator.JAVA_LONG, "scanTime") // init as scanTime = 0
val columnarBatchClz = classOf[ColumnarBatch].getName
val batch = ctx.addMutableState(columnarBatchClz, "batch")
val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx") // init as batchIdx = 0
val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0
val columnVectorClzs = vectorTypes.getOrElse(
Seq.fill(output.indices.size)(classOf[ColumnVector].getName))
val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map {

View file

@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.metric.SQLMetrics
@ -154,7 +154,8 @@ case class ExpandExec(
val value = ctx.freshName("value")
val code = s"""
|boolean $isNull = true;
|${ctx.javaType(firstExpr.dataType)} $value = ${ctx.defaultValue(firstExpr.dataType)};
|${CodeGenerator.javaType(firstExpr.dataType)} $value =
| ${CodeGenerator.defaultValue(firstExpr.dataType)};
""".stripMargin
ExprCode(code, isNull, value)
}

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
@ -305,15 +305,15 @@ case class GenerateExec(
nullable: Boolean,
initialChecks: Seq[String]): ExprCode = {
val value = ctx.freshName(name)
val javaType = ctx.javaType(dt)
val getter = ctx.getValue(source, dt, index)
val javaType = CodeGenerator.javaType(dt)
val getter = CodeGenerator.getValue(source, dt, index)
val checks = initialChecks ++ optionalCode(nullable, s"$source.isNullAt($index)")
if (checks.nonEmpty) {
val isNull = ctx.freshName("isNull")
val code =
s"""
|boolean $isNull = ${checks.mkString(" || ")};
|$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter;
|$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter;
""".stripMargin
ExprCode(code, isNull, value)
} else {

View file

@ -22,7 +22,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
@ -133,7 +133,8 @@ case class SortExec(
override def needStopCheck: Boolean = false
override protected def doProduce(ctx: CodegenContext): String = {
val needToSort = ctx.addMutableState(ctx.JAVA_BOOLEAN, "needToSort", v => s"$v = true;")
val needToSort =
ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;")
// Initialize the class member variables. This includes the instance of the Sorter and
// the iterator to return sorted rows.

View file

@ -234,7 +234,7 @@ trait CodegenSupport extends SparkPlan {
variables.zipWithIndex.foreach { case (ev, i) =>
val paramName = ctx.freshName(s"expr_$i")
val paramType = ctx.javaType(attributes(i).dataType)
val paramType = CodeGenerator.javaType(attributes(i).dataType)
arguments += ev.value
parameters += s"$paramType $paramName"

View file

@ -178,7 +178,7 @@ case class HashAggregateExec(
private var bufVars: Seq[ExprCode] = _
private def doProduceWithoutKeys(ctx: CodegenContext): String = {
val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg")
val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
// The generated function doesn't have input row in the code context.
ctx.INPUT_ROW = null
@ -186,8 +186,8 @@ case class HashAggregateExec(
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val initExpr = functions.flatMap(f => f.initialValues)
bufVars = initExpr.map { e =>
val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull")
val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue")
val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull")
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
// The initial expression should not access any column
val ev = e.genCode(ctx)
val initVars = s"""
@ -532,7 +532,7 @@ case class HashAggregateExec(
*/
private def checkIfFastHashMapSupported(ctx: CodegenContext): Boolean = {
val isSupported =
(groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) ||
(groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) ||
f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) &&
bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge)
@ -565,7 +565,7 @@ case class HashAggregateExec(
}
private def doProduceWithKeys(ctx: CodegenContext): String = {
val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg")
val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
if (sqlContext.conf.enableTwoLevelAggMap) {
enableTwoLevelHashMap(ctx)
} else {
@ -757,7 +757,7 @@ case class HashAggregateExec(
val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter,
incCounter) = if (testFallbackStartsAt.isDefined) {
val countTerm = ctx.addMutableState(ctx.JAVA_INT, "fallbackCounter")
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter")
(s"$countTerm < ${testFallbackStartsAt.get._1}",
s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;")
} else {
@ -832,7 +832,7 @@ case class HashAggregateExec(
}
val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
}
s"""
|// common sub-expressions
@ -855,7 +855,7 @@ case class HashAggregateExec(
}
val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
ctx.updateColumn(
CodeGenerator.updateColumn(
fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled)
}

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.types._
/**
@ -41,13 +41,13 @@ abstract class HashMapGenerator(
val groupingKeys = groupingKeySchema.map(k => Buffer(k.dataType, ctx.freshName("key")))
val bufferValues = bufferSchema.map(k => Buffer(k.dataType, ctx.freshName("value")))
val groupingKeySignature =
groupingKeys.map(key => s"${ctx.javaType(key.dataType)} ${key.name}").mkString(", ")
groupingKeys.map(key => s"${CodeGenerator.javaType(key.dataType)} ${key.name}").mkString(", ")
val buffVars: Seq[ExprCode] = {
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val initExpr = functions.flatMap(f => f.initialValues)
initExpr.map { e =>
val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull")
val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue")
val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull")
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
val ev = e.genCode(ctx)
val initVars =
s"""

View file

@ -18,8 +18,8 @@
package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator}
import org.apache.spark.sql.types._
/**
@ -114,7 +114,7 @@ class RowBasedHashMapGenerator(
def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = {
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
s"""(${ctx.genEqual(key.dataType, ctx.getValue("row",
s"""(${ctx.genEqual(key.dataType, CodeGenerator.getValue("row",
key.dataType, ordinal.toString()), key.name)})"""
}.mkString(" && ")
}
@ -147,7 +147,7 @@ class RowBasedHashMapGenerator(
case t: DecimalType =>
s"agg_rowWriter.write(${ordinal}, ${key.name}, ${t.precision}, ${t.scale})"
case t: DataType =>
if (!t.isInstanceOf[StringType] && !ctx.isPrimitiveType(t)) {
if (!t.isInstanceOf[StringType] && !CodeGenerator.isPrimitiveType(t)) {
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $t")
}
s"agg_rowWriter.write(${ordinal}, ${key.name})"

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator}
import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
@ -127,7 +127,8 @@ class VectorizedHashMapGenerator(
def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = {
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
val value = ctx.getValueFromVector(s"vectors[$ordinal]", key.dataType, "buckets[idx]")
val value = CodeGenerator.getValueFromVector(s"vectors[$ordinal]", key.dataType,
"buckets[idx]")
s"(${ctx.genEqual(key.dataType, value, key.name)})"
}.mkString(" && ")
}
@ -182,14 +183,14 @@ class VectorizedHashMapGenerator(
def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = {
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
ctx.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name)
CodeGenerator.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name)
}
}
def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = {
bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
ctx.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType,
buffVars(ordinal), nullable = true)
CodeGenerator.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows",
key.dataType, buffVars(ordinal), nullable = true)
}
}

View file

@ -24,7 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, ExpressionCanonicalizer}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.LongType
@ -364,8 +364,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
protected override def doProduce(ctx: CodegenContext): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
val initTerm = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initRange")
val number = ctx.addMutableState(ctx.JAVA_LONG, "number")
val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange")
val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number")
val value = ctx.freshName("value")
val ev = ExprCode("", "false", value)
@ -385,10 +385,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
// the metrics.
// Once number == batchEnd, it's time to progress to the next batch.
val batchEnd = ctx.addMutableState(ctx.JAVA_LONG, "batchEnd")
val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd")
// How many values should still be generated by this range operator.
val numElementsTodo = ctx.addMutableState(ctx.JAVA_LONG, "numElementsTodo")
val numElementsTodo = ctx.addMutableState(CodeGenerator.JAVA_LONG, "numElementsTodo")
// How many values should be generated in the next batch.
val nextBatchTodo = ctx.freshName("nextBatchTodo")

View file

@ -91,7 +91,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
val accessorName = ctx.addMutableState(accessorCls, "accessor")
val createCode = dt match {
case t if ctx.isPrimitiveType(dt) =>
case t if CodeGenerator.isPrimitiveType(dt) =>
s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
case NullType | StringType | BinaryType =>
s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"

View file

@ -22,7 +22,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan}
@ -182,9 +182,10 @@ case class BroadcastHashJoinExec(
// the variables are needed even there is no matched rows
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val javaType = CodeGenerator.javaType(a.dataType)
val code = s"""
|boolean $isNull = true;
|${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)};
|$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
|if ($matched != null) {
| ${ev.code}
| $isNull = ${ev.isNull};

View file

@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport,
@ -516,9 +516,9 @@ case class SortMergeJoinExec(
ctx.INPUT_ROW = leftRow
left.output.zipWithIndex.map { case (a, i) =>
val value = ctx.freshName("value")
val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
val javaType = ctx.javaType(a.dataType)
val defaultValue = ctx.defaultValue(a.dataType)
val valueCode = CodeGenerator.getValue(leftRow, a.dataType, i.toString)
val javaType = CodeGenerator.javaType(a.dataType)
val defaultValue = CodeGenerator.defaultValue(a.dataType)
if (a.nullable) {
val isNull = ctx.freshName("isNull")
val code =

View file

@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, LazilyGeneratedOrdering}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.util.Utils
@ -71,7 +71,8 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
}
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false
val stopEarly =
ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false
ctx.addNewFunction("stopEarly", s"""
@Override
@ -79,7 +80,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
return $stopEarly;
}
""", inlineToOuterClass = true)
val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count") // init as count = 0
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0
s"""
| if ($countTerm < $limit) {
| $countTerm += 1;