diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 1e418540a2..523eed825f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -64,33 +64,29 @@ case class StaticInvoke( val argGen = arguments.map(_.genCode(ctx)) val argString = argGen.map(_.value).mkString(", ") - if (propagateNull) { - val objNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" - } else { - "" - } + val callFunc = s"$objectName.$functionName($argString)" - val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" - ev.copy(code = s""" - ${argGen.map(_.code).mkString("\n")} - - boolean ${ev.isNull} = !$argsNonNull; - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; - - if ($argsNonNull) { - ${ev.value} = $objectName.$functionName($argString); - $objNullCheck - } - """) + val setIsNull = if (propagateNull && arguments.nonEmpty) { + s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};" } else { - ev.copy(code = s""" - ${argGen.map(_.code).mkString("\n")} - - $javaType ${ev.value} = $objectName.$functionName($argString); - final boolean ${ev.isNull} = ${ev.value} == null; - """) + s"boolean ${ev.isNull} = false;" } + + // If the function can return null, we do an extra check to make sure our null bit is still set + // correctly. + val postNullCheck = if (ctx.defaultValue(dataType) == "null") { + s"${ev.isNull} = ${ev.value} == null;" + } else { + "" + } + + val code = s""" + ${argGen.map(_.code).mkString("\n")} + $setIsNull + final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc; + $postNullCheck + """ + ev.copy(code = code) } } @@ -111,7 +107,8 @@ case class Invoke( targetObject: Expression, functionName: String, dataType: DataType, - arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression { + arguments: Seq[Expression] = Nil, + propagateNull: Boolean = true) extends Expression with NonSQLExpression { override def nullable: Boolean = true override def children: Seq[Expression] = targetObject +: arguments @@ -130,60 +127,53 @@ case class Invoke( case _ => None } - lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match { - case (IntegerType, "java.lang.Object") => (s: String) => - s"((java.lang.Integer)$s).intValue()" - case (LongType, "java.lang.Object") => (s: String) => - s"((java.lang.Long)$s).longValue()" - case (FloatType, "java.lang.Object") => (s: String) => - s"((java.lang.Float)$s).floatValue()" - case (ShortType, "java.lang.Object") => (s: String) => - s"((java.lang.Short)$s).shortValue()" - case (ByteType, "java.lang.Object") => (s: String) => - s"((java.lang.Byte)$s).byteValue()" - case (DoubleType, "java.lang.Object") => (s: String) => - s"((java.lang.Double)$s).doubleValue()" - case (BooleanType, "java.lang.Object") => (s: String) => - s"((java.lang.Boolean)$s).booleanValue()" - case _ => identity[String] _ - } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val obj = targetObject.genCode(ctx) val argGen = arguments.map(_.genCode(ctx)) val argString = argGen.map(_.value).mkString(", ") - // If the function can return null, we do an extra check to make sure our null bit is still set - // correctly. - val objNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"boolean ${ev.isNull} = ${ev.value} == null;" + val callFunc = if (method.isDefined && method.get.getReturnType.isPrimitive) { + s"${obj.value}.$functionName($argString)" } else { - ev.isNull = obj.isNull - "" + s"(${ctx.boxedType(javaType)}) ${obj.value}.$functionName($argString)" } - val value = unboxer(s"${obj.value}.$functionName($argString)") + val setIsNull = if (propagateNull && arguments.nonEmpty) { + s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};" + } else { + s"boolean ${ev.isNull} = ${obj.isNull};" + } val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) { - s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;" + s"final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;" } else { s""" $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; try { - ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value; + ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $callFunc; } catch (Exception e) { org.apache.spark.unsafe.Platform.throwException(e); } """ } - ev.copy(code = s""" + // If the function can return null, we do an extra check to make sure our null bit is still set + // correctly. + val postNullCheck = if (ctx.defaultValue(dataType) == "null") { + s"${ev.isNull} = ${ev.value} == null;" + } else { + "" + } + + val code = s""" ${obj.code} ${argGen.map(_.code).mkString("\n")} + $setIsNull $evaluate - $objNullCheck - """) + $postNullCheck + """ + ev.copy(code = code) } override def toString: String = s"$targetObject.$functionName" @@ -246,11 +236,13 @@ case class NewInstance( val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx)) - val setup = - s""" - ${argGen.map(_.code).mkString("\n")} - ${outer.map(_.code).getOrElse("")} - """.stripMargin + var isNull = ev.isNull + val setIsNull = if (propagateNull && arguments.nonEmpty) { + s"final boolean $isNull = ${argGen.map(_.isNull).mkString(" || ")};" + } else { + isNull = "false" + "" + } val constructorCall = outer.map { gen => s"""${gen.value}.new ${cls.getSimpleName}($argString)""" @@ -258,27 +250,13 @@ case class NewInstance( s"new $className($argString)" } - if (propagateNull && argGen.nonEmpty) { - val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" - - ev.copy(code = s""" - $setup - - boolean ${ev.isNull} = true; - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; - if ($argsNonNull) { - ${ev.value} = $constructorCall; - ${ev.isNull} = false; - } - """) - } else { - ev.copy(code = s""" - $setup - - final $javaType ${ev.value} = $constructorCall; - final boolean ${ev.isNull} = false; - """) - } + val code = s""" + ${argGen.map(_.code).mkString("\n")} + ${outer.map(_.code).getOrElse("")} + $setIsNull + final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall; + """ + ev.copy(code = code, isNull = isNull) } override def toString: String = s"newInstance($cls)" @@ -306,13 +284,14 @@ case class UnwrapOption( val javaType = ctx.javaType(dataType) val inputObject = child.genCode(ctx) - ev.copy(code = s""" + val code = s""" ${inputObject.code} - boolean ${ev.isNull} = ${inputObject.value} == null || ${inputObject.value}.isEmpty(); + final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); $javaType ${ev.value} = - ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType)${inputObject.value}.get(); - """) + ${ev.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) ${inputObject.value}.get(); + """ + ev.copy(code = code) } } @@ -338,14 +317,14 @@ case class WrapOption(child: Expression, optType: DataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputObject = child.genCode(ctx) - ev.copy(code = s""" + val code = s""" ${inputObject.code} - boolean ${ev.isNull} = false; scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); - """) + """ + ev.copy(code = code, isNull = "false") } } @@ -474,7 +453,7 @@ case class MapObjects private( s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" } - ev.copy(code = s""" + val code = s""" ${genInputData.code} boolean ${ev.isNull} = ${genInputData.value} == null; @@ -504,7 +483,8 @@ case class MapObjects private( ${ev.isNull} = false; ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray); } - """) + """ + ev.copy(code = code) } } @@ -539,14 +519,16 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) } """ } + val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes) val schemaField = ctx.addReferenceObj("schema", schema) - ev.copy(code = s""" - boolean ${ev.isNull} = false; + + val code = s""" $values = new Object[${children.size}]; $childrenCode final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField); - """) + """ + ev.copy(code = code, isNull = "false") } } @@ -579,14 +561,14 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) // Code to serialize. val input = child.genCode(ctx) - ev.copy(code = s""" + val javaType = ctx.javaType(dataType) + val serialize = s"$serializer.serialize(${input.value}, null).array()" + + val code = s""" ${input.code} - final boolean ${ev.isNull} = ${input.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = $serializer.serialize(${input.value}, null).array(); - } - """) + final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $serialize; + """ + ev.copy(code = code, isNull = input.isNull) } override def dataType: DataType = BinaryType @@ -617,17 +599,17 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B serializer, s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") - // Code to serialize. + // Code to deserialize. val input = child.genCode(ctx) - ev.copy(code = s""" + val javaType = ctx.javaType(dataType) + val deserialize = + s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" + + val code = s""" ${input.code} - final boolean ${ev.isNull} = ${input.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = (${ctx.javaType(dataType)}) - $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); - } - """) + final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize; + """ + ev.copy(code = code, isNull = input.isNull) } override def dataType: DataType = ObjectType(tag.runtimeClass) @@ -658,15 +640,13 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp """ } - ev.isNull = instanceGen.isNull - ev.value = instanceGen.value - - ev.copy(code = s""" + val code = s""" ${instanceGen.code} if (!${instanceGen.isNull}) { ${initialize.mkString("\n")} } - """) + """ + ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value) } } @@ -696,13 +676,15 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + "please try to use scala.Option[_] or other nullable types " + "(e.g. java.lang.Integer instead of int/scala.Int)." - val idx = ctx.references.length - ctx.references += errMsg - ExprCode(code = s""" + val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + + val code = s""" ${childGen.code} if (${childGen.isNull}) { - throw new RuntimeException((String) references[$idx]); - }""", isNull = "false", value = childGen.value) + throw new RuntimeException(this.$errMsgField); + } + """ + ev.copy(code = code, isNull = "false", value = childGen.value) } }