diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 6a17a397b3..89ffbb0016 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ /** @@ -66,13 +66,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) ev.copy(code = oev.code) } else { assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.") - val javaType = ctx.javaType(dataType) - val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) + val javaType = CodeGenerator.javaType(dataType) + val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = s""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - |$javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + |$javaType ${ev.value} = ${ev.isNull} ? + | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 79b051670e..12330bfa55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -669,7 +669,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = { s""" boolean $resultIsNull = $inputIsNull; - ${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)}; + ${CodeGenerator.javaType(resultType)} $result = ${CodeGenerator.defaultValue(resultType)}; if (!$inputIsNull) { ${cast(input, result, resultIsNull)} } @@ -685,7 +685,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val funcName = ctx.freshName("elementToString") val elementToStringFunc = ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${ctx.javaType(et)} element) { + |private UTF8String $funcName(${CodeGenerator.javaType(et)} element) { | UTF8String elementStr = null; | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} | return elementStr; @@ -697,13 +697,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$buffer.append("["); |if ($array.numElements() > 0) { | if (!$array.isNullAt(0)) { - | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")})); + | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, "0")})); | } | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) { | $buffer.append(","); | if (!$array.isNullAt($loopIndex)) { | $buffer.append(" "); - | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)})); + | $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, loopIndex)})); | } | } |} @@ -723,7 +723,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val dataToStringCode = castToStringCode(dataType, ctx) ctx.addNewFunction(funcName, s""" - |private UTF8String $funcName(${ctx.javaType(dataType)} data) { + |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} data) { | UTF8String dataStr = null; | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} | return dataStr; @@ -734,23 +734,26 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val keyToStringFunc = dataToStringFunc("keyToString", kt) val valueToStringFunc = dataToStringFunc("valueToString", vt) val loopIndex = ctx.freshName("loopIndex") + val getMapFirstKey = CodeGenerator.getValue(s"$map.keyArray()", kt, "0") + val getMapFirstValue = CodeGenerator.getValue(s"$map.valueArray()", vt, "0") + val getMapKeyArray = CodeGenerator.getValue(s"$map.keyArray()", kt, loopIndex) + val getMapValueArray = CodeGenerator.getValue(s"$map.valueArray()", vt, loopIndex) s""" |$buffer.append("["); |if ($map.numElements() > 0) { - | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")})); + | $buffer.append($keyToStringFunc($getMapFirstKey)); | $buffer.append(" ->"); | if (!$map.valueArray().isNullAt(0)) { | $buffer.append(" "); - | $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")})); + | $buffer.append($valueToStringFunc($getMapFirstValue)); | } | for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) { | $buffer.append(", "); - | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)})); + | $buffer.append($keyToStringFunc($getMapKeyArray)); | $buffer.append(" ->"); | if (!$map.valueArray().isNullAt($loopIndex)) { | $buffer.append(" "); - | $buffer.append($valueToStringFunc( - | ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)})); + | $buffer.append($valueToStringFunc($getMapValueArray)); | } | } |} @@ -773,7 +776,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String | ${if (i != 0) s"""$buffer.append(" ");""" else ""} | | // Append $i field into the string buffer - | ${ctx.javaType(ft)} $field = ${ctx.getValue(row, ft, s"$i")}; + | ${CodeGenerator.javaType(ft)} $field = ${CodeGenerator.getValue(row, ft, s"$i")}; | UTF8String $fieldStr = null; | ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)} | $buffer.append($fieldStr); @@ -1202,8 +1205,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String $values[$j] = null; } else { boolean $fromElementNull = false; - ${ctx.javaType(fromType)} $fromElementPrim = - ${ctx.getValue(c, fromType, j)}; + ${CodeGenerator.javaType(fromType)} $fromElementPrim = + ${CodeGenerator.getValue(c, fromType, j)}; ${castCode(ctx, fromElementPrim, fromElementNull, toElementPrim, toElementNull, toType, elementCast)} if ($toElementNull) { @@ -1259,20 +1262,20 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val fromFieldNull = ctx.freshName("ffn") val toFieldPrim = ctx.freshName("tfp") val toFieldNull = ctx.freshName("tfn") - val fromType = ctx.javaType(from.fields(i).dataType) + val fromType = CodeGenerator.javaType(from.fields(i).dataType) s""" boolean $fromFieldNull = $tmpInput.isNullAt($i); if ($fromFieldNull) { $tmpResult.setNullAt($i); } else { $fromType $fromFieldPrim = - ${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)}; + ${CodeGenerator.getValue(tmpInput, from.fields(i).dataType, i.toString)}; ${castCode(ctx, fromFieldPrim, fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} if ($toFieldNull) { $tmpResult.setNullAt($i); } else { - ${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; + ${CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4568714933..ed90b18586 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,7 +119,7 @@ abstract class Expression extends TreeNode[Expression] { // TODO: support whole stage codegen too if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { - val globalIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "globalIsNull") + val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull eval.isNull = globalIsNull s"$globalIsNull = $localIsNull;" @@ -127,7 +127,7 @@ abstract class Expression extends TreeNode[Expression] { "" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val newValue = ctx.freshName("value") val funcName = ctx.freshName(nodeName) @@ -411,14 +411,14 @@ abstract class UnaryExpression extends Expression { ev.copy(code = s""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { ev.copy(code = s""" boolean ${ev.isNull} = false; ${childGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = "false") } } @@ -510,7 +510,7 @@ abstract class BinaryExpression extends Expression { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { @@ -518,7 +518,7 @@ abstract class BinaryExpression extends Expression { boolean ${ev.isNull} = false; ${leftGen.code} ${rightGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = "false") } } @@ -654,7 +654,7 @@ abstract class TernaryExpression extends Expression { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval""") } else { ev.copy(code = s""" @@ -662,7 +662,7 @@ abstract class TernaryExpression extends Expression { ${leftGen.code} ${midGen.code} ${rightGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 11fb579dfa..4523079060 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types.{DataType, LongType} /** @@ -65,14 +65,14 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count") + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count") val partitionMaskTerm = "partitionMask" - ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm) + ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_LONG, partitionMaskTerm) ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = "false") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 989c023056..e869258469 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1018,11 +1018,12 @@ case class ScalaUDF( val udf = ctx.addReferenceObj("udf", function, s"scala.Function${children.length}") val getFuncResult = s"$udf.apply(${funcArgs.mkString(", ")})" val resultConverter = s"$convertersTerm[${children.length}]" + val boxedType = CodeGenerator.boxedType(dataType) val callFunc = s""" - |${ctx.boxedType(dataType)} $resultTerm = null; + |$boxedType $resultTerm = null; |try { - | $resultTerm = (${ctx.boxedType(dataType)})$resultConverter.apply($getFuncResult); + | $resultTerm = ($boxedType)$resultConverter.apply($getFuncResult); |} catch (Exception e) { | throw new org.apache.spark.SparkException($errorMsgTerm, e); |} @@ -1035,7 +1036,7 @@ case class ScalaUDF( |$callFunc | |boolean ${ev.isNull} = $resultTerm == null; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${ev.isNull}) { | ${ev.value} = $resultTerm; |} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index a160b9b275..cc6a769d03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -44,8 +44,9 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = "partitionId" - ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm) + ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", + isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 9a9f579b37..6c4a3601c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -22,7 +22,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -165,7 +165,7 @@ case class PreciseTimestampConversion( val eval = child.genCode(ctx) ev.copy(code = eval.code + s"""boolean ${ev.isNull} = ${eval.isNull}; - |${ctx.javaType(dataType)} ${ev.value} = ${eval.value}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } override def nullSafeEval(input: Any): Any = input diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 8bb14598a6..508bdd5050 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -49,8 +49,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression // codegen would fail to compile if we just write (-($c)) // for example, we could not write --9223372036854775808L in code s""" - ${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval); - ${ev.value} = (${ctx.javaType(dt)})(-($originValue)); + ${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval); + ${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue)); """}) case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } @@ -107,7 +107,7 @@ case class Abs(child: Expression) case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") case dt: NumericType => - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(java.lang.Math.abs($c))") + defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))") } protected override def nullSafeEval(input: Any): Any = numeric.abs(input) @@ -129,7 +129,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") } @@ -167,7 +167,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") case ByteType | ShortType => defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") case _ => @@ -203,7 +203,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") case ByteType | ShortType => defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") case _ => @@ -278,7 +278,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } else { s"${eval2.value} == 0" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val divide = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { @@ -288,7 +288,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -299,7 +299,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -365,7 +365,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } else { s"${eval2.value} == 0" } - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val remainder = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { @@ -375,7 +375,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -386,7 +386,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -454,13 +454,13 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { s"${eval2.value} == 0" } val remainder = ctx.freshName("remainder") - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val result = dataType match { case DecimalType.Fixed(_, _) => val decimalAdd = "$plus" s""" - ${ctx.javaType(dataType)} $remainder = ${eval1.value}.remainder(${eval2.value}); + $javaType $remainder = ${eval1.value}.remainder(${eval2.value}); if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { ${ev.value}=($remainder.$decimalAdd(${eval2.value})).remainder(${eval2.value}); } else { @@ -470,17 +470,16 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => s""" - ${ctx.javaType(dataType)} $remainder = - (${ctx.javaType(dataType)})(${eval1.value} % ${eval2.value}); + $javaType $remainder = ($javaType)(${eval1.value} % ${eval2.value}); if ($remainder < 0) { - ${ev.value}=(${ctx.javaType(dataType)})(($remainder + ${eval2.value}) % ${eval2.value}); + ${ev.value}=($javaType)(($remainder + ${eval2.value}) % ${eval2.value}); } else { ${ev.value}=$remainder; } """ case _ => s""" - ${ctx.javaType(dataType)} $remainder = ${eval1.value} % ${eval2.value}; + $javaType $remainder = ${eval1.value} % ${eval2.value}; if ($remainder < 0) { ${ev.value}=($remainder + ${eval2.value}) % ${eval2.value}; } else { @@ -493,7 +492,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if ($isZero) { ${ev.isNull} = true; } else { @@ -504,7 +503,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { @@ -602,7 +601,7 @@ case class Least(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -614,7 +613,7 @@ case class Least(children: Seq[Expression]) extends Expression { """.stripMargin ) - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "least", @@ -629,7 +628,7 @@ case class Least(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes """.stripMargin) } @@ -681,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) val evals = evalChildren.map(eval => s""" |${eval.code} @@ -693,7 +692,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { """.stripMargin ) - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "greatest", @@ -708,7 +707,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 173481f06a..cc24e397cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -147,7 +147,7 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") + defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dataType)}) ~($c)") } protected override def nullSafeEval(input: Any): Any = not(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 60a6f50472..793824b0b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -59,6 +59,11 @@ import org.apache.spark.util.{ParentClassLoader, Utils} case class ExprCode(var code: String, var isNull: String, var value: String) object ExprCode { + def forNullValue(dataType: DataType): ExprCode = { + val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true) + ExprCode(code = "", isNull = "true", value = defaultValueLiteral) + } + def forNonNullValue(value: String): ExprCode = { ExprCode(code = "", isNull = "false", value = value) } @@ -105,6 +110,8 @@ private[codegen] case class NewFunctionSpec( */ class CodegenContext { + import CodeGenerator._ + /** * Holding a list of objects that could be used passed into generated class. */ @@ -196,11 +203,11 @@ class CodegenContext { /** * Returns the reference of next available slot in current compacted array. The size of each - * compacted array is controlled by the constant `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * compacted array is controlled by the constant `MUTABLESTATEARRAY_SIZE_LIMIT`. * Once reaching the threshold, new compacted array is created. */ def getNextSlot(): String = { - if (currentIndex < CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT) { + if (currentIndex < MUTABLESTATEARRAY_SIZE_LIMIT) { val res = s"${arrayNames.last}[$currentIndex]" currentIndex += 1 res @@ -247,10 +254,10 @@ class CodegenContext { * are satisfied: * 1. forceInline is true * 2. its type is primitive type and the total number of the inlined mutable variables - * is less than `CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD` + * is less than `OUTER_CLASS_VARIABLES_THRESHOLD` * 3. its type is multi-dimensional array * When a variable is compacted into an array, the max size of the array for compaction - * is given by `CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT`. + * is given by `MUTABLESTATEARRAY_SIZE_LIMIT`. */ def addMutableState( javaType: String, @@ -261,7 +268,7 @@ class CodegenContext { // want to put a primitive type variable at outerClass for performance val canInlinePrimitive = isPrimitiveType(javaType) && - (inlinedMutableStates.length < CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) + (inlinedMutableStates.length < OUTER_CLASS_VARIABLES_THRESHOLD) if (forceInline || canInlinePrimitive || javaType.contains("[][]")) { val varName = if (useFreshName) freshName(variableName) else variableName val initCode = initFunc(varName) @@ -339,7 +346,7 @@ class CodegenContext { val length = if (index + 1 == numArrays) { mutableStateArrays.getCurrentIndex } else { - CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + MUTABLESTATEARRAY_SIZE_LIMIT } if (javaType.contains("[]")) { // initializer had an one-dimensional array variable @@ -468,7 +475,7 @@ class CodegenContext { inlineToOuterClass: Boolean): NewFunctionSpec = { val (className, classInstance) = if (inlineToOuterClass) { outerClassName -> "" - } else if (currClassSize > CodeGenerator.GENERATED_CLASS_SIZE_THRESHOLD) { + } else if (currClassSize > GENERATED_CLASS_SIZE_THRESHOLD) { val className = freshName("NestedClass") val classInstance = freshName("nestedClassInstance") @@ -537,14 +544,6 @@ class CodegenContext { extraClasses.append(code) } - final val JAVA_BOOLEAN = "boolean" - final val JAVA_BYTE = "byte" - final val JAVA_SHORT = "short" - final val JAVA_INT = "int" - final val JAVA_LONG = "long" - final val JAVA_FLOAT = "float" - final val JAVA_DOUBLE = "double" - /** * The map from a variable name to it's next ID. */ @@ -580,196 +579,6 @@ class CodegenContext { } } - /** - * Returns the specialized code to access a value from `inputRow` at `ordinal`. - */ - def getValue(input: String, dataType: DataType, ordinal: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" - case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" - case StringType => s"$input.getUTF8String($ordinal)" - case BinaryType => s"$input.getBinary($ordinal)" - case CalendarIntervalType => s"$input.getInterval($ordinal)" - case t: StructType => s"$input.getStruct($ordinal, ${t.size})" - case _: ArrayType => s"$input.getArray($ordinal)" - case _: MapType => s"$input.getMap($ordinal)" - case NullType => "null" - case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) - case _ => s"($jt)$input.get($ordinal, null)" - } - } - - /** - * Returns the code to update a column in Row for a given DataType. - */ - def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" - case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" - case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) - // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy - // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. - case StringType | _: StructType | _: ArrayType | _: MapType => - s"$row.update($ordinal, $value.copy())" - case _ => s"$row.update($ordinal, $value)" - } - } - - /** - * Update a column in MutableRow from ExprCode. - * - * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise - */ - def updateColumn( - row: String, - dataType: DataType, - ordinal: Int, - ev: ExprCode, - nullable: Boolean, - isVectorized: Boolean = false): String = { - if (nullable) { - // Can't call setNullAt on DecimalType, because we need to keep the offset - if (!isVectorized && dataType.isInstanceOf[DecimalType]) { - s""" - if (!${ev.isNull}) { - ${setColumn(row, dataType, ordinal, ev.value)}; - } else { - ${setColumn(row, dataType, ordinal, "null")}; - } - """ - } else { - s""" - if (!${ev.isNull}) { - ${setColumn(row, dataType, ordinal, ev.value)}; - } else { - $row.setNullAt($ordinal); - } - """ - } - } else { - s"""${setColumn(row, dataType, ordinal, ev.value)};""" - } - } - - /** - * Returns the specialized code to set a given value in a column vector for a given `DataType`. - */ - def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = { - val jt = javaType(dataType) - dataType match { - case _ if isPrimitiveType(jt) => - s"$vector.put${primitiveTypeName(jt)}($rowId, $value);" - case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});" - case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" - case _ => - throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") - } - } - - /** - * Returns the specialized code to set a given value in a column vector for a given `DataType` - * that could potentially be nullable. - */ - def updateColumn( - vector: String, - rowId: String, - dataType: DataType, - ev: ExprCode, - nullable: Boolean): String = { - if (nullable) { - s""" - if (!${ev.isNull}) { - ${setValue(vector, rowId, dataType, ev.value)} - } else { - $vector.putNull($rowId); - } - """ - } else { - s"""${setValue(vector, rowId, dataType, ev.value)};""" - } - } - - /** - * Returns the specialized code to access a value from a column vector for a given `DataType`. - */ - def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = { - if (dataType.isInstanceOf[StructType]) { - // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an - // `ordinal` parameter. - s"$vector.getStruct($rowId)" - } else { - getValue(vector, dataType, rowId) - } - } - - /** - * Returns the name used in accessor and setter for a Java primitive type. - */ - def primitiveTypeName(jt: String): String = jt match { - case JAVA_INT => "Int" - case _ => boxedType(jt) - } - - def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt)) - - /** - * Returns the Java type for a DataType. - */ - def javaType(dt: DataType): String = dt match { - case BooleanType => JAVA_BOOLEAN - case ByteType => JAVA_BYTE - case ShortType => JAVA_SHORT - case IntegerType | DateType => JAVA_INT - case LongType | TimestampType => JAVA_LONG - case FloatType => JAVA_FLOAT - case DoubleType => JAVA_DOUBLE - case dt: DecimalType => "Decimal" - case BinaryType => "byte[]" - case StringType => "UTF8String" - case CalendarIntervalType => "CalendarInterval" - case _: StructType => "InternalRow" - case _: ArrayType => "ArrayData" - case _: MapType => "MapData" - case udt: UserDefinedType[_] => javaType(udt.sqlType) - case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" - case ObjectType(cls) => cls.getName - case _ => "Object" - } - - /** - * Returns the boxed type in Java. - */ - def boxedType(jt: String): String = jt match { - case JAVA_BOOLEAN => "Boolean" - case JAVA_BYTE => "Byte" - case JAVA_SHORT => "Short" - case JAVA_INT => "Integer" - case JAVA_LONG => "Long" - case JAVA_FLOAT => "Float" - case JAVA_DOUBLE => "Double" - case other => other - } - - def boxedType(dt: DataType): String = boxedType(javaType(dt)) - - /** - * Returns the representation of default value for a given Java Type. - */ - def defaultValue(jt: String): String = jt match { - case JAVA_BOOLEAN => "false" - case JAVA_BYTE => "(byte)-1" - case JAVA_SHORT => "(short)-1" - case JAVA_INT => "-1" - case JAVA_LONG => "-1L" - case JAVA_FLOAT => "-1.0f" - case JAVA_DOUBLE => "-1.0" - case _ => "null" - } - - def defaultValue(dt: DataType): String = defaultValue(javaType(dt)) - /** * Generates code for equal expression in Java. */ @@ -812,6 +621,7 @@ class CodegenContext { val isNullB = freshName("isNullB") val compareFunc = freshName("compareArray") val minLength = freshName("minLength") + val jt = javaType(elementType) val funcCode: String = s""" public int $compareFunc(ArrayData a, ArrayData b) { @@ -833,8 +643,8 @@ class CodegenContext { } else if ($isNullB) { return 1; } else { - ${javaType(elementType)} $elementA = ${getValue("a", elementType, "i")}; - ${javaType(elementType)} $elementB = ${getValue("b", elementType, "i")}; + $jt $elementA = ${getValue("a", elementType, "i")}; + $jt $elementB = ${getValue("b", elementType, "i")}; int comp = ${genComp(elementType, elementA, elementB)}; if (comp != 0) { return comp; @@ -906,19 +716,6 @@ class CodegenContext { } } - /** - * List of java data types that have special accessors and setters in [[InternalRow]]. - */ - val primitiveTypes = - Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE) - - /** - * Returns true if the Java type has a special accessor and setter in [[InternalRow]]. - */ - def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) - - def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) - /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow @@ -1089,7 +886,7 @@ class CodegenContext { // for performance reasons, the functions are prepended, instead of appended, // thus here they are in reversed order val orderedFunctions = innerClassFunctions.reverse - if (orderedFunctions.size > CodeGenerator.MERGE_SPLIT_METHODS_THRESHOLD) { + if (orderedFunctions.size > MERGE_SPLIT_METHODS_THRESHOLD) { // Adding a new function to each inner class which contains the invocation of all the // ones which have been added to that inner class. For example, // private class NestedClass { @@ -1289,7 +1086,7 @@ class CodegenContext { * length less than a pre-defined constant. */ def isValidParamLength(paramLength: Int): Boolean = { - paramLength <= CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH + paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH } } @@ -1524,4 +1321,221 @@ object CodeGenerator extends Logging { result } }) + + /** + * Name of Java primitive data type + */ + final val JAVA_BOOLEAN = "boolean" + final val JAVA_BYTE = "byte" + final val JAVA_SHORT = "short" + final val JAVA_INT = "int" + final val JAVA_LONG = "long" + final val JAVA_FLOAT = "float" + final val JAVA_DOUBLE = "double" + + /** + * List of java primitive data types + */ + val primitiveTypes = + Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE) + + /** + * Returns true if a Java type is Java primitive primitive type + */ + def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) + + def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) + + /** + * Returns the specialized code to access a value from `inputRow` at `ordinal`. + */ + def getValue(input: String, dataType: DataType, ordinal: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)" + case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})" + case StringType => s"$input.getUTF8String($ordinal)" + case BinaryType => s"$input.getBinary($ordinal)" + case CalendarIntervalType => s"$input.getInterval($ordinal)" + case t: StructType => s"$input.getStruct($ordinal, ${t.size})" + case _: ArrayType => s"$input.getArray($ordinal)" + case _: MapType => s"$input.getMap($ordinal)" + case NullType => "null" + case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) + case _ => s"($jt)$input.get($ordinal, null)" + } + } + + /** + * Returns the code to update a column in Row for a given DataType. + */ + def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" + case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" + case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) + // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy + // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. + case StringType | _: StructType | _: ArrayType | _: MapType => + s"$row.update($ordinal, $value.copy())" + case _ => s"$row.update($ordinal, $value)" + } + } + + /** + * Update a column in MutableRow from ExprCode. + * + * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise + */ + def updateColumn( + row: String, + dataType: DataType, + ordinal: Int, + ev: ExprCode, + nullable: Boolean, + isVectorized: Boolean = false): String = { + if (nullable) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + if (!isVectorized && dataType.isInstanceOf[DecimalType]) { + s""" + |if (!${ev.isNull}) { + | ${setColumn(row, dataType, ordinal, ev.value)}; + |} else { + | ${setColumn(row, dataType, ordinal, "null")}; + |} + """.stripMargin + } else { + s""" + |if (!${ev.isNull}) { + | ${setColumn(row, dataType, ordinal, ev.value)}; + |} else { + | $row.setNullAt($ordinal); + |} + """.stripMargin + } + } else { + s"""${setColumn(row, dataType, ordinal, ev.value)};""" + } + } + + /** + * Returns the specialized code to set a given value in a column vector for a given `DataType`. + */ + def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = { + val jt = javaType(dataType) + dataType match { + case _ if isPrimitiveType(jt) => + s"$vector.put${primitiveTypeName(jt)}($rowId, $value);" + case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});" + case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" + case _ => + throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") + } + } + + /** + * Returns the specialized code to set a given value in a column vector for a given `DataType` + * that could potentially be nullable. + */ + def updateColumn( + vector: String, + rowId: String, + dataType: DataType, + ev: ExprCode, + nullable: Boolean): String = { + if (nullable) { + s""" + |if (!${ev.isNull}) { + | ${setValue(vector, rowId, dataType, ev.value)} + |} else { + | $vector.putNull($rowId); + |} + """.stripMargin + } else { + s"""${setValue(vector, rowId, dataType, ev.value)};""" + } + } + + /** + * Returns the specialized code to access a value from a column vector for a given `DataType`. + */ + def getValueFromVector(vector: String, dataType: DataType, rowId: String): String = { + if (dataType.isInstanceOf[StructType]) { + // `ColumnVector.getStruct` is different from `InternalRow.getStruct`, it only takes an + // `ordinal` parameter. + s"$vector.getStruct($rowId)" + } else { + getValue(vector, dataType, rowId) + } + } + + /** + * Returns the name used in accessor and setter for a Java primitive type. + */ + def primitiveTypeName(jt: String): String = jt match { + case JAVA_INT => "Int" + case _ => boxedType(jt) + } + + def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt)) + + /** + * Returns the Java type for a DataType. + */ + def javaType(dt: DataType): String = dt match { + case BooleanType => JAVA_BOOLEAN + case ByteType => JAVA_BYTE + case ShortType => JAVA_SHORT + case IntegerType | DateType => JAVA_INT + case LongType | TimestampType => JAVA_LONG + case FloatType => JAVA_FLOAT + case DoubleType => JAVA_DOUBLE + case _: DecimalType => "Decimal" + case BinaryType => "byte[]" + case StringType => "UTF8String" + case CalendarIntervalType => "CalendarInterval" + case _: StructType => "InternalRow" + case _: ArrayType => "ArrayData" + case _: MapType => "MapData" + case udt: UserDefinedType[_] => javaType(udt.sqlType) + case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" + case ObjectType(cls) => cls.getName + case _ => "Object" + } + + /** + * Returns the boxed type in Java. + */ + def boxedType(jt: String): String = jt match { + case JAVA_BOOLEAN => "Boolean" + case JAVA_BYTE => "Byte" + case JAVA_SHORT => "Short" + case JAVA_INT => "Integer" + case JAVA_LONG => "Long" + case JAVA_FLOAT => "Float" + case JAVA_DOUBLE => "Double" + case other => other + } + + def boxedType(dt: DataType): String = boxedType(javaType(dt)) + + /** + * Returns the representation of default value for a given Java Type. + * @param jt the string name of the Java type + * @param typedNull if true, for null literals, return a typed (with a cast) version + */ + def defaultValue(jt: String, typedNull: Boolean): String = jt match { + case JAVA_BOOLEAN => "false" + case JAVA_BYTE => "(byte)-1" + case JAVA_SHORT => "(short)-1" + case JAVA_INT => "-1" + case JAVA_LONG => "-1L" + case JAVA_FLOAT => "-1.0f" + case JAVA_DOUBLE => "-1.0" + case _ => if (typedNull) s"(($jt)null)" else "null" + } + + def defaultValue(dt: DataType, typedNull: Boolean = false): String = + defaultValue(javaType(dt), typedNull) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 0322d1dd6a..e12420bb5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -44,20 +44,21 @@ trait CodegenFallback extends Expression { } val objectTerm = ctx.freshName("obj") val placeHolder = ctx.registerComment(this.toString) + val javaType = CodeGenerator.javaType(this.dataType) if (nullable) { ev.copy(code = s""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; - ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(this.dataType)}; if (!${ev.isNull}) { - ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; }""") } else { ev.copy(code = s""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); - ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; """, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index b53c0087e7..d35fd8ecb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -62,9 +62,9 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val projectionCodes: Seq[(String, String, String, Int)] = exprVals.zip(index).map { case (ev, i) => val e = expressions(i) - val value = ctx.addMutableState(ctx.javaType(e.dataType), "value") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value") if (e.nullable) { - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "isNull") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull") (s""" |${ev.code} |$isNull = ${ev.isNull}; @@ -84,7 +84,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val updates = validExpr.zip(projectionCodes).map { case (e, (_, isNull, value, i)) => val ev = ExprCode("", isNull, value) - ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) + CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 4a459571ed..9a51be6ed5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -89,7 +89,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR s""" ${ctx.INPUT_ROW} = a; boolean $isNullA; - ${ctx.javaType(order.child.dataType)} $primitiveA; + ${CodeGenerator.javaType(order.child.dataType)} $primitiveA; { ${eval.code} $isNullA = ${eval.isNull}; @@ -97,7 +97,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR } ${ctx.INPUT_ROW} = b; boolean $isNullB; - ${ctx.javaType(order.child.dataType)} $primitiveB; + ${CodeGenerator.javaType(order.child.dataType)} $primitiveB; { ${eval.code} $isNullB = ${eval.isNull}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 3dcbb518ba..f92f70ee71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -53,7 +53,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, i.toString), dt) + val converter = convertToSafe(ctx, CodeGenerator.getValue(tmpInput, dt, i.toString), dt) s""" if (!$tmpInput.isNullAt($i)) { ${converter.code} @@ -90,7 +90,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val arrayClass = classOf[GenericArrayData].getName val elementConverter = convertToSafe( - ctx, ctx.getValue(tmpInput, elementType, index), elementType) + ctx, CodeGenerator.getValue(tmpInput, elementType, index), elementType) val code = s""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); @@ -153,7 +153,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] mutableRow.setNullAt($i); } else { ${converter.code} - ${ctx.setColumn("mutableRow", e.dataType, i, converter.value)}; + ${CodeGenerator.setColumn("mutableRow", e.dataType, i, converter.value)}; } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 36ffa8dcdd..22717f5954 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -52,7 +52,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, i.toString)) + ExprCode("", s"$tmpInput.isNullAt($i)", CodeGenerator.getValue(tmpInput, dt, i.toString)) } s""" @@ -195,16 +195,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case other => other } - val jt = ctx.javaType(et) + val jt = CodeGenerator.javaType(et) val elementOrOffsetSize = et match { case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8 - case _ if ctx.isPrimitiveType(jt) => et.defaultSize + case _ if CodeGenerator.isPrimitiveType(jt) => et.defaultSize case _ => 8 // we need 8 bytes to store offset and length } val tmpCursor = ctx.freshName("tmpCursor") - val element = ctx.getValue(tmpInput, et, index) + val element = CodeGenerator.getValue(tmpInput, et, index) val writeElement = et match { case t: StructType => s""" @@ -235,7 +235,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$arrayWriter.write($index, $element);" } - val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else "" + val primitiveTypeName = + if (CodeGenerator.isPrimitiveType(jt)) CodeGenerator.primitiveTypeName(et) else "" s""" final ArrayData $tmpInput = $input; if ($tmpInput instanceof UnsafeArrayData) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4270b987d6..beb84694c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,7 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -54,7 +54,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType ev.copy(code = s""" boolean ${ev.isNull} = false; ${childGen.code} - ${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : (${childGen.value}).numElements();""", isNull = "false") } } @@ -270,7 +270,7 @@ case class ArrayContains(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (arr, value) => { val i = ctx.freshName("i") - val getValue = ctx.getValue(arr, right.dataType, i) + val getValue = CodeGenerator.getValue(arr, right.dataType, i) s""" for (int $i = 0; $i < $arr.numElements(); $i ++) { if ($arr.isNullAt($i)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 047b80ac52..85facdad43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -90,7 +90,7 @@ private [sql] object GenArrayData { val arrayDataName = ctx.freshName("arrayData") val numElements = elementsCode.length - if (!ctx.isPrimitiveType(elementType)) { + if (!CodeGenerator.isPrimitiveType(elementType)) { val arrayName = ctx.freshName("arrayObject") val genericArrayClass = classOf[GenericArrayData].getName @@ -124,7 +124,7 @@ private [sql] object GenArrayData { ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - val primitiveValueTypeName = ctx.primitiveTypeName(elementType) + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { s"$arrayDataName.setNullAt($i);" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 7e53ca3908..6cdad19168 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -129,12 +129,12 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] if ($eval.isNullAt($ordinal)) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)}; } """ } else { s""" - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + ${ev.value} = ${CodeGenerator.getValue(eval, dataType, ordinal.toString)}; """ } }) @@ -205,7 +205,7 @@ case class GetArrayStructFields( } else { final InternalRow $row = $eval.getStruct($j, $numFields); $nullSafeEval { - $values[$j] = ${ctx.getValue(row, field.dataType, ordinal.toString)}; + $values[$j] = ${CodeGenerator.getValue(row, field.dataType, ordinal.toString)}; } } } @@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) if ($index >= $eval1.numElements() || $index < 0$nullCheck) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(eval1, dataType, index)}; + ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)}; } """ }) @@ -327,6 +327,7 @@ case class GetMapValue(child: Expression, key: Expression) } else { "" } + val keyJavaType = CodeGenerator.javaType(keyType) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int $length = $eval1.numElements(); @@ -336,7 +337,7 @@ case class GetMapValue(child: Expression, key: Expression) int $index = 0; boolean $found = false; while ($index < $length && !$found) { - final ${ctx.javaType(keyType)} $key = ${ctx.getValue(keys, keyType, index)}; + final $keyJavaType $key = ${CodeGenerator.getValue(keys, keyType, index)}; if (${ctx.genEqual(keyType, key, eval2)}) { $found = true; } else { @@ -347,7 +348,7 @@ case class GetMapValue(child: Expression, key: Expression) if (!$found$nullCheck) { ${ev.isNull} = true; } else { - ${ev.value} = ${ctx.getValue(values, dataType, index)}; + ${ev.value} = ${CodeGenerator.getValue(values, dataType, index)}; } """ }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index b444c3a7be..f4e9619bac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -69,7 +69,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi s""" |${condEval.code} |boolean ${ev.isNull} = false; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${condEval.isNull} && ${condEval.value}) { | ${trueEval.code} | ${ev.isNull} = ${trueEval.isNull}; @@ -191,7 +191,7 @@ case class CaseWhen( // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") - ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value) + ev.value = ctx.addMutableState(CodeGenerator.javaType(dataType), ev.value) // these blocks are meant to be inside a // do { @@ -244,10 +244,10 @@ case class CaseWhen( val codes = ctx.splitExpressionsWithCurrentInputs( expressions = allConditions, funcName = "caseWhen", - returnType = ctx.JAVA_BYTE, + returnType = CodeGenerator.JAVA_BYTE, makeSplitFunction = func => s""" - |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; + |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $func |} while (false); @@ -264,7 +264,7 @@ case class CaseWhen( ev.copy(code = s""" - |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; + |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $codes |} while (false); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 424871f204..1ae4e5a2f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -673,18 +673,19 @@ abstract class UnixTime } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val javaType = CodeGenerator.javaType(dataType) left.dataType match { case StringType if right.foldable => val df = classOf[DateFormat].getName if (formatter == null) { - ExprCode("", "true", ctx.defaultValue(dataType)) + ExprCode.forNullValue(dataType) } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { ${ev.value} = $formatterName.parse(${eval1.value}.toString()).getTime() / 1000L; @@ -713,7 +714,7 @@ abstract class UnixTime ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = ${eval1.value} / 1000000L; }""") @@ -724,7 +725,7 @@ abstract class UnixTime ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.daysToMillis(${eval1.value}, $tz) / 1000L; }""") @@ -819,7 +820,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ ev.copy(code = s""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { try { ${ev.value} = UTF8String.fromString($formatterName.format( @@ -1344,18 +1345,19 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { : ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val javaType = CodeGenerator.javaType(dataType) if (format.foldable) { if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { val t = instant.genCode(ctx) val truncFuncStr = truncFunc(t.value, truncLevel.toString) ev.copy(code = s""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.$truncFuncStr; }""") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 055ebf6c0d..b702422ed7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -278,7 +278,7 @@ abstract class HashExpression[E] extends Expression { } } - val hashResultType = ctx.javaType(dataType) + val hashResultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = childrenHash, funcName = "computeHash", @@ -307,9 +307,10 @@ abstract class HashExpression[E] extends Expression { ctx: CodegenContext): String = { val element = ctx.freshName("element") + val jt = CodeGenerator.javaType(elementType) ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") { s""" - final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; + final $jt $element = ${CodeGenerator.getValue(input, elementType, index)}; ${computeHash(element, elementType, result, ctx)} """ } @@ -407,7 +408,7 @@ abstract class HashExpression[E] extends Expression { val fieldsHash = fields.zipWithIndex.map { case (field, index) => nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) } - val hashResultType = ctx.javaType(dataType) + val hashResultType = CodeGenerator.javaType(dataType) ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", @@ -651,11 +652,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = childrenHash, funcName = "computeHash", - extraArguments = Seq(ctx.JAVA_INT -> ev.value), - returnType = ctx.JAVA_INT, + extraArguments = Seq(CodeGenerator.JAVA_INT -> ev.value), + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" - |${ctx.JAVA_INT} $childHash = 0; + |${CodeGenerator.JAVA_INT} $childHash = 0; |$body |return ${ev.value}; """.stripMargin, @@ -664,8 +665,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { ev.copy(code = s""" - |${ctx.JAVA_INT} ${ev.value} = $seed; - |${ctx.JAVA_INT} $childHash = 0; + |${CodeGenerator.JAVA_INT} ${ev.value} = $seed; + |${CodeGenerator.JAVA_INT} $childHash = 0; |$codes """.stripMargin) } @@ -780,14 +781,14 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """.stripMargin } - s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( + s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result), - returnType = ctx.JAVA_INT, + arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result), + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" - |${ctx.JAVA_INT} $childResult = 0; + |${CodeGenerator.JAVA_INT} $childResult = 0; |$body |return $result; """.stripMargin, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 7a8edabed1..07785e7448 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -42,7 +42,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + s"$className.getInputFilePath();", isNull = "false") } } @@ -65,7 +65,7 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + s"$className.getStartOffset();", isNull = "false") } } @@ -88,7 +88,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + s"$className.getLength();", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index c1e65e34c2..7395609a04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -277,13 +277,9 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def eval(input: InternalRow): Any = value override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) if (value == null) { - val defaultValueLiteral = ctx.defaultValue(javaType) match { - case "null" => s"(($javaType)null)" - case lit => lit - } - ExprCode(code = "", isNull = "true", value = defaultValueLiteral) + ExprCode.forNullValue(dataType) } else { dataType match { case BooleanType | IntegerType | DateType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index d8dc0862f1..2c2cf3d2e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1128,15 +1128,16 @@ abstract class RoundBase(child: Expression, scale: Expression, }""" } + val javaType = CodeGenerator.javaType(dataType) if (scaleV == null) { // if scale is null, no need to eval its child at all ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { ev.copy(code = s""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { $evaluationCode }""") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 470d5da041..b35fa72e95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -72,7 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + ev.isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -87,14 +87,14 @@ case class Coalesce(children: Seq[Expression]) extends Expression { """.stripMargin } - val resultType = ctx.javaType(dataType) + val resultType = CodeGenerator.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "coalesce", returnType = resultType, makeSplitFunction = func => s""" - |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { | $func |} while (false); @@ -113,7 +113,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |${ev.isNull} = true; - |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { | $codes |} while (false); @@ -234,7 +234,7 @@ case class IsNaN(child: Expression) extends UnaryExpression case DoubleType | FloatType => ev.copy(code = s""" ${eval.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = "false") } } @@ -281,7 +281,7 @@ case class NaNvl(left: Expression, right: Expression) ev.copy(code = s""" ${leftGen.code} boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (${leftGen.isNull}) { ${ev.isNull} = true; } else { @@ -416,8 +416,8 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "atLeastNNonNulls", - extraArguments = (ctx.JAVA_INT, nonnull) :: Nil, - returnType = ctx.JAVA_INT, + extraArguments = (CodeGenerator.JAVA_INT, nonnull) :: Nil, + returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" |do { @@ -436,11 +436,11 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate ev.copy(code = s""" - |${ctx.JAVA_INT} $nonnull = 0; + |${CodeGenerator.JAVA_INT} $nonnull = 0; |do { | $codes |} while (false); - |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; """.stripMargin, isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 64da9bb9cd..80618af1e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -62,13 +62,13 @@ trait InvokeLike extends Expression with NonSQLExpression { def prepareArguments(ctx: CodegenContext): (String, String, String) = { val resultIsNull = if (needNullCheck) { - val resultIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "resultIsNull") + val resultIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "resultIsNull") resultIsNull } else { "false" } val argValues = arguments.map { e => - val argValue = ctx.addMutableState(ctx.javaType(e.dataType), "argValue") + val argValue = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "argValue") argValue } @@ -137,7 +137,7 @@ case class StaticInvoke( throw new UnsupportedOperationException("Only code-generated evaluation is supported.") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -151,7 +151,7 @@ case class StaticInvoke( } val evaluate = if (returnNullable) { - if (ctx.defaultValue(dataType) == "null") { + if (CodeGenerator.defaultValue(dataType) == "null") { s""" ${ev.value} = $callFunc; ${ev.isNull} = ${ev.value} == null; @@ -159,7 +159,7 @@ case class StaticInvoke( } else { val boxedResult = ctx.freshName("boxedResult") s""" - ${ctx.boxedType(dataType)} $boxedResult = $callFunc; + ${CodeGenerator.boxedType(dataType)} $boxedResult = $callFunc; ${ev.isNull} = $boxedResult == null; if (!${ev.isNull}) { ${ev.value} = $boxedResult; @@ -173,7 +173,7 @@ case class StaticInvoke( val code = s""" $argCode $prepareIsNull - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!$resultIsNull) { $evaluate } @@ -228,7 +228,7 @@ case class Invoke( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val obj = targetObject.genCode(ctx) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -255,11 +255,11 @@ case class Invoke( // If the function can return null, we do an extra check to make sure our null bit is still // set correctly. val assignResult = if (!returnNullable) { - s"${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;" + s"${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult;" } else { s""" if ($funcResult != null) { - ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; + ${ev.value} = (${CodeGenerator.boxedType(javaType)}) $funcResult; } else { ${ev.isNull} = true; } @@ -275,7 +275,7 @@ case class Invoke( val code = s""" ${obj.code} boolean ${ev.isNull} = true; - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${obj.isNull}) { $argCode ${ev.isNull} = $resultIsNull; @@ -341,7 +341,7 @@ case class NewInstance( throw new UnsupportedOperationException("Only code-generated evaluation is supported.") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val (argCode, argString, resultIsNull) = prepareArguments(ctx) @@ -358,7 +358,8 @@ case class NewInstance( val code = s""" $argCode ${outer.map(_.code).getOrElse("")} - final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall; + final $javaType ${ev.value} = ${ev.isNull} ? + ${CodeGenerator.defaultValue(dataType)} : $constructorCall; """ ev.copy(code = code) } @@ -385,15 +386,15 @@ case class UnwrapOption( throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val inputObject = child.genCode(ctx) val code = s""" ${inputObject.code} final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); - $javaType ${ev.value} = ${ev.isNull} ? - ${ctx.defaultValue(javaType)} : (${ctx.boxedType(javaType)}) ${inputObject.value}.get(); + $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : + (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get(); """ ev.copy(code = code) } @@ -546,7 +547,7 @@ case class MapObjects private( ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val elementJavaType = ctx.javaType(loopVarDataType) + val elementJavaType = CodeGenerator.javaType(loopVarDataType) ctx.addMutableState(elementJavaType, loopValue, forceInline = true, useFreshName = false) val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) @@ -554,7 +555,7 @@ case class MapObjects private( val convertedArray = ctx.freshName("convertedArray") val loopIndex = ctx.freshName("loopIndex") - val convertedType = ctx.boxedType(lambdaFunction.dataType) + val convertedType = CodeGenerator.boxedType(lambdaFunction.dataType) // Because of the way Java defines nested arrays, we have to handle the syntax specially. // Specifically, we have to insert the [$dataLength] in between the type and any extra nested @@ -621,7 +622,7 @@ case class MapObjects private( ( s"${genInputData.value}.numElements()", "", - ctx.getValue(genInputData.value, et, loopIndex) + CodeGenerator.getValue(genInputData.value, et, loopIndex) ) case ObjectType(cls) if cls == classOf[Object] => val it = ctx.freshName("it") @@ -643,7 +644,8 @@ case class MapObjects private( } val loopNullCheck = if (loopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, loopIsNull, forceInline = true, useFreshName = false) inputDataType match { case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" case _ => s"$loopIsNull = $loopValue == null;" @@ -695,7 +697,7 @@ case class MapObjects private( val code = s""" ${genInputData.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { $determineCollectionType @@ -806,10 +808,10 @@ case class CatalystToExternalMap private( } val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] - val keyElementJavaType = ctx.javaType(mapType.keyType) + val keyElementJavaType = CodeGenerator.javaType(mapType.keyType) ctx.addMutableState(keyElementJavaType, keyLoopValue, forceInline = true, useFreshName = false) val genKeyFunction = keyLambdaFunction.genCode(ctx) - val valueElementJavaType = ctx.javaType(mapType.valueType) + val valueElementJavaType = CodeGenerator.javaType(mapType.valueType) ctx.addMutableState(valueElementJavaType, valueLoopValue, forceInline = true, useFreshName = false) val genValueFunction = valueLambdaFunction.genCode(ctx) @@ -825,10 +827,11 @@ case class CatalystToExternalMap private( val valueArray = ctx.freshName("valueArray") val getKeyArray = s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();" - val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) + val getKeyLoopVar = CodeGenerator.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) val getValueArray = s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();" - val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex) + val getValueLoopVar = CodeGenerator.getValue( + valueArray, inputDataType(mapType.valueType), loopIndex) // Make a copy of the data if it's unsafe-backed def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = @@ -844,7 +847,7 @@ case class CatalystToExternalMap private( val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) val valueLoopNullCheck = if (valueLoopIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, valueLoopIsNull, forceInline = true, useFreshName = false) s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" } else { @@ -873,7 +876,7 @@ case class CatalystToExternalMap private( val code = s""" ${genInputData.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { int $dataLength = $getLength; @@ -993,8 +996,8 @@ case class ExternalMapToCatalyst private( val entry = ctx.freshName("entry") val entries = ctx.freshName("entries") - val keyElementJavaType = ctx.javaType(keyType) - val valueElementJavaType = ctx.javaType(valueType) + val keyElementJavaType = CodeGenerator.javaType(keyType) + val valueElementJavaType = CodeGenerator.javaType(valueType) ctx.addMutableState(keyElementJavaType, key, forceInline = true, useFreshName = false) ctx.addMutableState(valueElementJavaType, value, forceInline = true, useFreshName = false) @@ -1009,8 +1012,8 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next(); - $key = (${ctx.boxedType(keyType)}) $entry.getKey(); - $value = (${ctx.boxedType(valueType)}) $entry.getValue(); + $key = (${CodeGenerator.boxedType(keyType)}) $entry.getKey(); + $value = (${CodeGenerator.boxedType(valueType)}) $entry.getValue(); """ defineEntries -> defineKeyValue @@ -1024,22 +1027,24 @@ case class ExternalMapToCatalyst private( val defineKeyValue = s""" final $scalaMapEntryCls $entry = ($scalaMapEntryCls) $entries.next(); - $key = (${ctx.boxedType(keyType)}) $entry._1(); - $value = (${ctx.boxedType(valueType)}) $entry._2(); + $key = (${CodeGenerator.boxedType(keyType)}) $entry._1(); + $value = (${CodeGenerator.boxedType(valueType)}) $entry._2(); """ defineEntries -> defineKeyValue } val keyNullCheck = if (keyIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, keyIsNull, forceInline = true, useFreshName = false) s"$keyIsNull = $key == null;" } else { "" } val valueNullCheck = if (valueIsNull != "false") { - ctx.addMutableState(ctx.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) + ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, valueIsNull, forceInline = true, useFreshName = false) s"$valueIsNull = $value == null;" } else { "" @@ -1047,12 +1052,12 @@ case class ExternalMapToCatalyst private( val arrayCls = classOf[GenericArrayData].getName val mapCls = classOf[ArrayBasedMapData].getName - val convertedKeyType = ctx.boxedType(keyConverter.dataType) - val convertedValueType = ctx.boxedType(valueConverter.dataType) + val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType) + val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType) val code = s""" ${inputMap.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${inputMap.isNull}) { final int $length = ${inputMap.value}.size(); final Object[] $convertedKeys = new Object[$length]; @@ -1174,12 +1179,13 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) // Code to serialize. val input = child.genCode(ctx) - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val serialize = s"$serializer.serialize(${input.value}, null).array()" val code = s""" ${input.code} - final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $serialize; + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize; """ ev.copy(code = code, isNull = input.isNull) } @@ -1223,13 +1229,14 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B // Code to deserialize. val input = child.genCode(ctx) - val javaType = ctx.javaType(dataType) + val javaType = CodeGenerator.javaType(dataType) val deserialize = s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" val code = s""" ${input.code} - final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize; + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize; """ ev.copy(code = code, isNull = input.isNull) } @@ -1254,7 +1261,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp val instanceGen = beanInstance.genCode(ctx) val javaBeanInstance = ctx.freshName("javaBean") - val beanInstanceJavaType = ctx.javaType(beanInstance.dataType) + val beanInstanceJavaType = CodeGenerator.javaType(beanInstance.dataType) val initialize = setters.map { case (setterMethod, fieldValue) => @@ -1405,15 +1412,15 @@ case class ValidateExternalType(child: Expression, expected: DataType) case _: ArrayType => s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" case _ => - s"$obj instanceof ${ctx.boxedType(dataType)}" + s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } val code = s""" ${input.code} - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${input.isNull}) { if ($typeCheck) { - ${ev.value} = (${ctx.boxedType(dataType)}) $obj; + ${ev.value} = (${CodeGenerator.boxedType(dataType)}) $obj; } else { throw new RuntimeException($obj.getClass().getName() + $errMsgField); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a6d41ea7d0..4b85d9adbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -235,7 +235,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaDataType = ctx.javaType(value.dataType) + val javaDataType = CodeGenerator.javaType(value.dataType) val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) // inTmpResult has 3 possible values: @@ -263,8 +263,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = listCode, funcName = "valueIn", - extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil, - returnType = ctx.JAVA_BYTE, + extraArguments = (javaDataType, valueArg) :: (CodeGenerator.JAVA_BYTE, tmpResult) :: Nil, + returnType = CodeGenerator.JAVA_BYTE, makeSplitFunction = body => s""" |do { @@ -348,8 +348,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with ev.copy(code = s""" |${childGen.code} - |${ctx.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; - |${ctx.JAVA_BOOLEAN} ${ev.value} = false; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; + |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; |if (!${ev.isNull}) { | ${ev.value} = $setTerm.contains(${childGen.value}); | $setIsNull @@ -505,7 +505,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (ctx.isPrimitiveType(left.dataType) + if (CodeGenerator.isPrimitiveType(left.dataType) && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType && left.dataType != DoubleType) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 8bc936fcbf..6c9937dacc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -82,7 +82,8 @@ case class Rand(child: Expression) extends RDG { ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", + isNull = "false") } } @@ -116,7 +117,8 @@ case class Randn(child: Expression) extends RDG { ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") ev.copy(code = s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", + isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index f3e8f6de58..ad0c0791d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -126,7 +126,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); } @@ -134,7 +134,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } else { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) } } else { @@ -201,7 +201,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0); } @@ -209,7 +209,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } else { ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) } } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index d7612e30b4..22fbb8998e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -102,11 +102,11 @@ case class Concat(children: Seq[Expression]) extends Expression { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcat", - extraArguments = (s"${ctx.javaType(dataType)}[]", args) :: Nil) + extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil) ev.copy(s""" $initCode $codes - ${ctx.javaType(dataType)} ${ev.value} = $concatenator.concat($args); + ${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args); boolean ${ev.isNull} = ${ev.value} == null; """) } @@ -196,7 +196,7 @@ case class ConcatWs(children: Seq[Expression]) } else { val array = ctx.freshName("array") val varargNum = ctx.freshName("varargNum") - val idxInVararg = ctx.freshName("idxInVararg") + val idxVararg = ctx.freshName("idxInVararg") val evals = children.map(_.genCode(ctx)) val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) => @@ -206,7 +206,7 @@ case class ConcatWs(children: Seq[Expression]) if (eval.isNull == "true") { "" } else { - s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" + s"$array[$idxVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.value};" }) case _: ArrayType => val size = ctx.freshName("n") @@ -222,7 +222,7 @@ case class ConcatWs(children: Seq[Expression]) if (!${eval.isNull}) { final int $size = ${eval.value}.numElements(); for (int j = 0; j < $size; j ++) { - $array[$idxInVararg ++] = ${ctx.getValue(eval.value, StringType, "j")}; + $array[$idxVararg ++] = ${CodeGenerator.getValue(eval.value, StringType, "j")}; } } """) @@ -247,20 +247,20 @@ case class ConcatWs(children: Seq[Expression]) val varargBuilds = ctx.splitExpressionsWithCurrentInputs( expressions = varargBuild, funcName = "varargBuildsConcatWs", - extraArguments = ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, + extraArguments = ("UTF8String []", array) :: ("int", idxVararg) :: Nil, returnType = "int", makeSplitFunction = body => s""" |$body - |return $idxInVararg; + |return $idxVararg; """.stripMargin, - foldFunctions = _.map(funcCall => s"$idxInVararg = $funcCall;").mkString("\n")) + foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n")) ev.copy( s""" $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; - int $idxInVararg = 0; + int $idxVararg = 0; $varargCounts UTF8String[] $array = new UTF8String[$varargNum]; $varargBuilds @@ -333,7 +333,7 @@ case class Elt(children: Seq[Expression]) extends Expression { val indexVal = ctx.freshName("index") val indexMatched = ctx.freshName("eltIndexMatched") - val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal") + val inputVal = ctx.addMutableState(CodeGenerator.javaType(dataType), "inputVal") val assignInputValue = inputs.zipWithIndex.map { case (eval, index) => s""" @@ -350,10 +350,10 @@ case class Elt(children: Seq[Expression]) extends Expression { expressions = assignInputValue, funcName = "eltFunc", extraArguments = ("int", indexVal) :: Nil, - returnType = ctx.JAVA_BOOLEAN, + returnType = CodeGenerator.JAVA_BOOLEAN, makeSplitFunction = body => s""" - |${ctx.JAVA_BOOLEAN} $indexMatched = false; + |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; |do { | $body |} while (false); @@ -372,12 +372,12 @@ case class Elt(children: Seq[Expression]) extends Expression { s""" |${index.code} |final int $indexVal = ${index.value}; - |${ctx.JAVA_BOOLEAN} $indexMatched = false; + |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; |$inputVal = null; |do { | $codes |} while (false); - |final ${ctx.javaType(dataType)} ${ev.value} = $inputVal; + |final ${CodeGenerator.javaType(dataType)} ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) } @@ -1410,10 +1410,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val numArgLists = argListGen.length val argListCode = argListGen.zipWithIndex.map { case(v, index) => val value = - if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { + if (CodeGenerator.boxedType(v._1) != CodeGenerator.javaType(v._1)) { // Java primitives get boxed in order to allow null values. - s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + - s"new ${ctx.boxedType(v._1)}(${v._2.value})" + s"(${v._2.isNull}) ? (${CodeGenerator.boxedType(v._1)}) null : " + + s"new ${CodeGenerator.boxedType(v._1)}(${v._2.value})" } else { s"(${v._2.isNull}) ? null : ${v._2.value}" } @@ -1434,7 +1434,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC ev.copy(code = s""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { $stringBuffer $sb = new $stringBuffer(); $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); @@ -2110,7 +2110,8 @@ case class FormatNumber(x: Expression, d: Expression) val usLocale = "US" val i = ctx.freshName("i") val dFormat = ctx.freshName("dFormat") - val lastDValue = ctx.addMutableState(ctx.JAVA_INT, "lastDValue", v => s"$v = -100;") + val lastDValue = + ctx.addMutableState(CodeGenerator.JAVA_INT, "lastDValue", v => s"$v = -100;") val pattern = ctx.addMutableState(sb, "pattern", v => s"$v = new $sb();") val numberFormat = ctx.addMutableState(df, "numberFormat", v => s"""$v = new $df("", new $dfs($l.$usLocale));""") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 676ba3956d..1e48c7b8df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -405,12 +405,12 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-18016: define mutable states by using an array") { val ctx1 = new CodegenContext for (i <- 1 to CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) { - ctx1.addMutableState(ctx1.JAVA_INT, "i", v => s"$v = $i;") + ctx1.addMutableState(CodeGenerator.JAVA_INT, "i", v => s"$v = $i;") } assert(ctx1.inlinedMutableStates.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD) // When the number of primitive type mutable states is over the threshold, others are // allocated into an array - assert(ctx1.arrayCompactedMutableStates.get(ctx1.JAVA_INT).get.arrayNames.size == 1) + assert(ctx1.arrayCompactedMutableStates.get(CodeGenerator.JAVA_INT).get.arrayNames.size == 1) assert(ctx1.mutableStateInitCode.size == CodeGenerator.OUTER_CLASS_VARIABLES_THRESHOLD + 10) val ctx2 = new CodegenContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 04f2619ed7..392906a022 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -49,15 +49,15 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { ordinal: String, dataType: DataType, nullable: Boolean): ExprCode = { - val javaType = ctx.javaType(dataType) - val value = ctx.getValueFromVector(columnVar, dataType, ordinal) + val javaType = CodeGenerator.javaType(dataType) + val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { s""" boolean $isNullVar = $columnVar.isNullAt($ordinal); - $javaType $valueVar = $isNullVar ? ${ctx.defaultValue(dataType)} : ($value); + $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); """ } else { s"$javaType $valueVar = $value;" @@ -85,12 +85,13 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { // metrics val numOutputRows = metricTerm(ctx, "numOutputRows") val scanTimeMetric = metricTerm(ctx, "scanTime") - val scanTimeTotalNs = ctx.addMutableState(ctx.JAVA_LONG, "scanTime") // init as scanTime = 0 + val scanTimeTotalNs = + ctx.addMutableState(CodeGenerator.JAVA_LONG, "scanTime") // init as scanTime = 0 val columnarBatchClz = classOf[ColumnarBatch].getName val batch = ctx.addMutableState(columnarBatchClz, "batch") - val idx = ctx.addMutableState(ctx.JAVA_INT, "batchIdx") // init as batchIdx = 0 + val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0 val columnVectorClzs = vectorTypes.getOrElse( Seq.fill(output.indices.size)(classOf[ColumnVector].getName)) val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index a7bd5ebf93..12ae1ea4a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -154,7 +154,8 @@ case class ExpandExec( val value = ctx.freshName("value") val code = s""" |boolean $isNull = true; - |${ctx.javaType(firstExpr.dataType)} $value = ${ctx.defaultValue(firstExpr.dataType)}; + |${CodeGenerator.javaType(firstExpr.dataType)} $value = + | ${CodeGenerator.defaultValue(firstExpr.dataType)}; """.stripMargin ExprCode(code, isNull, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 0c2c4a1a91..384f0398a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} @@ -305,15 +305,15 @@ case class GenerateExec( nullable: Boolean, initialChecks: Seq[String]): ExprCode = { val value = ctx.freshName(name) - val javaType = ctx.javaType(dt) - val getter = ctx.getValue(source, dt, index) + val javaType = CodeGenerator.javaType(dt) + val getter = CodeGenerator.getValue(source, dt, index) val checks = initialChecks ++ optionalCode(nullable, s"$source.isNullAt($index)") if (checks.nonEmpty) { val isNull = ctx.freshName("isNull") val code = s""" |boolean $isNull = ${checks.mkString(" || ")}; - |$javaType $value = $isNull ? ${ctx.defaultValue(dt)} : $getter; + |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin ExprCode(code, isNull, value) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index ac1c34d41c..0dc16ba5ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics @@ -133,7 +133,8 @@ case class SortExec( override def needStopCheck: Boolean = false override protected def doProduce(ctx: CodegenContext): String = { - val needToSort = ctx.addMutableState(ctx.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") + val needToSort = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;") // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index deb0a044c2..f89e3fb0e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -234,7 +234,7 @@ trait CodegenSupport extends SparkPlan { variables.zipWithIndex.foreach { case (ev, i) => val paramName = ctx.freshName(s"expr_$i") - val paramType = ctx.javaType(attributes(i).dataType) + val paramType = CodeGenerator.javaType(attributes(i).dataType) arguments += ev.value parameters += s"$paramType $paramName" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index ce3c68810f..1926e9373b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -178,7 +178,7 @@ case class HashAggregateExec( private var bufVars: Seq[ExprCode] = _ private def doProduceWithoutKeys(ctx: CodegenContext): String = { - val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") + val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") // The generated function doesn't have input row in the code context. ctx.INPUT_ROW = null @@ -186,8 +186,8 @@ case class HashAggregateExec( val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) bufVars = initExpr.map { e => - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") - val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") // The initial expression should not access any column val ev = e.genCode(ctx) val initVars = s""" @@ -532,7 +532,7 @@ case class HashAggregateExec( */ private def checkIfFastHashMapSupported(ctx: CodegenContext): Boolean = { val isSupported = - (groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) || + (groupingKeySchema ++ bufferSchema).forall(f => CodeGenerator.isPrimitiveType(f.dataType) || f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) && bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge) @@ -565,7 +565,7 @@ case class HashAggregateExec( } private def doProduceWithKeys(ctx: CodegenContext): String = { - val initAgg = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initAgg") + val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else { @@ -757,7 +757,7 @@ case class HashAggregateExec( val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { - val countTerm = ctx.addMutableState(ctx.JAVA_INT, "fallbackCounter") + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") (s"$countTerm < ${testFallbackStartsAt.get._1}", s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") } else { @@ -832,7 +832,7 @@ case class HashAggregateExec( } val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) } s""" |// common sub-expressions @@ -855,7 +855,7 @@ case class HashAggregateExec( } val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn( + CodeGenerator.updateColumn( fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 1c613b19c4..6b60b414ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.types._ /** @@ -41,13 +41,13 @@ abstract class HashMapGenerator( val groupingKeys = groupingKeySchema.map(k => Buffer(k.dataType, ctx.freshName("key"))) val bufferValues = bufferSchema.map(k => Buffer(k.dataType, ctx.freshName("value"))) val groupingKeySignature = - groupingKeys.map(key => s"${ctx.javaType(key.dataType)} ${key.name}").mkString(", ") + groupingKeys.map(key => s"${CodeGenerator.javaType(key.dataType)} ${key.name}").mkString(", ") val buffVars: Seq[ExprCode] = { val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val initExpr = functions.flatMap(f => f.initialValues) initExpr.map { e => - val isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "bufIsNull") - val value = ctx.addMutableState(ctx.javaType(e.dataType), "bufValue") + val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "bufIsNull") + val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") val ev = e.genCode(ctx) val initVars = s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index fd25707dd4..8617be88f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} import org.apache.spark.sql.types._ /** @@ -114,7 +114,7 @@ class RowBasedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, ctx.getValue("row", + s"""(${ctx.genEqual(key.dataType, CodeGenerator.getValue("row", key.dataType, ordinal.toString()), key.name)})""" }.mkString(" && ") } @@ -147,7 +147,7 @@ class RowBasedHashMapGenerator( case t: DecimalType => s"agg_rowWriter.write(${ordinal}, ${key.name}, ${t.precision}, ${t.scale})" case t: DataType => - if (!t.isInstanceOf[StringType] && !ctx.isPrimitiveType(t)) { + if (!t.isInstanceOf[StringType] && !CodeGenerator.isPrimitiveType(t)) { throw new IllegalArgumentException(s"cannot generate code for unsupported type: $t") } s"agg_rowWriter.write(${ordinal}, ${key.name})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 633eeac180..7b3580cecc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -127,7 +127,8 @@ class VectorizedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - val value = ctx.getValueFromVector(s"vectors[$ordinal]", key.dataType, "buckets[idx]") + val value = CodeGenerator.getValueFromVector(s"vectors[$ordinal]", key.dataType, + "buckets[idx]") s"(${ctx.genEqual(key.dataType, value, key.name)})" }.mkString(" && ") } @@ -182,14 +183,14 @@ class VectorizedHashMapGenerator( def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name) + CodeGenerator.setValue(s"vectors[$ordinal]", "numRows", key.dataType, key.name) } } def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = { bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType, - buffVars(ordinal), nullable = true) + CodeGenerator.updateColumn(s"vectors[${groupingKeys.length + ordinal}]", "numRows", + key.dataType, buffVars(ordinal), nullable = true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index a15a8d11aa..4707022f74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -24,7 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, ExpressionCanonicalizer} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType @@ -364,8 +364,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) protected override def doProduce(ctx: CodegenContext): String = { val numOutput = metricTerm(ctx, "numOutputRows") - val initTerm = ctx.addMutableState(ctx.JAVA_BOOLEAN, "initRange") - val number = ctx.addMutableState(ctx.JAVA_LONG, "number") + val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange") + val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number") val value = ctx.freshName("value") val ev = ExprCode("", "false", value) @@ -385,10 +385,10 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // the metrics. // Once number == batchEnd, it's time to progress to the next batch. - val batchEnd = ctx.addMutableState(ctx.JAVA_LONG, "batchEnd") + val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd") // How many values should still be generated by this range operator. - val numElementsTodo = ctx.addMutableState(ctx.JAVA_LONG, "numElementsTodo") + val numElementsTodo = ctx.addMutableState(CodeGenerator.JAVA_LONG, "numElementsTodo") // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index 4f28eeb725..3b5655ba05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -91,7 +91,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val accessorName = ctx.addMutableState(accessorCls, "accessor") val createCode = dt match { - case t if ctx.isPrimitiveType(dt) => + case t if CodeGenerator.isPrimitiveType(dt) => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" case NullType | StringType | BinaryType => s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 1918fcc548..487d6a2383 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -22,7 +22,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} @@ -182,9 +182,10 @@ case class BroadcastHashJoinExec( // the variables are needed even there is no matched rows val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") + val javaType = CodeGenerator.javaType(a.dataType) val code = s""" |boolean $isNull = true; - |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; + |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; |if ($matched != null) { | ${ev.code} | $isNull = ${ev.isNull}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 2de2f30eb0..5a511b30e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, @@ -516,9 +516,9 @@ case class SortMergeJoinExec( ctx.INPUT_ROW = leftRow left.output.zipWithIndex.map { case (a, i) => val value = ctx.freshName("value") - val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) - val javaType = ctx.javaType(a.dataType) - val defaultValue = ctx.defaultValue(a.dataType) + val valueCode = CodeGenerator.getValue(leftRow, a.dataType, i.toString) + val javaType = CodeGenerator.javaType(a.dataType) + val defaultValue = CodeGenerator.defaultValue(a.dataType) if (a.nullable) { val isNull = ctx.freshName("isNull") val code = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index cccee63bc0..66bcda8913 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.util.Utils @@ -71,7 +71,8 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val stopEarly = ctx.addMutableState(ctx.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false + val stopEarly = + ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "stopEarly") // init as stopEarly = false ctx.addNewFunction("stopEarly", s""" @Override @@ -79,7 +80,7 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { return $stopEarly; } """, inlineToOuterClass = true) - val countTerm = ctx.addMutableState(ctx.JAVA_INT, "count") // init as count = 0 + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "count") // init as count = 0 s""" | if ($countTerm < $limit) { | $countTerm += 1;