[SPARK-23546][SQL] Refactor stateless methods/values in CodegenContext
## What changes were proposed in this pull request? A current `CodegenContext` class has immutable value or method without mutable state, too. This refactoring moves them to `CodeGenerator` object class which can be accessed from anywhere without an instantiated `CodegenContext` in the program. ## How was this patch tested? Existing tests Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com> Closes #20700 from kiszk/SPARK-23546.
This commit is contained in:
parent
269cd53590
commit
2ce37b50fc
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.errors.attachTree
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
|
@ -66,13 +66,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
|
|||
ev.copy(code = oev.code)
|
||||
} else {
|
||||
assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.")
|
||||
val javaType = ctx.javaType(dataType)
|
||||
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
|
||||
if (nullable) {
|
||||
ev.copy(code =
|
||||
s"""
|
||||
|boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
|
||||
|$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
|
||||
|$javaType ${ev.value} = ${ev.isNull} ?
|
||||
| ${CodeGenerator.defaultValue(dataType)} : ($value);
|
||||
""".stripMargin)
|
||||
} else {
|
||||
ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false")
|
||||
|
|
|
@ -669,7 +669,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|
|||
result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = {
|
||||
s"""
|
||||
boolean $resultIsNull = $inputIsNull;
|
||||
${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)};
|
||||
${CodeGenerator.javaType(resultType)} $result = ${CodeGenerator.defaultValue(resultType)};
|
||||
if (!$inputIsNull) {
|
||||
${cast(input, result, resultIsNull)}
|
||||
}
|
||||
|
@ -685,7 +685,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|
|||
val funcName = ctx.freshName("elementToString")
|
||||
val elementToStringFunc = ctx.addNewFunction(funcName,
|
||||
s"""
|
||||
|private UTF8String $funcName(${ctx.javaType(et)} element) {
|
||||
|private UTF8String $funcName(${CodeGenerator.javaType(et)} element) {
|
||||
| UTF8String elementStr = null;
|
||||
| ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)}
|
||||
| return elementStr;
|
||||
|
@ -697,13 +697,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|
|||
|$buffer.append("[");
|
||||
|if ($array.numElements() > 0) {
|
||||
| if (!$array.isNullAt(0)) {
|
||||
| $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")}));
|
||||
| $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, "0")}));
|
||||
| }
|
||||
| for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) {
|
||||
| $buffer.append(",");
|
||||
| if (!$array.isNullAt($loopIndex)) {
|
||||
| $buffer.append(" ");
|
||||
| $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)}));
|
||||
| $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, loopIndex)}));
|
||||
| }
|
||||
| }
|
||||
|}
|
||||
|
@ -723,7 +723,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|
|||
val dataToStringCode = castToStringCode(dataType, ctx)
|
||||
ctx.addNewFunction(funcName,
|
||||
s"""
|
||||
|private UTF8String $funcName(${ctx.javaType(dataType)} data) {
|
||||
|private UTF8String $funcName(${CodeGenerator.javaType(dataType)} data) {
|
||||
| UTF8String dataStr = null;
|
||||
| ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)}
|
||||
| return dataStr;
|
||||
|
@ -734,23 +734,26 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|
|||
val keyToStringFunc = dataToStringFunc("keyToString", kt)
|
||||
val valueToStringFunc = dataToStringFunc("valueToString", vt)
|
||||
val loopIndex = ctx.freshName("loopIndex")
|
||||
val getMapFirstKey = CodeGenerator.getValue(s"$map.keyArray()", kt, "0")
|
||||
val getMapFirstValue = CodeGenerator.getValue(s"$map.valueArray()", vt, "0")
|
||||
val getMapKeyArray = CodeGenerator.getValue(s"$map.keyArray()", kt, loopIndex)
|
||||
val getMapValueArray = CodeGenerator.getValue(s"$map.valueArray()", vt, loopIndex)
|
||||
s"""
|
||||
|$buffer.append("[");
|
||||
|if ($map.numElements() > 0) {
|
||||
| $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")}));
|
||||
| $buffer.append($keyToStringFunc($getMapFirstKey));
|
||||
| $buffer.append(" ->");
|
||||
| if (!$map.valueArray().isNullAt(0)) {
|
||||
| $buffer.append(" ");
|
||||
| $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")}));
|
||||
| $buffer.append($valueToStringFunc($getMapFirstValue));
|
||||
| }
|
||||
| for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) {
|
||||
| $buffer.append(", ");
|
||||
| $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)}));
|
||||
| $buffer.append($keyToStringFunc($getMapKeyArray));
|
||||
| $buffer.append(" ->");
|
||||
| if (!$map.valueArray().isNullAt($loopIndex)) {
|
||||
| $buffer.append(" ");
|
||||
| $buffer.append($valueToStringFunc(
|
||||
| ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)}));
|
||||
| $buffer.append($valueToStringFunc($getMapValueArray));
|
||||
| }
|
||||
| }
|
||||
|}
|
||||
|
@ -773,7 +776,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|
|||
| ${if (i != 0) s"""$buffer.append(" ");""" else ""}
|
||||
|
|
||||
| // Append $i field into the string buffer
|
||||
| ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")};
|
||||
| ${CodeGenerator.javaType(ft)} $field = ${CodeGenerator.getValue(row, ft, s"$i")};
|
||||
| UTF8String $fieldStr = null;
|
||||
| ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)}
|
||||
| $buffer.append($fieldStr);
|
||||
|
@ -1202,8 +1205,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|
|||
$values[$j] = null;
|
||||
} else {
|
||||
boolean $fromElementNull = false;
|
||||
${ctx.javaType(fromType)} $fromElementPrim =
|
||||
${ctx.getValue(c, fromType, j)};
|
||||
${CodeGenerator.javaType(fromType)} $fromElementPrim =
|
||||
${CodeGenerator.getValue(c, fromType, j)};
|
||||
${castCode(ctx, fromElementPrim,
|
||||
fromElementNull, toElementPrim, toElementNull, toType, elementCast)}
|
||||
if ($toElementNull) {
|
||||
|
@ -1259,20 +1262,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
|
|||
val fromFieldNull = ctx.freshName("ffn")
|
||||
val toFieldPrim = ctx.freshName("tfp")
|
||||
val toFieldNull = ctx.freshName("tfn")
|
||||
val fromType = ctx.javaType(from.fields(i).dataType)
|
||||
val fromType = CodeGenerator.javaType(from.fields(i).dataType)
|
||||
s"""
|
||||
boolean $fromFieldNull = $tmpInput.isNullAt($i);
|
||||
if ($fromFieldNull) {
|
||||
$tmpResult.setNullAt($i);
|
||||
} else {
|
||||
$fromType $fromFieldPrim =
|
||||
${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)};
|
||||
${CodeGenerator.getValue(tmpInput, from.fields(i).dataType, i.toString)};
|
||||
${castCode(ctx, fromFieldPrim,
|
||||
fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)}
|
||||
if ($toFieldNull) {
|
||||
$tmpResult.setNullAt($i);
|
||||
} else {
|
||||
${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)};
|
||||
${CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)};
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
|
@ -119,7 +119,7 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
// TODO: support whole stage codegen too
|
||||
if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
|
||||
val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") {
|
||||
val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull")
|
||||
val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
|
||||
val localIsNull = eval.isNull
|
||||
eval.isNull = globalIsNull
|
||||
s"$globalIsNull = $localIsNull;"
|
||||
|
@ -127,7 +127,7 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
""
|
||||
}
|
||||
|
||||
val javaType = ctx.javaType(dataType)
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
val newValue = ctx.freshName("value")
|
||||
|
||||
val funcName = ctx.freshName(nodeName)
|
||||
|
@ -411,14 +411,14 @@ abstract class UnaryExpression extends Expression {
|
|||
ev.copy(code = s"""
|
||||
${childGen.code}
|
||||
boolean ${ev.isNull} = ${childGen.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
$nullSafeEval
|
||||
""")
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
boolean ${ev.isNull} = false;
|
||||
${childGen.code}
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
$resultCode""", isNull = "false")
|
||||
}
|
||||
}
|
||||
|
@ -510,7 +510,7 @@ abstract class BinaryExpression extends Expression {
|
|||
|
||||
ev.copy(code = s"""
|
||||
boolean ${ev.isNull} = true;
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
$nullSafeEval
|
||||
""")
|
||||
} else {
|
||||
|
@ -518,7 +518,7 @@ abstract class BinaryExpression extends Expression {
|
|||
boolean ${ev.isNull} = false;
|
||||
${leftGen.code}
|
||||
${rightGen.code}
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
$resultCode""", isNull = "false")
|
||||
}
|
||||
}
|
||||
|
@ -654,7 +654,7 @@ abstract class TernaryExpression extends Expression {
|
|||
|
||||
ev.copy(code = s"""
|
||||
boolean ${ev.isNull} = true;
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
$nullSafeEval""")
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
|
@ -662,7 +662,7 @@ abstract class TernaryExpression extends Expression {
|
|||
${leftGen.code}
|
||||
${midGen.code}
|
||||
${rightGen.code}
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
$resultCode""", isNull = "false")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.types.{DataType, LongType}
|
||||
|
||||
/**
|
||||
|
@ -65,14 +65,14 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
|
|||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count")
|
||||
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count")
|
||||
val partitionMaskTerm = "partitionMask"
|
||||
ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm)
|
||||
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm)
|
||||
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
|
||||
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
|
||||
|
||||
ev.copy(code = s"""
|
||||
final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
|
||||
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
|
||||
$countTerm++;""", isNull = "false")
|
||||
}
|
||||
|
||||
|
|
|
@ -1018,11 +1018,12 @@ case class ScalaUDF(
|
|||
val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}")
|
||||
val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})"
|
||||
val resultConverter = s"$convertersTerm[${children.length}]"
|
||||
val boxedType = CodeGenerator.boxedType(dataType)
|
||||
val callFunc =
|
||||
s"""
|
||||
|${ctx.boxedType(dataType)} $resultTerm = null;
|
||||
|$boxedType $resultTerm = null;
|
||||
|try {
|
||||
| $resultTerm = (${ctx.boxedType(dataType)})$resultConverter.apply($getFuncResult);
|
||||
| $resultTerm = ($boxedType)$resultConverter.apply($getFuncResult);
|
||||
|} catch (Exception e) {
|
||||
| throw new org.apache.spark.SparkException($errorMsgTerm, e);
|
||||
|}
|
||||
|
@ -1035,7 +1036,7 @@ case class ScalaUDF(
|
|||
|$callFunc
|
||||
|
|
||||
|boolean ${ev.isNull} = $resultTerm == null;
|
||||
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|if (!${ev.isNull}) {
|
||||
| ${ev.value} = $resultTerm;
|
||||
|}
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.types.{DataType, IntegerType}
|
||||
|
||||
/**
|
||||
|
@ -44,8 +44,9 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
|
|||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val idTerm = "partitionId"
|
||||
ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm)
|
||||
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm)
|
||||
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
|
||||
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
|
||||
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
|
||||
isNull = "false")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.commons.lang3.StringUtils
|
|||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.CalendarInterval
|
||||
|
||||
|
@ -165,7 +165,7 @@ case class PreciseTimestampConversion(
|
|||
val eval = child.genCode(ctx)
|
||||
ev.copy(code = eval.code +
|
||||
s"""boolean ${ev.isNull} = ${eval.isNull};
|
||||
|${ctx.javaType(dataType)} ${ev.value} = ${eval.value};
|
||||
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value};
|
||||
""".stripMargin)
|
||||
}
|
||||
override def nullSafeEval(input: Any): Any = input
|
||||
|
|
|
@ -49,8 +49,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
|
|||
// codegen would fail to compile if we just write (-($c))
|
||||
// for example, we could not write --9223372036854775808L in code
|
||||
s"""
|
||||
${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval);
|
||||
${ev.value} = (${ctx.javaType(dt)})(-($originValue));
|
||||
${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval);
|
||||
${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue));
|
||||
"""})
|
||||
case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
|
||||
}
|
||||
|
@ -107,7 +107,7 @@ case class Abs(child: Expression)
|
|||
case dt: DecimalType =>
|
||||
defineCodeGen(ctx, ev, c => s"$c.abs()")
|
||||
case dt: NumericType =>
|
||||
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(java.lang.Math.abs($c))")
|
||||
defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))")
|
||||
}
|
||||
|
||||
protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
|
||||
|
@ -129,7 +129,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
|
|||
// byte and short are casted into int when add, minus, times or divide
|
||||
case ByteType | ShortType =>
|
||||
defineCodeGen(ctx, ev,
|
||||
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
|
||||
(eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
|
||||
case _ =>
|
||||
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
|
||||
}
|
||||
|
@ -167,7 +167,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
|
|||
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)")
|
||||
case ByteType | ShortType =>
|
||||
defineCodeGen(ctx, ev,
|
||||
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
|
||||
(eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
|
||||
case CalendarIntervalType =>
|
||||
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)")
|
||||
case _ =>
|
||||
|
@ -203,7 +203,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
|
|||
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)")
|
||||
case ByteType | ShortType =>
|
||||
defineCodeGen(ctx, ev,
|
||||
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
|
||||
(eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
|
||||
case CalendarIntervalType =>
|
||||
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)")
|
||||
case _ =>
|
||||
|
@ -278,7 +278,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
|
|||
} else {
|
||||
s"${eval2.value} == 0"
|
||||
}
|
||||
val javaType = ctx.javaType(dataType)
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
val divide = if (dataType.isInstanceOf[DecimalType]) {
|
||||
s"${eval1.value}.$decimalMethod(${eval2.value})"
|
||||
} else {
|
||||
|
@ -288,7 +288,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
|
|||
ev.copy(code = s"""
|
||||
${eval2.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if ($isZero) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
|
@ -299,7 +299,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
|
|||
ev.copy(code = s"""
|
||||
${eval2.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (${eval2.isNull} || $isZero) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
|
@ -365,7 +365,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
|
|||
} else {
|
||||
s"${eval2.value} == 0"
|
||||
}
|
||||
val javaType = ctx.javaType(dataType)
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
val remainder = if (dataType.isInstanceOf[DecimalType]) {
|
||||
s"${eval1.value}.$decimalMethod(${eval2.value})"
|
||||
} else {
|
||||
|
@ -375,7 +375,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
|
|||
ev.copy(code = s"""
|
||||
${eval2.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if ($isZero) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
|
@ -386,7 +386,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
|
|||
ev.copy(code = s"""
|
||||
${eval2.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (${eval2.isNull} || $isZero) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
|
@ -454,13 +454,13 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
|
|||
s"${eval2.value} == 0"
|
||||
}
|
||||
val remainder = ctx.freshName("remainder")
|
||||
val javaType = ctx.javaType(dataType)
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
|
||||
val result = dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
val decimalAdd = "$plus"
|
||||
s"""
|
||||
${ctx.javaType(dataType)} $remainder = ${eval1.value}.remainder(${eval2.value});
|
||||
$javaType $remainder = ${eval1.value}.remainder(${eval2.value});
|
||||
if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
|
||||
${ev.value}=($remainder.$decimalAdd(${eval2.value})).remainder(${eval2.value});
|
||||
} else {
|
||||
|
@ -470,17 +470,16 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
|
|||
// byte and short are casted into int when add, minus, times or divide
|
||||
case ByteType | ShortType =>
|
||||
s"""
|
||||
${ctx.javaType(dataType)} $remainder =
|
||||
(${ctx.javaType(dataType)})(${eval1.value} % ${eval2.value});
|
||||
$javaType $remainder = ($javaType)(${eval1.value} % ${eval2.value});
|
||||
if ($remainder < 0) {
|
||||
${ev.value}=(${ctx.javaType(dataType)})(($remainder + ${eval2.value}) % ${eval2.value});
|
||||
${ev.value}=($javaType)(($remainder + ${eval2.value}) % ${eval2.value});
|
||||
} else {
|
||||
${ev.value}=$remainder;
|
||||
}
|
||||
"""
|
||||
case _ =>
|
||||
s"""
|
||||
${ctx.javaType(dataType)} $remainder = ${eval1.value} % ${eval2.value};
|
||||
$javaType $remainder = ${eval1.value} % ${eval2.value};
|
||||
if ($remainder < 0) {
|
||||
${ev.value}=($remainder + ${eval2.value}) % ${eval2.value};
|
||||
} else {
|
||||
|
@ -493,7 +492,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
|
|||
ev.copy(code = s"""
|
||||
${eval2.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if ($isZero) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
|
@ -504,7 +503,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
|
|||
ev.copy(code = s"""
|
||||
${eval2.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
$javaType ${ev.value} = ${ctx.defaultValue(javaType)};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (${eval2.isNull} || $isZero) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
|
@ -602,7 +601,7 @@ case class Least(children: Seq[Expression]) extends Expression {
|
|||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val evalChildren = children.map(_.genCode(ctx))
|
||||
ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
|
||||
ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)
|
||||
val evals = evalChildren.map(eval =>
|
||||
s"""
|
||||
|${eval.code}
|
||||
|
@ -614,7 +613,7 @@ case class Least(children: Seq[Expression]) extends Expression {
|
|||
""".stripMargin
|
||||
)
|
||||
|
||||
val resultType = ctx.javaType(dataType)
|
||||
val resultType = CodeGenerator.javaType(dataType)
|
||||
val codes = ctx.splitExpressionsWithCurrentInputs(
|
||||
expressions = evals,
|
||||
funcName = "least",
|
||||
|
@ -629,7 +628,7 @@ case class Least(children: Seq[Expression]) extends Expression {
|
|||
ev.copy(code =
|
||||
s"""
|
||||
|${ev.isNull} = true;
|
||||
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|$codes
|
||||
""".stripMargin)
|
||||
}
|
||||
|
@ -681,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
|
|||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val evalChildren = children.map(_.genCode(ctx))
|
||||
ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
|
||||
ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)
|
||||
val evals = evalChildren.map(eval =>
|
||||
s"""
|
||||
|${eval.code}
|
||||
|
@ -693,7 +692,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
|
|||
""".stripMargin
|
||||
)
|
||||
|
||||
val resultType = ctx.javaType(dataType)
|
||||
val resultType = CodeGenerator.javaType(dataType)
|
||||
val codes = ctx.splitExpressionsWithCurrentInputs(
|
||||
expressions = evals,
|
||||
funcName = "greatest",
|
||||
|
@ -708,7 +707,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
|
|||
ev.copy(code =
|
||||
s"""
|
||||
|${ev.isNull} = true;
|
||||
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|$codes
|
||||
""".stripMargin)
|
||||
}
|
||||
|
|
|
@ -147,7 +147,7 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp
|
|||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)")
|
||||
defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dataType)}) ~($c)")
|
||||
}
|
||||
|
||||
protected override def nullSafeEval(input: Any): Any = not(input)
|
||||
|
|
|
@ -59,6 +59,11 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
|
|||
case class ExprCode(var code: String, var isNull: String, var value: String)
|
||||
|
||||
object ExprCode {
|
||||
def forNullValue(dataType: DataType): ExprCode = {
|
||||
val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true)
|
||||
ExprCode(code = "", isNull = "true", value = defaultValueLiteral)
|
||||
}
|
||||
|
||||
def forNonNullValue(value: String): ExprCode = {
|
||||
ExprCode(code = "", isNull = "false", value = value)
|
||||
}
|
||||
|
@ -105,6 +110,8 @@ private[codegen] case class NewFunctionSpec(
|
|||
*/
|
||||
class CodegenContext {
|
||||
|
||||
import CodeGenerator._
|
||||
|
||||
/**
|
||||
* Holding a list of objects that could be used passed into generated class.
|
||||
*/
|
||||
|
@ -196,11 +203,11 @@ class CodegenContext {
|
|||
|
||||
/**
|
||||
* Returns the reference of next available slot in current compacted array. The size of each
|
||||
* compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`.
|
||||
* compacted array is controlled by the constant `MUTABLESTATEARRAY_SIZE_LIMIT`.
|
||||
* Once reaching the threshold, new compacted array is created.
|
||||
*/
|
||||
def getNextSlot(): String = {
|
||||
if (currentIndex < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT) {
|
||||
if (currentIndex < MUTABLESTATEARRAY_SIZE_LIMIT) {
|
||||
val res = s"${arrayNames.last}[$currentIndex]"
|
||||
currentIndex += 1
|
||||
res
|
||||
|
@ -247,10 +254,10 @@ class CodegenContext {
|
|||
* are satisfied:
|
||||
* 1. forceInline is true
|
||||
* 2. its type is primitive type and the total number of the inlined mutable variables
|
||||
* is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD`
|
||||
* is less than `OUTER_CLASS_VARIABLES_THRESHOLD`
|
||||
* 3. its type is multi-dimensional array
|
||||
* When a variable is compacted into an array, the max size of the array for compaction
|
||||
* is given by `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`.
|
||||
* is given by `MUTABLESTATEARRAY_SIZE_LIMIT`.
|
||||
*/
|
||||
def addMutableState(
|
||||
javaType: String,
|
||||
|
@ -261,7 +268,7 @@ class CodegenContext {
|
|||
|
||||
// want to put a primitive type variable at outerClass for performance
|
||||
val canInlinePrimitive = isPrimitiveType(javaType) &&
|
||||
(inlinedMutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD)
|
||||
(inlinedMutableStates.length < OUTER_CLASS_VARIABLES_THRESHOLD)
|
||||
if (forceInline || canInlinePrimitive || javaType.contains("[][]")) {
|
||||
val varName = if (useFreshName) freshName(variableName) else variableName
|
||||
val initCode = initFunc(varName)
|
||||
|
@ -339,7 +346,7 @@ class CodegenContext {
|
|||
val length = if (index + 1 == numArrays) {
|
||||
mutableStateArrays.getCurrentIndex
|
||||
} else {
|
||||
CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT
|
||||
MUTABLESTATEARRAY_SIZE_LIMIT
|
||||
}
|
||||
if (javaType.contains("[]")) {
|
||||
// initializer had an one-dimensional array variable
|
||||
|
@ -468,7 +475,7 @@ class CodegenContext {
|
|||
inlineToOuterClass: Boolean): NewFunctionSpec = {
|
||||
val (className, classInstance) = if (inlineToOuterClass) {
|
||||
outerClassName -> ""
|
||||
} else if (currClassSize > CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD) {
|
||||
} else if (currClassSize > GENERATED_CLASS_SIZE_THRESHOLD) {
|
||||
val className = freshName("NestedClass")
|
||||
val classInstance = freshName("nestedClassInstance")
|
||||
|
||||
|
@ -537,14 +544,6 @@ class CodegenContext {
|
|||
extraClasses.append(code)
|
||||
}
|
||||
|
||||
final val JAVA_BOOLEAN = "boolean"
|
||||
final val JAVA_BYTE = "byte"
|
||||
final val JAVA_SHORT = "short"
|
||||
final val JAVA_INT = "int"
|
||||
final val JAVA_LONG = "long"
|
||||
final val JAVA_FLOAT = "float"
|
||||
final val JAVA_DOUBLE = "double"
|
||||
|
||||
/**
|
||||
* The map from a variable name to it's next ID.
|
||||
*/
|
||||
|
@ -580,196 +579,6 @@ class CodegenContext {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the specialized code to access a value from `inputRow` at `ordinal`.
|
||||
*/
|
||||
def getValue(input: String, dataType: DataType, ordinal: String): String = {
|
||||
val jt = javaType(dataType)
|
||||
dataType match {
|
||||
case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)"
|
||||
case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})"
|
||||
case StringType => s"$input.getUTF8String($ordinal)"
|
||||
case BinaryType => s"$input.getBinary($ordinal)"
|
||||
case CalendarIntervalType => s"$input.getInterval($ordinal)"
|
||||
case t: StructType => s"$input.getStruct($ordinal, ${t.size})"
|
||||
case _: ArrayType => s"$input.getArray($ordinal)"
|
||||
case _: MapType => s"$input.getMap($ordinal)"
|
||||
case NullType => "null"
|
||||
case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal)
|
||||
case _ => s"($jt)$input.get($ordinal, null)"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the code to update a column in Row for a given DataType.
|
||||
*/
|
||||
def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = {
|
||||
val jt = javaType(dataType)
|
||||
dataType match {
|
||||
case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
|
||||
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
|
||||
case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
|
||||
// The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy
|
||||
// it to avoid keeping a "pointer" to a memory region which may get updated afterwards.
|
||||
case StringType | _: StructType | _: ArrayType | _: MapType =>
|
||||
s"$row.update($ordinal, $value.copy())"
|
||||
case _ => s"$row.update($ordinal, $value)"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a column in MutableRow from ExprCode.
|
||||
*
|
||||
* @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise
|
||||
*/
|
||||
def updateColumn(
|
||||
row: String,
|
||||
dataType: DataType,
|
||||
ordinal: Int,
|
||||
ev: ExprCode,
|
||||
nullable: Boolean,
|
||||
isVectorized: Boolean = false): String = {
|
||||
if (nullable) {
|
||||
// Can't call setNullAt on DecimalType, because we need to keep the offset
|
||||
if (!isVectorized && dataType.isInstanceOf[DecimalType]) {
|
||||
s"""
|
||||
if (!${ev.isNull}) {
|
||||
${setColumn(row, dataType, ordinal, ev.value)};
|
||||
} else {
|
||||
${setColumn(row, dataType, ordinal, "null")};
|
||||
}
|
||||
"""
|
||||
} else {
|
||||
s"""
|
||||
if (!${ev.isNull}) {
|
||||
${setColumn(row, dataType, ordinal, ev.value)};
|
||||
} else {
|
||||
$row.setNullAt($ordinal);
|
||||
}
|
||||
"""
|
||||
}
|
||||
} else {
|
||||
s"""${setColumn(row, dataType, ordinal, ev.value)};"""
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the specialized code to set a given value in a column vector for a given `DataType`.
|
||||
*/
|
||||
def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = {
|
||||
val jt = javaType(dataType)
|
||||
dataType match {
|
||||
case _ if isPrimitiveType(jt) =>
|
||||
s"$vector.put${primitiveTypeName(jt)}($rowId, $value);"
|
||||
case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});"
|
||||
case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());"
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the specialized code to set a given value in a column vector for a given `DataType`
|
||||
* that could potentially be nullable.
|
||||
*/
|
||||
def updateColumn(
|
||||
vector: String,
|
||||
rowId: String,
|
||||
dataType: DataType,
|
||||
ev: ExprCode,
|
||||
nullable: Boolean): String = {
|
||||
if (nullable) {
|
||||
s"""
|
||||
if (!${ev.isNull}) {
|
||||
${setValue(vector, rowId, dataType, ev.value)}
|
||||
} else {
|
||||
$vector.putNull($rowId);
|
||||
}
|
||||
"""
|
||||
} else {
|
||||
s"""${setValue(vector, rowId, dataType, ev.value)};"""
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the specialized code to access a value from a column vector for a given `DataType`.
|
||||
*/
|
||||
def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = {
|
||||
if (dataType.isInstanceOf[StructType]) {
|
||||
// `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an
|
||||
// `ordinal` parameter.
|
||||
s"$vector.getStruct($rowId)"
|
||||
} else {
|
||||
getValue(vector, dataType, rowId)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the name used in accessor and setter for a Java primitive type.
|
||||
*/
|
||||
def primitiveTypeName(jt: String): String = jt match {
|
||||
case JAVA_INT => "Int"
|
||||
case _ => boxedType(jt)
|
||||
}
|
||||
|
||||
def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt))
|
||||
|
||||
/**
|
||||
* Returns the Java type for a DataType.
|
||||
*/
|
||||
def javaType(dt: DataType): String = dt match {
|
||||
case BooleanType => JAVA_BOOLEAN
|
||||
case ByteType => JAVA_BYTE
|
||||
case ShortType => JAVA_SHORT
|
||||
case IntegerType | DateType => JAVA_INT
|
||||
case LongType | TimestampType => JAVA_LONG
|
||||
case FloatType => JAVA_FLOAT
|
||||
case DoubleType => JAVA_DOUBLE
|
||||
case dt: DecimalType => "Decimal"
|
||||
case BinaryType => "byte[]"
|
||||
case StringType => "UTF8String"
|
||||
case CalendarIntervalType => "CalendarInterval"
|
||||
case _: StructType => "InternalRow"
|
||||
case _: ArrayType => "ArrayData"
|
||||
case _: MapType => "MapData"
|
||||
case udt: UserDefinedType[_] => javaType(udt.sqlType)
|
||||
case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
|
||||
case ObjectType(cls) => cls.getName
|
||||
case _ => "Object"
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the boxed type in Java.
|
||||
*/
|
||||
def boxedType(jt: String): String = jt match {
|
||||
case JAVA_BOOLEAN => "Boolean"
|
||||
case JAVA_BYTE => "Byte"
|
||||
case JAVA_SHORT => "Short"
|
||||
case JAVA_INT => "Integer"
|
||||
case JAVA_LONG => "Long"
|
||||
case JAVA_FLOAT => "Float"
|
||||
case JAVA_DOUBLE => "Double"
|
||||
case other => other
|
||||
}
|
||||
|
||||
def boxedType(dt: DataType): String = boxedType(javaType(dt))
|
||||
|
||||
/**
|
||||
* Returns the representation of default value for a given Java Type.
|
||||
*/
|
||||
def defaultValue(jt: String): String = jt match {
|
||||
case JAVA_BOOLEAN => "false"
|
||||
case JAVA_BYTE => "(byte)-1"
|
||||
case JAVA_SHORT => "(short)-1"
|
||||
case JAVA_INT => "-1"
|
||||
case JAVA_LONG => "-1L"
|
||||
case JAVA_FLOAT => "-1.0f"
|
||||
case JAVA_DOUBLE => "-1.0"
|
||||
case _ => "null"
|
||||
}
|
||||
|
||||
def defaultValue(dt: DataType): String = defaultValue(javaType(dt))
|
||||
|
||||
/**
|
||||
* Generates code for equal expression in Java.
|
||||
*/
|
||||
|
@ -812,6 +621,7 @@ class CodegenContext {
|
|||
val isNullB = freshName("isNullB")
|
||||
val compareFunc = freshName("compareArray")
|
||||
val minLength = freshName("minLength")
|
||||
val jt = javaType(elementType)
|
||||
val funcCode: String =
|
||||
s"""
|
||||
public int $compareFunc(ArrayData a, ArrayData b) {
|
||||
|
@ -833,8 +643,8 @@ class CodegenContext {
|
|||
} else if ($isNullB) {
|
||||
return 1;
|
||||
} else {
|
||||
${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")};
|
||||
${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")};
|
||||
$jt $elementA = ${getValue("a", elementType, "i")};
|
||||
$jt $elementB = ${getValue("b", elementType, "i")};
|
||||
int comp = ${genComp(elementType, elementA, elementB)};
|
||||
if (comp != 0) {
|
||||
return comp;
|
||||
|
@ -906,19 +716,6 @@ class CodegenContext {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* List of java data types that have special accessors and setters in [[InternalRow]].
|
||||
*/
|
||||
val primitiveTypes =
|
||||
Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE)
|
||||
|
||||
/**
|
||||
* Returns true if the Java type has a special accessor and setter in [[InternalRow]].
|
||||
*/
|
||||
def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)
|
||||
|
||||
def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))
|
||||
|
||||
/**
|
||||
* Splits the generated code of expressions into multiple functions, because function has
|
||||
* 64kb code size limit in JVM. If the class to which the function would be inlined would grow
|
||||
|
@ -1089,7 +886,7 @@ class CodegenContext {
|
|||
// for performance reasons, the functions are prepended, instead of appended,
|
||||
// thus here they are in reversed order
|
||||
val orderedFunctions = innerClassFunctions.reverse
|
||||
if (orderedFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) {
|
||||
if (orderedFunctions.size > MERGE_SPLIT_METHODS_THRESHOLD) {
|
||||
// Adding a new function to each inner class which contains the invocation of all the
|
||||
// ones which have been added to that inner class. For example,
|
||||
// private class NestedClass {
|
||||
|
@ -1289,7 +1086,7 @@ class CodegenContext {
|
|||
* length less than a pre-defined constant.
|
||||
*/
|
||||
def isValidParamLength(paramLength: Int): Boolean = {
|
||||
paramLength <= CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
|
||||
paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1524,4 +1321,221 @@ object CodeGenerator extends Logging {
|
|||
result
|
||||
}
|
||||
})
|
||||
|
||||
/**
|
||||
* Name of Java primitive data type
|
||||
*/
|
||||
final val JAVA_BOOLEAN = "boolean"
|
||||
final val JAVA_BYTE = "byte"
|
||||
final val JAVA_SHORT = "short"
|
||||
final val JAVA_INT = "int"
|
||||
final val JAVA_LONG = "long"
|
||||
final val JAVA_FLOAT = "float"
|
||||
final val JAVA_DOUBLE = "double"
|
||||
|
||||
/**
|
||||
* List of java primitive data types
|
||||
*/
|
||||
val primitiveTypes =
|
||||
Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE)
|
||||
|
||||
/**
|
||||
* Returns true if a Java type is Java primitive primitive type
|
||||
*/
|
||||
def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt)
|
||||
|
||||
def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt))
|
||||
|
||||
/**
|
||||
* Returns the specialized code to access a value from `inputRow` at `ordinal`.
|
||||
*/
|
||||
def getValue(input: String, dataType: DataType, ordinal: String): String = {
|
||||
val jt = javaType(dataType)
|
||||
dataType match {
|
||||
case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)"
|
||||
case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})"
|
||||
case StringType => s"$input.getUTF8String($ordinal)"
|
||||
case BinaryType => s"$input.getBinary($ordinal)"
|
||||
case CalendarIntervalType => s"$input.getInterval($ordinal)"
|
||||
case t: StructType => s"$input.getStruct($ordinal, ${t.size})"
|
||||
case _: ArrayType => s"$input.getArray($ordinal)"
|
||||
case _: MapType => s"$input.getMap($ordinal)"
|
||||
case NullType => "null"
|
||||
case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal)
|
||||
case _ => s"($jt)$input.get($ordinal, null)"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the code to update a column in Row for a given DataType.
|
||||
*/
|
||||
def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = {
|
||||
val jt = javaType(dataType)
|
||||
dataType match {
|
||||
case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
|
||||
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
|
||||
case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
|
||||
// The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy
|
||||
// it to avoid keeping a "pointer" to a memory region which may get updated afterwards.
|
||||
case StringType | _: StructType | _: ArrayType | _: MapType =>
|
||||
s"$row.update($ordinal, $value.copy())"
|
||||
case _ => s"$row.update($ordinal, $value)"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a column in MutableRow from ExprCode.
|
||||
*
|
||||
* @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise
|
||||
*/
|
||||
def updateColumn(
|
||||
row: String,
|
||||
dataType: DataType,
|
||||
ordinal: Int,
|
||||
ev: ExprCode,
|
||||
nullable: Boolean,
|
||||
isVectorized: Boolean = false): String = {
|
||||
if (nullable) {
|
||||
// Can't call setNullAt on DecimalType, because we need to keep the offset
|
||||
if (!isVectorized && dataType.isInstanceOf[DecimalType]) {
|
||||
s"""
|
||||
|if (!${ev.isNull}) {
|
||||
| ${setColumn(row, dataType, ordinal, ev.value)};
|
||||
|} else {
|
||||
| ${setColumn(row, dataType, ordinal, "null")};
|
||||
|}
|
||||
""".stripMargin
|
||||
} else {
|
||||
s"""
|
||||
|if (!${ev.isNull}) {
|
||||
| ${setColumn(row, dataType, ordinal, ev.value)};
|
||||
|} else {
|
||||
| $row.setNullAt($ordinal);
|
||||
|}
|
||||
""".stripMargin
|
||||
}
|
||||
} else {
|
||||
s"""${setColumn(row, dataType, ordinal, ev.value)};"""
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the specialized code to set a given value in a column vector for a given `DataType`.
|
||||
*/
|
||||
def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = {
|
||||
val jt = javaType(dataType)
|
||||
dataType match {
|
||||
case _ if isPrimitiveType(jt) =>
|
||||
s"$vector.put${primitiveTypeName(jt)}($rowId, $value);"
|
||||
case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});"
|
||||
case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());"
|
||||
case _ =>
|
||||
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the specialized code to set a given value in a column vector for a given `DataType`
|
||||
* that could potentially be nullable.
|
||||
*/
|
||||
def updateColumn(
|
||||
vector: String,
|
||||
rowId: String,
|
||||
dataType: DataType,
|
||||
ev: ExprCode,
|
||||
nullable: Boolean): String = {
|
||||
if (nullable) {
|
||||
s"""
|
||||
|if (!${ev.isNull}) {
|
||||
| ${setValue(vector, rowId, dataType, ev.value)}
|
||||
|} else {
|
||||
| $vector.putNull($rowId);
|
||||
|}
|
||||
""".stripMargin
|
||||
} else {
|
||||
s"""${setValue(vector, rowId, dataType, ev.value)};"""
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the specialized code to access a value from a column vector for a given `DataType`.
|
||||
*/
|
||||
def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = {
|
||||
if (dataType.isInstanceOf[StructType]) {
|
||||
// `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an
|
||||
// `ordinal` parameter.
|
||||
s"$vector.getStruct($rowId)"
|
||||
} else {
|
||||
getValue(vector, dataType, rowId)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the name used in accessor and setter for a Java primitive type.
|
||||
*/
|
||||
def primitiveTypeName(jt: String): String = jt match {
|
||||
case JAVA_INT => "Int"
|
||||
case _ => boxedType(jt)
|
||||
}
|
||||
|
||||
def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt))
|
||||
|
||||
/**
|
||||
* Returns the Java type for a DataType.
|
||||
*/
|
||||
def javaType(dt: DataType): String = dt match {
|
||||
case BooleanType => JAVA_BOOLEAN
|
||||
case ByteType => JAVA_BYTE
|
||||
case ShortType => JAVA_SHORT
|
||||
case IntegerType | DateType => JAVA_INT
|
||||
case LongType | TimestampType => JAVA_LONG
|
||||
case FloatType => JAVA_FLOAT
|
||||
case DoubleType => JAVA_DOUBLE
|
||||
case _: DecimalType => "Decimal"
|
||||
case BinaryType => "byte[]"
|
||||
case StringType => "UTF8String"
|
||||
case CalendarIntervalType => "CalendarInterval"
|
||||
case _: StructType => "InternalRow"
|
||||
case _: ArrayType => "ArrayData"
|
||||
case _: MapType => "MapData"
|
||||
case udt: UserDefinedType[_] => javaType(udt.sqlType)
|
||||
case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]"
|
||||
case ObjectType(cls) => cls.getName
|
||||
case _ => "Object"
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the boxed type in Java.
|
||||
*/
|
||||
def boxedType(jt: String): String = jt match {
|
||||
case JAVA_BOOLEAN => "Boolean"
|
||||
case JAVA_BYTE => "Byte"
|
||||
case JAVA_SHORT => "Short"
|
||||
case JAVA_INT => "Integer"
|
||||
case JAVA_LONG => "Long"
|
||||
case JAVA_FLOAT => "Float"
|
||||
case JAVA_DOUBLE => "Double"
|
||||
case other => other
|
||||
}
|
||||
|
||||
def boxedType(dt: DataType): String = boxedType(javaType(dt))
|
||||
|
||||
/**
|
||||
* Returns the representation of default value for a given Java Type.
|
||||
* @param jt the string name of the Java type
|
||||
* @param typedNull if true, for null literals, return a typed (with a cast) version
|
||||
*/
|
||||
def defaultValue(jt: String, typedNull: Boolean): String = jt match {
|
||||
case JAVA_BOOLEAN => "false"
|
||||
case JAVA_BYTE => "(byte)-1"
|
||||
case JAVA_SHORT => "(short)-1"
|
||||
case JAVA_INT => "-1"
|
||||
case JAVA_LONG => "-1L"
|
||||
case JAVA_FLOAT => "-1.0f"
|
||||
case JAVA_DOUBLE => "-1.0"
|
||||
case _ => if (typedNull) s"(($jt)null)" else "null"
|
||||
}
|
||||
|
||||
def defaultValue(dt: DataType, typedNull: Boolean = false): String =
|
||||
defaultValue(javaType(dt), typedNull)
|
||||
}
|
||||
|
|
|
@ -44,20 +44,21 @@ trait CodegenFallback extends Expression {
|
|||
}
|
||||
val objectTerm = ctx.freshName("obj")
|
||||
val placeHolder = ctx.registerComment(this.toString)
|
||||
val javaType = CodeGenerator.javaType(this.dataType)
|
||||
if (nullable) {
|
||||
ev.copy(code = s"""
|
||||
$placeHolder
|
||||
Object $objectTerm = ((Expression) references[$idx]).eval($input);
|
||||
boolean ${ev.isNull} = $objectTerm == null;
|
||||
${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(this.dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
|
||||
${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;
|
||||
}""")
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
$placeHolder
|
||||
Object $objectTerm = ((Expression) references[$idx]).eval($input);
|
||||
${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm;
|
||||
$javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;
|
||||
""", isNull = "false")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -62,9 +62,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
|
|||
val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map {
|
||||
case (ev, i) =>
|
||||
val e = expressions(i)
|
||||
val value = ctx.addMutableState(ctx.javaType(e.dataType), "value")
|
||||
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value")
|
||||
if (e.nullable) {
|
||||
val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "isNull")
|
||||
val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull")
|
||||
(s"""
|
||||
|${ev.code}
|
||||
|$isNull = ${ev.isNull};
|
||||
|
@ -84,7 +84,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
|
|||
val updates = validExpr.zip(projectionCodes).map {
|
||||
case (e, (_, isNull, value, i)) =>
|
||||
val ev = ExprCode("", isNull, value)
|
||||
ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
|
||||
CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
|
||||
}
|
||||
|
||||
val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1))
|
||||
|
|
|
@ -89,7 +89,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
|
|||
s"""
|
||||
${ctx.INPUT_ROW} = a;
|
||||
boolean $isNullA;
|
||||
${ctx.javaType(order.child.dataType)} $primitiveA;
|
||||
${CodeGenerator.javaType(order.child.dataType)} $primitiveA;
|
||||
{
|
||||
${eval.code}
|
||||
$isNullA = ${eval.isNull};
|
||||
|
@ -97,7 +97,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
|
|||
}
|
||||
${ctx.INPUT_ROW} = b;
|
||||
boolean $isNullB;
|
||||
${ctx.javaType(order.child.dataType)} $primitiveB;
|
||||
${CodeGenerator.javaType(order.child.dataType)} $primitiveB;
|
||||
{
|
||||
${eval.code}
|
||||
$isNullB = ${eval.isNull};
|
||||
|
|
|
@ -53,7 +53,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
|
|||
val rowClass = classOf[GenericInternalRow].getName
|
||||
|
||||
val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
|
||||
val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, i.toString), dt)
|
||||
val converter = convertToSafe(ctx, CodeGenerator.getValue(tmpInput, dt, i.toString), dt)
|
||||
s"""
|
||||
if (!$tmpInput.isNullAt($i)) {
|
||||
${converter.code}
|
||||
|
@ -90,7 +90,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
|
|||
val arrayClass = classOf[GenericArrayData].getName
|
||||
|
||||
val elementConverter = convertToSafe(
|
||||
ctx, ctx.getValue(tmpInput, elementType, index), elementType)
|
||||
ctx, CodeGenerator.getValue(tmpInput, elementType, index), elementType)
|
||||
val code = s"""
|
||||
final ArrayData $tmpInput = $input;
|
||||
final int $numElements = $tmpInput.numElements();
|
||||
|
@ -153,7 +153,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
|
|||
mutableRow.setNullAt($i);
|
||||
} else {
|
||||
${converter.code}
|
||||
${ctx.setColumn("mutableRow", e.dataType, i, converter.value)};
|
||||
${CodeGenerator.setColumn("mutableRow", e.dataType, i, converter.value)};
|
||||
}
|
||||
"""
|
||||
}
|
||||
|
|
|
@ -52,7 +52,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
|||
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
|
||||
val tmpInput = ctx.freshName("tmpInput")
|
||||
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
|
||||
ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, i.toString))
|
||||
ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString))
|
||||
}
|
||||
|
||||
s"""
|
||||
|
@ -195,16 +195,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
|||
case other => other
|
||||
}
|
||||
|
||||
val jt = ctx.javaType(et)
|
||||
val jt = CodeGenerator.javaType(et)
|
||||
|
||||
val elementOrOffsetSize = et match {
|
||||
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8
|
||||
case _ if ctx.isPrimitiveType(jt) => et.defaultSize
|
||||
case _ if CodeGenerator.isPrimitiveType(jt) => et.defaultSize
|
||||
case _ => 8 // we need 8 bytes to store offset and length
|
||||
}
|
||||
|
||||
val tmpCursor = ctx.freshName("tmpCursor")
|
||||
val element = ctx.getValue(tmpInput, et, index)
|
||||
val element = CodeGenerator.getValue(tmpInput, et, index)
|
||||
val writeElement = et match {
|
||||
case t: StructType =>
|
||||
s"""
|
||||
|
@ -235,7 +235,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
|||
case _ => s"$arrayWriter.write($index, $element);"
|
||||
}
|
||||
|
||||
val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else ""
|
||||
val primitiveTypeName =
|
||||
if (CodeGenerator.isPrimitiveType(jt)) CodeGenerator.primitiveTypeName(et) else ""
|
||||
s"""
|
||||
final ArrayData $tmpInput = $input;
|
||||
if ($tmpInput instanceof UnsafeArrayData) {
|
||||
|
|
|
@ -20,7 +20,7 @@ import java.util.Comparator
|
|||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -54,7 +54,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
|
|||
ev.copy(code = s"""
|
||||
boolean ${ev.isNull} = false;
|
||||
${childGen.code}
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
|
||||
(${childGen.value}).numElements();""", isNull = "false")
|
||||
}
|
||||
}
|
||||
|
@ -270,7 +270,7 @@ case class ArrayContains(left: Expression, right: Expression)
|
|||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
nullSafeCodeGen(ctx, ev, (arr, value) => {
|
||||
val i = ctx.freshName("i")
|
||||
val getValue = ctx.getValue(arr, right.dataType, i)
|
||||
val getValue = CodeGenerator.getValue(arr, right.dataType, i)
|
||||
s"""
|
||||
for (int $i = 0; $i < $arr.numElements(); $i ++) {
|
||||
if ($arr.isNullAt($i)) {
|
||||
|
|
|
@ -90,7 +90,7 @@ private [sql] object GenArrayData {
|
|||
val arrayDataName = ctx.freshName("arrayData")
|
||||
val numElements = elementsCode.length
|
||||
|
||||
if (!ctx.isPrimitiveType(elementType)) {
|
||||
if (!CodeGenerator.isPrimitiveType(elementType)) {
|
||||
val arrayName = ctx.freshName("arrayObject")
|
||||
val genericArrayClass = classOf[GenericArrayData].getName
|
||||
|
||||
|
@ -124,7 +124,7 @@ private [sql] object GenArrayData {
|
|||
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
|
||||
val baseOffset = Platform.BYTE_ARRAY_OFFSET
|
||||
|
||||
val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
|
||||
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
|
||||
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
|
||||
val isNullAssignment = if (!isMapKey) {
|
||||
s"$arrayDataName.setNullAt($i);"
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -129,12 +129,12 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
|
|||
if ($eval.isNullAt($ordinal)) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
|
||||
${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)};
|
||||
}
|
||||
"""
|
||||
} else {
|
||||
s"""
|
||||
${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
|
||||
${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)};
|
||||
"""
|
||||
}
|
||||
})
|
||||
|
@ -205,7 +205,7 @@ case class GetArrayStructFields(
|
|||
} else {
|
||||
final InternalRow $row = $eval.getStruct($j, $numFields);
|
||||
$nullSafeEval {
|
||||
$values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)};
|
||||
$values[$j] = ${CodeGenerator.getValue(row, field.dataType, ordinal.toString)};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
|
|||
if ($index >= $eval1.numElements() || $index < 0$nullCheck) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
${ev.value} = ${ctx.getValue(eval1, dataType, index)};
|
||||
${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
|
||||
}
|
||||
"""
|
||||
})
|
||||
|
@ -327,6 +327,7 @@ case class GetMapValue(child: Expression, key: Expression)
|
|||
} else {
|
||||
""
|
||||
}
|
||||
val keyJavaType = CodeGenerator.javaType(keyType)
|
||||
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
|
||||
s"""
|
||||
final int $length = $eval1.numElements();
|
||||
|
@ -336,7 +337,7 @@ case class GetMapValue(child: Expression, key: Expression)
|
|||
int $index = 0;
|
||||
boolean $found = false;
|
||||
while ($index < $length && !$found) {
|
||||
final ${ctx.javaType(keyType)} $key = ${ctx.getValue(keys, keyType, index)};
|
||||
final $keyJavaType $key = ${CodeGenerator.getValue(keys, keyType, index)};
|
||||
if (${ctx.genEqual(keyType, key, eval2)}) {
|
||||
$found = true;
|
||||
} else {
|
||||
|
@ -347,7 +348,7 @@ case class GetMapValue(child: Expression, key: Expression)
|
|||
if (!$found$nullCheck) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
${ev.value} = ${ctx.getValue(values, dataType, index)};
|
||||
${ev.value} = ${CodeGenerator.getValue(values, dataType, index)};
|
||||
}
|
||||
"""
|
||||
})
|
||||
|
|
|
@ -69,7 +69,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
|
|||
s"""
|
||||
|${condEval.code}
|
||||
|boolean ${ev.isNull} = false;
|
||||
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|if (!${condEval.isNull} && ${condEval.value}) {
|
||||
| ${trueEval.code}
|
||||
| ${ev.isNull} = ${trueEval.isNull};
|
||||
|
@ -191,7 +191,7 @@ case class CaseWhen(
|
|||
// It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`,
|
||||
// We won't go on anymore on the computation.
|
||||
val resultState = ctx.freshName("caseWhenResultState")
|
||||
ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value)
|
||||
ev.value = ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value)
|
||||
|
||||
// these blocks are meant to be inside a
|
||||
// do {
|
||||
|
@ -244,10 +244,10 @@ case class CaseWhen(
|
|||
val codes = ctx.splitExpressionsWithCurrentInputs(
|
||||
expressions = allConditions,
|
||||
funcName = "caseWhen",
|
||||
returnType = ctx.JAVA_BYTE,
|
||||
returnType = CodeGenerator.JAVA_BYTE,
|
||||
makeSplitFunction = func =>
|
||||
s"""
|
||||
|${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
|
||||
|${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED;
|
||||
|do {
|
||||
| $func
|
||||
|} while (false);
|
||||
|
@ -264,7 +264,7 @@ case class CaseWhen(
|
|||
|
||||
ev.copy(code =
|
||||
s"""
|
||||
|${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
|
||||
|${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED;
|
||||
|do {
|
||||
| $codes
|
||||
|} while (false);
|
||||
|
|
|
@ -26,7 +26,7 @@ import scala.util.control.NonFatal
|
|||
import org.apache.commons.lang3.StringEscapeUtils
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||
|
@ -673,18 +673,19 @@ abstract class UnixTime
|
|||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
left.dataType match {
|
||||
case StringType if right.foldable =>
|
||||
val df = classOf[DateFormat].getName
|
||||
if (formatter == null) {
|
||||
ExprCode("", "true", ctx.defaultValue(dataType))
|
||||
ExprCode.forNullValue(dataType)
|
||||
} else {
|
||||
val formatterName = ctx.addReferenceObj("formatter", formatter, df)
|
||||
val eval1 = left.genCode(ctx)
|
||||
ev.copy(code = s"""
|
||||
${eval1.code}
|
||||
boolean ${ev.isNull} = ${eval1.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
try {
|
||||
${ev.value} = $formatterName.parse(${eval1.value}.toString()).getTime() / 1000L;
|
||||
|
@ -713,7 +714,7 @@ abstract class UnixTime
|
|||
ev.copy(code = s"""
|
||||
${eval1.code}
|
||||
boolean ${ev.isNull} = ${eval1.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
${ev.value} = ${eval1.value} / 1000000L;
|
||||
}""")
|
||||
|
@ -724,7 +725,7 @@ abstract class UnixTime
|
|||
ev.copy(code = s"""
|
||||
${eval1.code}
|
||||
boolean ${ev.isNull} = ${eval1.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
${ev.value} = $dtu.daysToMillis(${eval1.value}, $tz) / 1000L;
|
||||
}""")
|
||||
|
@ -819,7 +820,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
|
|||
ev.copy(code = s"""
|
||||
${t.code}
|
||||
boolean ${ev.isNull} = ${t.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
try {
|
||||
${ev.value} = UTF8String.fromString($formatterName.format(
|
||||
|
@ -1344,18 +1345,19 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes {
|
|||
: ExprCode = {
|
||||
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
|
||||
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
if (format.foldable) {
|
||||
if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) {
|
||||
ev.copy(code = s"""
|
||||
boolean ${ev.isNull} = true;
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
|
||||
} else {
|
||||
val t = instant.genCode(ctx)
|
||||
val truncFuncStr = truncFunc(t.value, truncLevel.toString)
|
||||
ev.copy(code = s"""
|
||||
${t.code}
|
||||
boolean ${ev.isNull} = ${t.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
${ev.value} = $dtu.$truncFuncStr;
|
||||
}""")
|
||||
|
|
|
@ -278,7 +278,7 @@ abstract class HashExpression[E] extends Expression {
|
|||
}
|
||||
}
|
||||
|
||||
val hashResultType = ctx.javaType(dataType)
|
||||
val hashResultType = CodeGenerator.javaType(dataType)
|
||||
val codes = ctx.splitExpressionsWithCurrentInputs(
|
||||
expressions = childrenHash,
|
||||
funcName = "computeHash",
|
||||
|
@ -307,9 +307,10 @@ abstract class HashExpression[E] extends Expression {
|
|||
ctx: CodegenContext): String = {
|
||||
val element = ctx.freshName("element")
|
||||
|
||||
val jt = CodeGenerator.javaType(elementType)
|
||||
ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") {
|
||||
s"""
|
||||
final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)};
|
||||
final $jt $element = ${CodeGenerator.getValue(input, elementType, index)};
|
||||
${computeHash(element, elementType, result, ctx)}
|
||||
"""
|
||||
}
|
||||
|
@ -407,7 +408,7 @@ abstract class HashExpression[E] extends Expression {
|
|||
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
|
||||
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
|
||||
}
|
||||
val hashResultType = ctx.javaType(dataType)
|
||||
val hashResultType = CodeGenerator.javaType(dataType)
|
||||
ctx.splitExpressions(
|
||||
expressions = fieldsHash,
|
||||
funcName = "computeHashForStruct",
|
||||
|
@ -651,11 +652,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
|
|||
val codes = ctx.splitExpressionsWithCurrentInputs(
|
||||
expressions = childrenHash,
|
||||
funcName = "computeHash",
|
||||
extraArguments = Seq(ctx.JAVA_INT -> ev.value),
|
||||
returnType = ctx.JAVA_INT,
|
||||
extraArguments = Seq(CodeGenerator.JAVA_INT -> ev.value),
|
||||
returnType = CodeGenerator.JAVA_INT,
|
||||
makeSplitFunction = body =>
|
||||
s"""
|
||||
|${ctx.JAVA_INT} $childHash = 0;
|
||||
|${CodeGenerator.JAVA_INT} $childHash = 0;
|
||||
|$body
|
||||
|return ${ev.value};
|
||||
""".stripMargin,
|
||||
|
@ -664,8 +665,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
|
|||
|
||||
ev.copy(code =
|
||||
s"""
|
||||
|${ctx.JAVA_INT} ${ev.value} = $seed;
|
||||
|${ctx.JAVA_INT} $childHash = 0;
|
||||
|${CodeGenerator.JAVA_INT} ${ev.value} = $seed;
|
||||
|${CodeGenerator.JAVA_INT} $childHash = 0;
|
||||
|$codes
|
||||
""".stripMargin)
|
||||
}
|
||||
|
@ -780,14 +781,14 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
|
|||
""".stripMargin
|
||||
}
|
||||
|
||||
s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions(
|
||||
s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions(
|
||||
expressions = fieldsHash,
|
||||
funcName = "computeHashForStruct",
|
||||
arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result),
|
||||
returnType = ctx.JAVA_INT,
|
||||
arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result),
|
||||
returnType = CodeGenerator.JAVA_INT,
|
||||
makeSplitFunction = body =>
|
||||
s"""
|
||||
|${ctx.JAVA_INT} $childResult = 0;
|
||||
|${CodeGenerator.JAVA_INT} $childResult = 0;
|
||||
|$body
|
||||
|return $result;
|
||||
""".stripMargin,
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
|
||||
import org.apache.spark.rdd.InputFileBlockHolder
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.types.{DataType, LongType, StringType}
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
|
@ -42,7 +42,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
|
|||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
|
||||
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
|
||||
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
|
||||
s"$className.getInputFilePath();", isNull = "false")
|
||||
}
|
||||
}
|
||||
|
@ -65,7 +65,7 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic {
|
|||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
|
||||
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
|
||||
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
|
||||
s"$className.getStartOffset();", isNull = "false")
|
||||
}
|
||||
}
|
||||
|
@ -88,7 +88,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic {
|
|||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
|
||||
ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " +
|
||||
ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
|
||||
s"$className.getLength();", isNull = "false")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -277,13 +277,9 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
|
|||
override def eval(input: InternalRow): Any = value
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val javaType = ctx.javaType(dataType)
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
if (value == null) {
|
||||
val defaultValueLiteral = ctx.defaultValue(javaType) match {
|
||||
case "null" => s"(($javaType)null)"
|
||||
case lit => lit
|
||||
}
|
||||
ExprCode(code = "", isNull = "true", value = defaultValueLiteral)
|
||||
ExprCode.forNullValue(dataType)
|
||||
} else {
|
||||
dataType match {
|
||||
case BooleanType | IntegerType | DateType =>
|
||||
|
|
|
@ -1128,15 +1128,16 @@ abstract class RoundBase(child: Expression, scale: Expression,
|
|||
}"""
|
||||
}
|
||||
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
if (scaleV == null) { // if scale is null, no need to eval its child at all
|
||||
ev.copy(code = s"""
|
||||
boolean ${ev.isNull} = true;
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
|
||||
} else {
|
||||
ev.copy(code = s"""
|
||||
${ce.code}
|
||||
boolean ${ev.isNull} = ${ce.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
$evaluationCode
|
||||
}""")
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -72,7 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
|
|||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
|
||||
ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull)
|
||||
|
||||
// all the evals are meant to be in a do { ... } while (false); loop
|
||||
val evals = children.map { e =>
|
||||
|
@ -87,14 +87,14 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
|
|||
""".stripMargin
|
||||
}
|
||||
|
||||
val resultType = ctx.javaType(dataType)
|
||||
val resultType = CodeGenerator.javaType(dataType)
|
||||
val codes = ctx.splitExpressionsWithCurrentInputs(
|
||||
expressions = evals,
|
||||
funcName = "coalesce",
|
||||
returnType = resultType,
|
||||
makeSplitFunction = func =>
|
||||
s"""
|
||||
|$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|do {
|
||||
| $func
|
||||
|} while (false);
|
||||
|
@ -113,7 +113,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
|
|||
ev.copy(code =
|
||||
s"""
|
||||
|${ev.isNull} = true;
|
||||
|$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|do {
|
||||
| $codes
|
||||
|} while (false);
|
||||
|
@ -234,7 +234,7 @@ case class IsNaN(child: Expression) extends UnaryExpression
|
|||
case DoubleType | FloatType =>
|
||||
ev.copy(code = s"""
|
||||
${eval.code}
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false")
|
||||
}
|
||||
}
|
||||
|
@ -281,7 +281,7 @@ case class NaNvl(left: Expression, right: Expression)
|
|||
ev.copy(code = s"""
|
||||
${leftGen.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (${leftGen.isNull}) {
|
||||
${ev.isNull} = true;
|
||||
} else {
|
||||
|
@ -416,8 +416,8 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
|
|||
val codes = ctx.splitExpressionsWithCurrentInputs(
|
||||
expressions = evals,
|
||||
funcName = "atLeastNNonNulls",
|
||||
extraArguments = (ctx.JAVA_INT, nonnull) :: Nil,
|
||||
returnType = ctx.JAVA_INT,
|
||||
extraArguments = (CodeGenerator.JAVA_INT, nonnull) :: Nil,
|
||||
returnType = CodeGenerator.JAVA_INT,
|
||||
makeSplitFunction = body =>
|
||||
s"""
|
||||
|do {
|
||||
|
@ -436,11 +436,11 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
|
|||
|
||||
ev.copy(code =
|
||||
s"""
|
||||
|${ctx.JAVA_INT} $nonnull = 0;
|
||||
|${CodeGenerator.JAVA_INT} $nonnull = 0;
|
||||
|do {
|
||||
| $codes
|
||||
|} while (false);
|
||||
|${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n;
|
||||
|${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n;
|
||||
""".stripMargin, isNull = "false")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
|
||||
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -62,13 +62,13 @@ trait InvokeLike extends Expression with NonSQLExpression {
|
|||
def prepareArguments(ctx: CodegenContext): (String, String, String) = {
|
||||
|
||||
val resultIsNull = if (needNullCheck) {
|
||||
val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull")
|
||||
val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull")
|
||||
resultIsNull
|
||||
} else {
|
||||
"false"
|
||||
}
|
||||
val argValues = arguments.map { e =>
|
||||
val argValue = ctx.addMutableState(ctx.javaType(e.dataType), "argValue")
|
||||
val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue")
|
||||
argValue
|
||||
}
|
||||
|
||||
|
@ -137,7 +137,7 @@ case class StaticInvoke(
|
|||
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val javaType = ctx.javaType(dataType)
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
|
||||
val (argCode, argString, resultIsNull) = prepareArguments(ctx)
|
||||
|
||||
|
@ -151,7 +151,7 @@ case class StaticInvoke(
|
|||
}
|
||||
|
||||
val evaluate = if (returnNullable) {
|
||||
if (ctx.defaultValue(dataType) == "null") {
|
||||
if (CodeGenerator.defaultValue(dataType) == "null") {
|
||||
s"""
|
||||
${ev.value} = $callFunc;
|
||||
${ev.isNull} = ${ev.value} == null;
|
||||
|
@ -159,7 +159,7 @@ case class StaticInvoke(
|
|||
} else {
|
||||
val boxedResult = ctx.freshName("boxedResult")
|
||||
s"""
|
||||
${ctx.boxedType(dataType)} $boxedResult = $callFunc;
|
||||
${CodeGenerator.boxedType(dataType)} $boxedResult = $callFunc;
|
||||
${ev.isNull} = $boxedResult == null;
|
||||
if (!${ev.isNull}) {
|
||||
${ev.value} = $boxedResult;
|
||||
|
@ -173,7 +173,7 @@ case class StaticInvoke(
|
|||
val code = s"""
|
||||
$argCode
|
||||
$prepareIsNull
|
||||
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!$resultIsNull) {
|
||||
$evaluate
|
||||
}
|
||||
|
@ -228,7 +228,7 @@ case class Invoke(
|
|||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val javaType = ctx.javaType(dataType)
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
val obj = targetObject.genCode(ctx)
|
||||
|
||||
val (argCode, argString, resultIsNull) = prepareArguments(ctx)
|
||||
|
@ -255,11 +255,11 @@ case class Invoke(
|
|||
// If the function can return null, we do an extra check to make sure our null bit is still
|
||||
// set correctly.
|
||||
val assignResult = if (!returnNullable) {
|
||||
s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;"
|
||||
s"${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;"
|
||||
} else {
|
||||
s"""
|
||||
if ($funcResult != null) {
|
||||
${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;
|
||||
${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;
|
||||
} else {
|
||||
${ev.isNull} = true;
|
||||
}
|
||||
|
@ -275,7 +275,7 @@ case class Invoke(
|
|||
val code = s"""
|
||||
${obj.code}
|
||||
boolean ${ev.isNull} = true;
|
||||
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!${obj.isNull}) {
|
||||
$argCode
|
||||
${ev.isNull} = $resultIsNull;
|
||||
|
@ -341,7 +341,7 @@ case class NewInstance(
|
|||
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val javaType = ctx.javaType(dataType)
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
|
||||
val (argCode, argString, resultIsNull) = prepareArguments(ctx)
|
||||
|
||||
|
@ -358,7 +358,8 @@ case class NewInstance(
|
|||
val code = s"""
|
||||
$argCode
|
||||
${outer.map(_.code).getOrElse("")}
|
||||
final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall;
|
||||
final $javaType ${ev.value} = ${ev.isNull} ?
|
||||
${CodeGenerator.defaultValue(dataType)} : $constructorCall;
|
||||
"""
|
||||
ev.copy(code = code)
|
||||
}
|
||||
|
@ -385,15 +386,15 @@ case class UnwrapOption(
|
|||
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val javaType = ctx.javaType(dataType)
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
val inputObject = child.genCode(ctx)
|
||||
|
||||
val code = s"""
|
||||
${inputObject.code}
|
||||
|
||||
final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
|
||||
$javaType ${ev.value} = ${ev.isNull} ?
|
||||
${ctx.defaultValue(javaType)} : (${ctx.boxedType(javaType)}) ${inputObject.value}.get();
|
||||
$javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} :
|
||||
(${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get();
|
||||
"""
|
||||
ev.copy(code = code)
|
||||
}
|
||||
|
@ -546,7 +547,7 @@ case class MapObjects private(
|
|||
ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable))
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val elementJavaType = ctx.javaType(loopVarDataType)
|
||||
val elementJavaType = CodeGenerator.javaType(loopVarDataType)
|
||||
ctx.addMutableState(elementJavaType, loopValue, forceInline = true, useFreshName = false)
|
||||
val genInputData = inputData.genCode(ctx)
|
||||
val genFunction = lambdaFunction.genCode(ctx)
|
||||
|
@ -554,7 +555,7 @@ case class MapObjects private(
|
|||
val convertedArray = ctx.freshName("convertedArray")
|
||||
val loopIndex = ctx.freshName("loopIndex")
|
||||
|
||||
val convertedType = ctx.boxedType(lambdaFunction.dataType)
|
||||
val convertedType = CodeGenerator.boxedType(lambdaFunction.dataType)
|
||||
|
||||
// Because of the way Java defines nested arrays, we have to handle the syntax specially.
|
||||
// Specifically, we have to insert the [$dataLength] in between the type and any extra nested
|
||||
|
@ -621,7 +622,7 @@ case class MapObjects private(
|
|||
(
|
||||
s"${genInputData.value}.numElements()",
|
||||
"",
|
||||
ctx.getValue(genInputData.value, et, loopIndex)
|
||||
CodeGenerator.getValue(genInputData.value, et, loopIndex)
|
||||
)
|
||||
case ObjectType(cls) if cls == classOf[Object] =>
|
||||
val it = ctx.freshName("it")
|
||||
|
@ -643,7 +644,8 @@ case class MapObjects private(
|
|||
}
|
||||
|
||||
val loopNullCheck = if (loopIsNull != "false") {
|
||||
ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false)
|
||||
ctx.addMutableState(
|
||||
CodeGenerator.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false)
|
||||
inputDataType match {
|
||||
case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
|
||||
case _ => s"$loopIsNull = $loopValue == null;"
|
||||
|
@ -695,7 +697,7 @@ case class MapObjects private(
|
|||
|
||||
val code = s"""
|
||||
${genInputData.code}
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
||||
if (!${genInputData.isNull}) {
|
||||
$determineCollectionType
|
||||
|
@ -806,10 +808,10 @@ case class CatalystToExternalMap private(
|
|||
}
|
||||
|
||||
val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType]
|
||||
val keyElementJavaType = ctx.javaType(mapType.keyType)
|
||||
val keyElementJavaType = CodeGenerator.javaType(mapType.keyType)
|
||||
ctx.addMutableState(keyElementJavaType, keyLoopValue, forceInline = true, useFreshName = false)
|
||||
val genKeyFunction = keyLambdaFunction.genCode(ctx)
|
||||
val valueElementJavaType = ctx.javaType(mapType.valueType)
|
||||
val valueElementJavaType = CodeGenerator.javaType(mapType.valueType)
|
||||
ctx.addMutableState(valueElementJavaType, valueLoopValue, forceInline = true,
|
||||
useFreshName = false)
|
||||
val genValueFunction = valueLambdaFunction.genCode(ctx)
|
||||
|
@ -825,10 +827,11 @@ case class CatalystToExternalMap private(
|
|||
val valueArray = ctx.freshName("valueArray")
|
||||
val getKeyArray =
|
||||
s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();"
|
||||
val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex)
|
||||
val getKeyLoopVar = CodeGenerator.getValue(keyArray, inputDataType(mapType.keyType), loopIndex)
|
||||
val getValueArray =
|
||||
s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();"
|
||||
val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex)
|
||||
val getValueLoopVar = CodeGenerator.getValue(
|
||||
valueArray, inputDataType(mapType.valueType), loopIndex)
|
||||
|
||||
// Make a copy of the data if it's unsafe-backed
|
||||
def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
|
||||
|
@ -844,7 +847,7 @@ case class CatalystToExternalMap private(
|
|||
val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction)
|
||||
|
||||
val valueLoopNullCheck = if (valueLoopIsNull != "false") {
|
||||
ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true,
|
||||
ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true,
|
||||
useFreshName = false)
|
||||
s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);"
|
||||
} else {
|
||||
|
@ -873,7 +876,7 @@ case class CatalystToExternalMap private(
|
|||
|
||||
val code = s"""
|
||||
${genInputData.code}
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
|
||||
if (!${genInputData.isNull}) {
|
||||
int $dataLength = $getLength;
|
||||
|
@ -993,8 +996,8 @@ case class ExternalMapToCatalyst private(
|
|||
val entry = ctx.freshName("entry")
|
||||
val entries = ctx.freshName("entries")
|
||||
|
||||
val keyElementJavaType = ctx.javaType(keyType)
|
||||
val valueElementJavaType = ctx.javaType(valueType)
|
||||
val keyElementJavaType = CodeGenerator.javaType(keyType)
|
||||
val valueElementJavaType = CodeGenerator.javaType(valueType)
|
||||
ctx.addMutableState(keyElementJavaType, key, forceInline = true, useFreshName = false)
|
||||
ctx.addMutableState(valueElementJavaType, value, forceInline = true, useFreshName = false)
|
||||
|
||||
|
@ -1009,8 +1012,8 @@ case class ExternalMapToCatalyst private(
|
|||
val defineKeyValue =
|
||||
s"""
|
||||
final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next();
|
||||
$key = (${ctx.boxedType(keyType)}) $entry.getKey();
|
||||
$value = (${ctx.boxedType(valueType)}) $entry.getValue();
|
||||
$key = (${CodeGenerator.boxedType(keyType)}) $entry.getKey();
|
||||
$value = (${CodeGenerator.boxedType(valueType)}) $entry.getValue();
|
||||
"""
|
||||
|
||||
defineEntries -> defineKeyValue
|
||||
|
@ -1024,22 +1027,24 @@ case class ExternalMapToCatalyst private(
|
|||
val defineKeyValue =
|
||||
s"""
|
||||
final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next();
|
||||
$key = (${ctx.boxedType(keyType)}) $entry._1();
|
||||
$value = (${ctx.boxedType(valueType)}) $entry._2();
|
||||
$key = (${CodeGenerator.boxedType(keyType)}) $entry._1();
|
||||
$value = (${CodeGenerator.boxedType(valueType)}) $entry._2();
|
||||
"""
|
||||
|
||||
defineEntries -> defineKeyValue
|
||||
}
|
||||
|
||||
val keyNullCheck = if (keyIsNull != "false") {
|
||||
ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false)
|
||||
ctx.addMutableState(
|
||||
CodeGenerator.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false)
|
||||
s"$keyIsNull = $key == null;"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
|
||||
val valueNullCheck = if (valueIsNull != "false") {
|
||||
ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false)
|
||||
ctx.addMutableState(
|
||||
CodeGenerator.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false)
|
||||
s"$valueIsNull = $value == null;"
|
||||
} else {
|
||||
""
|
||||
|
@ -1047,12 +1052,12 @@ case class ExternalMapToCatalyst private(
|
|||
|
||||
val arrayCls = classOf[GenericArrayData].getName
|
||||
val mapCls = classOf[ArrayBasedMapData].getName
|
||||
val convertedKeyType = ctx.boxedType(keyConverter.dataType)
|
||||
val convertedValueType = ctx.boxedType(valueConverter.dataType)
|
||||
val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType)
|
||||
val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType)
|
||||
val code =
|
||||
s"""
|
||||
${inputMap.code}
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!${inputMap.isNull}) {
|
||||
final int $length = ${inputMap.value}.size();
|
||||
final Object[] $convertedKeys = new Object[$length];
|
||||
|
@ -1174,12 +1179,13 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
|
|||
|
||||
// Code to serialize.
|
||||
val input = child.genCode(ctx)
|
||||
val javaType = ctx.javaType(dataType)
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
val serialize = s"$serializer.serialize(${input.value}, null).array()"
|
||||
|
||||
val code = s"""
|
||||
${input.code}
|
||||
final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $serialize;
|
||||
final $javaType ${ev.value} =
|
||||
${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize;
|
||||
"""
|
||||
ev.copy(code = code, isNull = input.isNull)
|
||||
}
|
||||
|
@ -1223,13 +1229,14 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
|
|||
|
||||
// Code to deserialize.
|
||||
val input = child.genCode(ctx)
|
||||
val javaType = ctx.javaType(dataType)
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
val deserialize =
|
||||
s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)"
|
||||
|
||||
val code = s"""
|
||||
${input.code}
|
||||
final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize;
|
||||
final $javaType ${ev.value} =
|
||||
${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize;
|
||||
"""
|
||||
ev.copy(code = code, isNull = input.isNull)
|
||||
}
|
||||
|
@ -1254,7 +1261,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
|
|||
val instanceGen = beanInstance.genCode(ctx)
|
||||
|
||||
val javaBeanInstance = ctx.freshName("javaBean")
|
||||
val beanInstanceJavaType = ctx.javaType(beanInstance.dataType)
|
||||
val beanInstanceJavaType = CodeGenerator.javaType(beanInstance.dataType)
|
||||
|
||||
val initialize = setters.map {
|
||||
case (setterMethod, fieldValue) =>
|
||||
|
@ -1405,15 +1412,15 @@ case class ValidateExternalType(child: Expression, expected: DataType)
|
|||
case _: ArrayType =>
|
||||
s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()"
|
||||
case _ =>
|
||||
s"$obj instanceof ${ctx.boxedType(dataType)}"
|
||||
s"$obj instanceof ${CodeGenerator.boxedType(dataType)}"
|
||||
}
|
||||
|
||||
val code = s"""
|
||||
${input.code}
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!${input.isNull}) {
|
||||
if ($typeCheck) {
|
||||
${ev.value} = (${ctx.boxedType(dataType)}) $obj;
|
||||
${ev.value} = (${CodeGenerator.boxedType(dataType)}) $obj;
|
||||
} else {
|
||||
throw new RuntimeException($obj.getClass().getName() + $errMsgField);
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ import scala.collection.immutable.TreeSet
|
|||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -235,7 +235,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|
|||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val javaDataType = ctx.javaType(value.dataType)
|
||||
val javaDataType = CodeGenerator.javaType(value.dataType)
|
||||
val valueGen = value.genCode(ctx)
|
||||
val listGen = list.map(_.genCode(ctx))
|
||||
// inTmpResult has 3 possible values:
|
||||
|
@ -263,8 +263,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|
|||
val codes = ctx.splitExpressionsWithCurrentInputs(
|
||||
expressions = listCode,
|
||||
funcName = "valueIn",
|
||||
extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil,
|
||||
returnType = ctx.JAVA_BYTE,
|
||||
extraArguments = (javaDataType, valueArg) :: (CodeGenerator.JAVA_BYTE, tmpResult) :: Nil,
|
||||
returnType = CodeGenerator.JAVA_BYTE,
|
||||
makeSplitFunction = body =>
|
||||
s"""
|
||||
|do {
|
||||
|
@ -348,8 +348,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
|
|||
ev.copy(code =
|
||||
s"""
|
||||
|${childGen.code}
|
||||
|${ctx.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull};
|
||||
|${ctx.JAVA_BOOLEAN} ${ev.value} = false;
|
||||
|${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull};
|
||||
|${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false;
|
||||
|if (!${ev.isNull}) {
|
||||
| ${ev.value} = $setTerm.contains(${childGen.value});
|
||||
| $setIsNull
|
||||
|
@ -505,7 +505,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
|
|||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
if (ctx.isPrimitiveType(left.dataType)
|
||||
if (CodeGenerator.isPrimitiveType(left.dataType)
|
||||
&& left.dataType != BooleanType // java boolean doesn't support > or < operator
|
||||
&& left.dataType != FloatType
|
||||
&& left.dataType != DoubleType) {
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.util.random.XORShiftRandom
|
||||
|
@ -82,7 +82,8 @@ case class Rand(child: Expression) extends RDG {
|
|||
ctx.addPartitionInitializationStatement(
|
||||
s"$rngTerm = new $className(${seed}L + partitionIndex);")
|
||||
ev.copy(code = s"""
|
||||
final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false")
|
||||
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""",
|
||||
isNull = "false")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -116,7 +117,8 @@ case class Randn(child: Expression) extends RDG {
|
|||
ctx.addPartitionInitializationStatement(
|
||||
s"$rngTerm = new $className(${seed}L + partitionIndex);")
|
||||
ev.copy(code = s"""
|
||||
final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false")
|
||||
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""",
|
||||
isNull = "false")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -126,7 +126,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
|
|||
ev.copy(code = s"""
|
||||
${eval.code}
|
||||
boolean ${ev.isNull} = ${eval.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
${ev.value} = $pattern.matcher(${eval.value}.toString()).matches();
|
||||
}
|
||||
|
@ -134,7 +134,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
|
|||
} else {
|
||||
ev.copy(code = s"""
|
||||
boolean ${ev.isNull} = true;
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
""")
|
||||
}
|
||||
} else {
|
||||
|
@ -201,7 +201,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
|
|||
ev.copy(code = s"""
|
||||
${eval.code}
|
||||
boolean ${ev.isNull} = ${eval.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0);
|
||||
}
|
||||
|
@ -209,7 +209,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
|
|||
} else {
|
||||
ev.copy(code = s"""
|
||||
boolean ${ev.isNull} = true;
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
""")
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -102,11 +102,11 @@ case class Concat(children: Seq[Expression]) extends Expression {
|
|||
val codes = ctx.splitExpressionsWithCurrentInputs(
|
||||
expressions = inputs,
|
||||
funcName = "valueConcat",
|
||||
extraArguments = (s"${ctx.javaType(dataType)}[]", args) :: Nil)
|
||||
extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil)
|
||||
ev.copy(s"""
|
||||
$initCode
|
||||
$codes
|
||||
${ctx.javaType(dataType)} ${ev.value} = $concatenator.concat($args);
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args);
|
||||
boolean ${ev.isNull} = ${ev.value} == null;
|
||||
""")
|
||||
}
|
||||
|
@ -196,7 +196,7 @@ case class ConcatWs(children: Seq[Expression])
|
|||
} else {
|
||||
val array = ctx.freshName("array")
|
||||
val varargNum = ctx.freshName("varargNum")
|
||||
val idxInVararg = ctx.freshName("idxInVararg")
|
||||
val idxVararg = ctx.freshName("idxInVararg")
|
||||
|
||||
val evals = children.map(_.genCode(ctx))
|
||||
val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) =>
|
||||
|
@ -206,7 +206,7 @@ case class ConcatWs(children: Seq[Expression])
|
|||
if (eval.isNull == "true") {
|
||||
""
|
||||
} else {
|
||||
s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};"
|
||||
s"$array[$idxVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};"
|
||||
})
|
||||
case _: ArrayType =>
|
||||
val size = ctx.freshName("n")
|
||||
|
@ -222,7 +222,7 @@ case class ConcatWs(children: Seq[Expression])
|
|||
if (!${eval.isNull}) {
|
||||
final int $size = ${eval.value}.numElements();
|
||||
for (int j = 0; j < $size; j ++) {
|
||||
$array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")};
|
||||
$array[$idxVararg ++] = ${CodeGenerator.getValue(eval.value, StringType, "j")};
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
@ -247,20 +247,20 @@ case class ConcatWs(children: Seq[Expression])
|
|||
val varargBuilds = ctx.splitExpressionsWithCurrentInputs(
|
||||
expressions = varargBuild,
|
||||
funcName = "varargBuildsConcatWs",
|
||||
extraArguments = ("UTF8String []", array) :: ("int", idxInVararg) :: Nil,
|
||||
extraArguments = ("UTF8String []", array) :: ("int", idxVararg) :: Nil,
|
||||
returnType = "int",
|
||||
makeSplitFunction = body =>
|
||||
s"""
|
||||
|$body
|
||||
|return $idxInVararg;
|
||||
|return $idxVararg;
|
||||
""".stripMargin,
|
||||
foldFunctions = _.map(funcCall => s"$idxInVararg = $funcCall;").mkString("\n"))
|
||||
foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n"))
|
||||
|
||||
ev.copy(
|
||||
s"""
|
||||
$codes
|
||||
int $varargNum = ${children.count(_.dataType == StringType) - 1};
|
||||
int $idxInVararg = 0;
|
||||
int $idxVararg = 0;
|
||||
$varargCounts
|
||||
UTF8String[] $array = new UTF8String[$varargNum];
|
||||
$varargBuilds
|
||||
|
@ -333,7 +333,7 @@ case class Elt(children: Seq[Expression]) extends Expression {
|
|||
val indexVal = ctx.freshName("index")
|
||||
val indexMatched = ctx.freshName("eltIndexMatched")
|
||||
|
||||
val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal")
|
||||
val inputVal = ctx.addMutableState(CodeGenerator.javaType(dataType), "inputVal")
|
||||
|
||||
val assignInputValue = inputs.zipWithIndex.map { case (eval, index) =>
|
||||
s"""
|
||||
|
@ -350,10 +350,10 @@ case class Elt(children: Seq[Expression]) extends Expression {
|
|||
expressions = assignInputValue,
|
||||
funcName = "eltFunc",
|
||||
extraArguments = ("int", indexVal) :: Nil,
|
||||
returnType = ctx.JAVA_BOOLEAN,
|
||||
returnType = CodeGenerator.JAVA_BOOLEAN,
|
||||
makeSplitFunction = body =>
|
||||
s"""
|
||||
|${ctx.JAVA_BOOLEAN} $indexMatched = false;
|
||||
|${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false;
|
||||
|do {
|
||||
| $body
|
||||
|} while (false);
|
||||
|
@ -372,12 +372,12 @@ case class Elt(children: Seq[Expression]) extends Expression {
|
|||
s"""
|
||||
|${index.code}
|
||||
|final int $indexVal = ${index.value};
|
||||
|${ctx.JAVA_BOOLEAN} $indexMatched = false;
|
||||
|${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false;
|
||||
|$inputVal = null;
|
||||
|do {
|
||||
| $codes
|
||||
|} while (false);
|
||||
|final ${ctx.javaType(dataType)} ${ev.value} = $inputVal;
|
||||
|final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal;
|
||||
|final boolean ${ev.isNull} = ${ev.value} == null;
|
||||
""".stripMargin)
|
||||
}
|
||||
|
@ -1410,10 +1410,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
|
|||
val numArgLists = argListGen.length
|
||||
val argListCode = argListGen.zipWithIndex.map { case(v, index) =>
|
||||
val value =
|
||||
if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
|
||||
if (CodeGenerator.boxedType(v._1) != CodeGenerator.javaType(v._1)) {
|
||||
// Java primitives get boxed in order to allow null values.
|
||||
s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
|
||||
s"new ${ctx.boxedType(v._1)}(${v._2.value})"
|
||||
s"(${v._2.isNull}) ? (${CodeGenerator.boxedType(v._1)}) null : " +
|
||||
s"new ${CodeGenerator.boxedType(v._1)}(${v._2.value})"
|
||||
} else {
|
||||
s"(${v._2.isNull}) ? null : ${v._2.value}"
|
||||
}
|
||||
|
@ -1434,7 +1434,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
|
|||
ev.copy(code = s"""
|
||||
${pattern.code}
|
||||
boolean ${ev.isNull} = ${pattern.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
$stringBuffer $sb = new $stringBuffer();
|
||||
$formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
|
||||
|
@ -2110,7 +2110,8 @@ case class FormatNumber(x: Expression, d: Expression)
|
|||
val usLocale = "US"
|
||||
val i = ctx.freshName("i")
|
||||
val dFormat = ctx.freshName("dFormat")
|
||||
val lastDValue = ctx.addMutableState(ctx.JAVA_INT, "lastDValue", v => s"$v = -100;")
|
||||
val lastDValue =
|
||||
ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;")
|
||||
val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();")
|
||||
val numberFormat = ctx.addMutableState(df, "numberFormat",
|
||||
v => s"""$v = new $df("", new $dfs($l.$usLocale));""")
|
||||
|
|
|
@ -405,12 +405,12 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
test("SPARK-18016: define mutable states by using an array") {
|
||||
val ctx1 = new CodegenContext
|
||||
for (i <- 1 to CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) {
|
||||
ctx1.addMutableState(ctx1.JAVA_INT, "i", v => s"$v = $i;")
|
||||
ctx1.addMutableState(CodeGenerator.JAVA_INT, "i", v => s"$v = $i;")
|
||||
}
|
||||
assert(ctx1.inlinedMutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD)
|
||||
// When the number of primitive type mutable states is over the threshold, others are
|
||||
// allocated into an array
|
||||
assert(ctx1.arrayCompactedMutableStates.get(ctx1.JAVA_INT).get.arrayNames.size == 1)
|
||||
assert(ctx1.arrayCompactedMutableStates.get(CodeGenerator.JAVA_INT).get.arrayNames.size == 1)
|
||||
assert(ctx1.mutableStateInitCode.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10)
|
||||
|
||||
val ctx2 = new CodegenContext
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.sql.execution
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
import org.apache.spark.sql.types.DataType
|
||||
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
|
||||
|
@ -49,15 +49,15 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
|
|||
ordinal: String,
|
||||
dataType: DataType,
|
||||
nullable: Boolean): ExprCode = {
|
||||
val javaType = ctx.javaType(dataType)
|
||||
val value = ctx.getValueFromVector(columnVar, dataType, ordinal)
|
||||
val javaType = CodeGenerator.javaType(dataType)
|
||||
val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal)
|
||||
val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" }
|
||||
val valueVar = ctx.freshName("value")
|
||||
val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
|
||||
val code = s"${ctx.registerComment(str)}\n" + (if (nullable) {
|
||||
s"""
|
||||
boolean $isNullVar = $columnVar.isNullAt($ordinal);
|
||||
$javaType $valueVar = $isNullVar ? ${ctx.defaultValue(dataType)} : ($value);
|
||||
$javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value);
|
||||
"""
|
||||
} else {
|
||||
s"$javaType $valueVar = $value;"
|
||||
|
@ -85,12 +85,13 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
|
|||
// metrics
|
||||
val numOutputRows = metricTerm(ctx, "numOutputRows")
|
||||
val scanTimeMetric = metricTerm(ctx, "scanTime")
|
||||
val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime") // init as scanTime = 0
|
||||
val scanTimeTotalNs =
|
||||
ctx.addMutableState(CodeGenerator.JAVA_LONG, "scanTime") // init as scanTime = 0
|
||||
|
||||
val columnarBatchClz = classOf[ColumnarBatch].getName
|
||||
val batch = ctx.addMutableState(columnarBatchClz, "batch")
|
||||
|
||||
val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx") // init as batchIdx = 0
|
||||
val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0
|
||||
val columnVectorClzs = vectorTypes.getOrElse(
|
||||
Seq.fill(output.indices.size)(classOf[ColumnVector].getName))
|
||||
val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map {
|
||||
|
|
|
@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.errors._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
|
||||
|
@ -154,7 +154,8 @@ case class ExpandExec(
|
|||
val value = ctx.freshName("value")
|
||||
val code = s"""
|
||||
|boolean $isNull = true;
|
||||
|${ctx.javaType(firstExpr.dataType)} $value = ${ctx.defaultValue(firstExpr.dataType)};
|
||||
|${CodeGenerator.javaType(firstExpr.dataType)} $value =
|
||||
| ${CodeGenerator.defaultValue(firstExpr.dataType)};
|
||||
""".stripMargin
|
||||
ExprCode(code, isNull, value)
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
|
||||
|
@ -305,15 +305,15 @@ case class GenerateExec(
|
|||
nullable: Boolean,
|
||||
initialChecks: Seq[String]): ExprCode = {
|
||||
val value = ctx.freshName(name)
|
||||
val javaType = ctx.javaType(dt)
|
||||
val getter = ctx.getValue(source, dt, index)
|
||||
val javaType = CodeGenerator.javaType(dt)
|
||||
val getter = CodeGenerator.getValue(source, dt, index)
|
||||
val checks = initialChecks ++ optionalCode(nullable, s"$source.isNullAt($index)")
|
||||
if (checks.nonEmpty) {
|
||||
val isNull = ctx.freshName("isNull")
|
||||
val code =
|
||||
s"""
|
||||
|boolean $isNull = ${checks.mkString(" || ")};
|
||||
|$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter;
|
||||
|$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter;
|
||||
""".stripMargin
|
||||
ExprCode(code, isNull, value)
|
||||
} else {
|
||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.spark.executor.TaskMetrics
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
|
||||
|
@ -133,7 +133,8 @@ case class SortExec(
|
|||
override def needStopCheck: Boolean = false
|
||||
|
||||
override protected def doProduce(ctx: CodegenContext): String = {
|
||||
val needToSort = ctx.addMutableState(ctx.JAVA_BOOLEAN, "needToSort", v => s"$v = true;")
|
||||
val needToSort =
|
||||
ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;")
|
||||
|
||||
// Initialize the class member variables. This includes the instance of the Sorter and
|
||||
// the iterator to return sorted rows.
|
||||
|
|
|
@ -234,7 +234,7 @@ trait CodegenSupport extends SparkPlan {
|
|||
|
||||
variables.zipWithIndex.foreach { case (ev, i) =>
|
||||
val paramName = ctx.freshName(s"expr_$i")
|
||||
val paramType = ctx.javaType(attributes(i).dataType)
|
||||
val paramType = CodeGenerator.javaType(attributes(i).dataType)
|
||||
|
||||
arguments += ev.value
|
||||
parameters += s"$paramType $paramName"
|
||||
|
|
|
@ -178,7 +178,7 @@ case class HashAggregateExec(
|
|||
private var bufVars: Seq[ExprCode] = _
|
||||
|
||||
private def doProduceWithoutKeys(ctx: CodegenContext): String = {
|
||||
val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg")
|
||||
val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
|
||||
// The generated function doesn't have input row in the code context.
|
||||
ctx.INPUT_ROW = null
|
||||
|
||||
|
@ -186,8 +186,8 @@ case class HashAggregateExec(
|
|||
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
|
||||
val initExpr = functions.flatMap(f => f.initialValues)
|
||||
bufVars = initExpr.map { e =>
|
||||
val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull")
|
||||
val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue")
|
||||
val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull")
|
||||
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
|
||||
// The initial expression should not access any column
|
||||
val ev = e.genCode(ctx)
|
||||
val initVars = s"""
|
||||
|
@ -532,7 +532,7 @@ case class HashAggregateExec(
|
|||
*/
|
||||
private def checkIfFastHashMapSupported(ctx: CodegenContext): Boolean = {
|
||||
val isSupported =
|
||||
(groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) ||
|
||||
(groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) ||
|
||||
f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) &&
|
||||
bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge)
|
||||
|
||||
|
@ -565,7 +565,7 @@ case class HashAggregateExec(
|
|||
}
|
||||
|
||||
private def doProduceWithKeys(ctx: CodegenContext): String = {
|
||||
val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg")
|
||||
val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg")
|
||||
if (sqlContext.conf.enableTwoLevelAggMap) {
|
||||
enableTwoLevelHashMap(ctx)
|
||||
} else {
|
||||
|
@ -757,7 +757,7 @@ case class HashAggregateExec(
|
|||
|
||||
val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter,
|
||||
incCounter) = if (testFallbackStartsAt.isDefined) {
|
||||
val countTerm = ctx.addMutableState(ctx.JAVA_INT, "fallbackCounter")
|
||||
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter")
|
||||
(s"$countTerm < ${testFallbackStartsAt.get._1}",
|
||||
s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;")
|
||||
} else {
|
||||
|
@ -832,7 +832,7 @@ case class HashAggregateExec(
|
|||
}
|
||||
val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) =>
|
||||
val dt = updateExpr(i).dataType
|
||||
ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
|
||||
CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable)
|
||||
}
|
||||
s"""
|
||||
|// common sub-expressions
|
||||
|
@ -855,7 +855,7 @@ case class HashAggregateExec(
|
|||
}
|
||||
val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) =>
|
||||
val dt = updateExpr(i).dataType
|
||||
ctx.updateColumn(
|
||||
CodeGenerator.updateColumn(
|
||||
fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled)
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.sql.execution.aggregate
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
|
@ -41,13 +41,13 @@ abstract class HashMapGenerator(
|
|||
val groupingKeys = groupingKeySchema.map(k => Buffer(k.dataType, ctx.freshName("key")))
|
||||
val bufferValues = bufferSchema.map(k => Buffer(k.dataType, ctx.freshName("value")))
|
||||
val groupingKeySignature =
|
||||
groupingKeys.map(key => s"${ctx.javaType(key.dataType)} ${key.name}").mkString(", ")
|
||||
groupingKeys.map(key => s"${CodeGenerator.javaType(key.dataType)} ${key.name}").mkString(", ")
|
||||
val buffVars: Seq[ExprCode] = {
|
||||
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
|
||||
val initExpr = functions.flatMap(f => f.initialValues)
|
||||
initExpr.map { e =>
|
||||
val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull")
|
||||
val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue")
|
||||
val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull")
|
||||
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
|
||||
val ev = e.genCode(ctx)
|
||||
val initVars =
|
||||
s"""
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
package org.apache.spark.sql.execution.aggregate
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext}
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
|
@ -114,7 +114,7 @@ class RowBasedHashMapGenerator(
|
|||
|
||||
def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = {
|
||||
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
|
||||
s"""(${ctx.genEqual(key.dataType, ctx.getValue("row",
|
||||
s"""(${ctx.genEqual(key.dataType, CodeGenerator.getValue("row",
|
||||
key.dataType, ordinal.toString()), key.name)})"""
|
||||
}.mkString(" && ")
|
||||
}
|
||||
|
@ -147,7 +147,7 @@ class RowBasedHashMapGenerator(
|
|||
case t: DecimalType =>
|
||||
s"agg_rowWriter.write(${ordinal}, ${key.name}, ${t.precision}, ${t.scale})"
|
||||
case t: DataType =>
|
||||
if (!t.isInstanceOf[StringType] && !ctx.isPrimitiveType(t)) {
|
||||
if (!t.isInstanceOf[StringType] && !CodeGenerator.isPrimitiveType(t)) {
|
||||
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $t")
|
||||
}
|
||||
s"agg_rowWriter.write(${ordinal}, ${key.name})"
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate
|
|||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator}
|
||||
import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.vectorized.ColumnarBatch
|
||||
|
@ -127,7 +127,8 @@ class VectorizedHashMapGenerator(
|
|||
|
||||
def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = {
|
||||
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
|
||||
val value = ctx.getValueFromVector(s"vectors[$ordinal]", key.dataType, "buckets[idx]")
|
||||
val value = CodeGenerator.getValueFromVector(s"vectors[$ordinal]", key.dataType,
|
||||
"buckets[idx]")
|
||||
s"(${ctx.genEqual(key.dataType, value, key.name)})"
|
||||
}.mkString(" && ")
|
||||
}
|
||||
|
@ -182,14 +183,14 @@ class VectorizedHashMapGenerator(
|
|||
|
||||
def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = {
|
||||
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
|
||||
ctx.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name)
|
||||
CodeGenerator.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name)
|
||||
}
|
||||
}
|
||||
|
||||
def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = {
|
||||
bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
|
||||
ctx.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType,
|
||||
buffVars(ordinal), nullable = true)
|
||||
CodeGenerator.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows",
|
||||
key.dataType, buffVars(ordinal), nullable = true)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
|
|||
import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, ExpressionCanonicalizer}
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
import org.apache.spark.sql.types.LongType
|
||||
|
@ -364,8 +364,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
|
|||
protected override def doProduce(ctx: CodegenContext): String = {
|
||||
val numOutput = metricTerm(ctx, "numOutputRows")
|
||||
|
||||
val initTerm = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initRange")
|
||||
val number = ctx.addMutableState(ctx.JAVA_LONG, "number")
|
||||
val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange")
|
||||
val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number")
|
||||
|
||||
val value = ctx.freshName("value")
|
||||
val ev = ExprCode("", "false", value)
|
||||
|
@ -385,10 +385,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
|
|||
// the metrics.
|
||||
|
||||
// Once number == batchEnd, it's time to progress to the next batch.
|
||||
val batchEnd = ctx.addMutableState(ctx.JAVA_LONG, "batchEnd")
|
||||
val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd")
|
||||
|
||||
// How many values should still be generated by this range operator.
|
||||
val numElementsTodo = ctx.addMutableState(ctx.JAVA_LONG, "numElementsTodo")
|
||||
val numElementsTodo = ctx.addMutableState(CodeGenerator.JAVA_LONG, "numElementsTodo")
|
||||
|
||||
// How many values should be generated in the next batch.
|
||||
val nextBatchTodo = ctx.freshName("nextBatchTodo")
|
||||
|
|
|
@ -91,7 +91,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
|
|||
val accessorName = ctx.addMutableState(accessorCls, "accessor")
|
||||
|
||||
val createCode = dt match {
|
||||
case t if ctx.isPrimitiveType(dt) =>
|
||||
case t if CodeGenerator.isPrimitiveType(dt) =>
|
||||
s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
|
||||
case NullType | StringType | BinaryType =>
|
||||
s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
|
||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.spark.broadcast.Broadcast
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
|
||||
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan}
|
||||
|
@ -182,9 +182,10 @@ case class BroadcastHashJoinExec(
|
|||
// the variables are needed even there is no matched rows
|
||||
val isNull = ctx.freshName("isNull")
|
||||
val value = ctx.freshName("value")
|
||||
val javaType = CodeGenerator.javaType(a.dataType)
|
||||
val code = s"""
|
||||
|boolean $isNull = true;
|
||||
|${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)};
|
||||
|$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
|
||||
|if ($matched != null) {
|
||||
| ${ev.code}
|
||||
| $isNull = ${ev.isNull};
|
||||
|
|
|
@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport,
|
||||
|
@ -516,9 +516,9 @@ case class SortMergeJoinExec(
|
|||
ctx.INPUT_ROW = leftRow
|
||||
left.output.zipWithIndex.map { case (a, i) =>
|
||||
val value = ctx.freshName("value")
|
||||
val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
|
||||
val javaType = ctx.javaType(a.dataType)
|
||||
val defaultValue = ctx.defaultValue(a.dataType)
|
||||
val valueCode = CodeGenerator.getValue(leftRow, a.dataType, i.toString)
|
||||
val javaType = CodeGenerator.javaType(a.dataType)
|
||||
val defaultValue = CodeGenerator.defaultValue(a.dataType)
|
||||
if (a.nullable) {
|
||||
val isNull = ctx.freshName("isNull")
|
||||
val code =
|
||||
|
|
|
@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
|
|||
import org.apache.spark.serializer.Serializer
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, LazilyGeneratedOrdering}
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
|
||||
import org.apache.spark.util.Utils
|
||||
|
@ -71,7 +71,8 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
|
|||
}
|
||||
|
||||
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
|
||||
val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false
|
||||
val stopEarly =
|
||||
ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false
|
||||
|
||||
ctx.addNewFunction("stopEarly", s"""
|
||||
@Override
|
||||
|
@ -79,7 +80,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
|
|||
return $stopEarly;
|
||||
}
|
||||
""", inlineToOuterClass = true)
|
||||
val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count") // init as count = 0
|
||||
val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0
|
||||
s"""
|
||||
| if ($countTerm < $limit) {
|
||||
| $countTerm += 1;
|
||||
|
|
Loading…
Reference in a new issue