[SPARK-11894][SQL] fix isNull for GetInternalRowField

We should use `InternalRow.isNullAt` to check if the field is null before calling `InternalRow.getXXX`

Thanks gatorsmile who discovered this bug.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9904 from cloud-fan/null.
This commit is contained in:
Wenchen Fan 2015-11-23 10:13:59 -08:00 committed by Michael Armbrust
parent 94ce65dfcb
commit 1a5baaa651
2 changed files with 23 additions and 15 deletions

View file

@ -236,11 +236,6 @@ case class NewInstance(
}
if (propagateNull) {
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
s"${ev.isNull} = ${ev.value} == null;"
} else {
""
}
val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
s"""
@ -531,15 +526,15 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val row = child.gen(ctx)
nullSafeCodeGen(ctx, ev, eval => {
s"""
${row.code}
final boolean ${ev.isNull} = ${row.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)};
if ($eval.isNullAt($ordinal)) {
${ev.isNull} = true;
} else {
${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
}
"""
})
}
}

View file

@ -386,7 +386,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
Seq((JavaData(1), 1L), (JavaData(2), 1L)))
}
ignore("Java encoder self join") {
test("Java encoder self join") {
implicit val kryoEncoder = Encoders.javaSerialization[JavaData]
val ds = Seq(JavaData(1), JavaData(2)).toDS()
assert(ds.joinWith(ds, lit(true)).collect().toSet ==
@ -396,6 +396,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
(JavaData(2), JavaData(1)),
(JavaData(2), JavaData(2))))
}
test("SPARK-11894: Incorrect results are returned when using null") {
val nullInt = null.asInstanceOf[java.lang.Integer]
val ds1 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS()
val ds2 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS()
checkAnswer(
ds1.joinWith(ds2, lit(true)),
((nullInt, "1"), (nullInt, "1")),
((new java.lang.Integer(22), "2"), (nullInt, "1")),
((nullInt, "1"), (new java.lang.Integer(22), "2")),
((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2")))
}
}