[SPARK-20798] GenerateUnsafeProjection should check if a value is null before calling the getter

## What changes were proposed in this pull request?

GenerateUnsafeProjection.writeStructToBuffer() did not honor the assumption that the caller must make sure that a value is not null before using the getter. This could lead to various errors. This change fixes that behavior.

Example of code generated before:
```scala
/* 059 */         final UTF8String fieldName = value.getUTF8String(0);
/* 060 */         if (value.isNullAt(0)) {
/* 061 */           rowWriter1.setNullAt(0);
/* 062 */         } else {
/* 063 */           rowWriter1.write(0, fieldName);
/* 064 */         }
```

Example of code generated now:
```scala
/* 060 */         boolean isNull1 = value.isNullAt(0);
/* 061 */         UTF8String value1 = isNull1 ? null : value.getUTF8String(0);
/* 062 */         if (isNull1) {
/* 063 */           rowWriter1.setNullAt(0);
/* 064 */         } else {
/* 065 */           rowWriter1.write(0, value1);
/* 066 */         }
```

## How was this patch tested?

Adds GenerateUnsafeProjectionSuite.

Author: Ala Luszczak <ala@databricks.com>

Closes #18030 from ala/fix-generate-unsafe-projection.
This commit is contained in:
Ala Luszczak 2017-05-19 13:18:48 +02:00 committed by Herman van Hovell
parent 92580bd0ea
commit ce8edb8bf4
3 changed files with 78 additions and 4 deletions

View file

@ -50,10 +50,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
fieldTypes: Seq[DataType],
bufferHolder: String): String = {
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
val fieldName = ctx.freshName("fieldName")
val code = s"final ${ctx.javaType(dt)} $fieldName = ${ctx.getValue(input, dt, i.toString)};"
val isNull = s"$input.isNullAt($i)"
ExprCode(code, isNull, fieldName)
val javaType = ctx.javaType(dt)
val isNullVar = ctx.freshName("isNull")
val valueVar = ctx.freshName("value")
val defaultValue = ctx.defaultValue(dt)
val readValue = ctx.getValue(input, dt, i.toString)
val code =
s"""
boolean $isNullVar = $input.isNullAt($i);
$javaType $valueVar = $isNullVar ? $defaultValue : $readValue;
"""
ExprCode(code, isNullVar, valueVar)
}
s"""

View file

@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
class GenerateUnsafeProjectionSuite extends SparkFunSuite {
test("Test unsafe projection string access pattern") {
val dataType = (new StructType).add("a", StringType)
val exprs = BoundReference(0, dataType, nullable = true) :: Nil
val projection = GenerateUnsafeProjection.generate(exprs)
val result = projection.apply(InternalRow(AlwaysNull))
assert(!result.isNullAt(0))
assert(result.getStruct(0, 1).isNullAt(0))
}
}
object AlwaysNull extends InternalRow {
override def numFields: Int = 1
override def setNullAt(i: Int): Unit = {}
override def copy(): InternalRow = this
override def anyNull: Boolean = true
override def isNullAt(ordinal: Int): Boolean = true
override def update(i: Int, value: Any): Unit = notSupported
override def getBoolean(ordinal: Int): Boolean = notSupported
override def getByte(ordinal: Int): Byte = notSupported
override def getShort(ordinal: Int): Short = notSupported
override def getInt(ordinal: Int): Int = notSupported
override def getLong(ordinal: Int): Long = notSupported
override def getFloat(ordinal: Int): Float = notSupported
override def getDouble(ordinal: Int): Double = notSupported
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported
override def getUTF8String(ordinal: Int): UTF8String = notSupported
override def getBinary(ordinal: Int): Array[Byte] = notSupported
override def getInterval(ordinal: Int): CalendarInterval = notSupported
override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported
override def getArray(ordinal: Int): ArrayData = notSupported
override def getMap(ordinal: Int): MapData = notSupported
override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
private def notSupported: Nothing = throw new UnsupportedOperationException
}

View file

@ -198,21 +198,25 @@ public final class ColumnarBatch {
@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
if (columns[ordinal].isNullAt(rowId)) return null;
return columns[ordinal].getDecimal(rowId, precision, scale);
}
@Override
public UTF8String getUTF8String(int ordinal) {
if (columns[ordinal].isNullAt(rowId)) return null;
return columns[ordinal].getUTF8String(rowId);
}
@Override
public byte[] getBinary(int ordinal) {
if (columns[ordinal].isNullAt(rowId)) return null;
return columns[ordinal].getBinary(rowId);
}
@Override
public CalendarInterval getInterval(int ordinal) {
if (columns[ordinal].isNullAt(rowId)) return null;
final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
return new CalendarInterval(months, microseconds);
@ -220,11 +224,13 @@ public final class ColumnarBatch {
@Override
public InternalRow getStruct(int ordinal, int numFields) {
if (columns[ordinal].isNullAt(rowId)) return null;
return columns[ordinal].getStruct(rowId);
}
@Override
public ArrayData getArray(int ordinal) {
if (columns[ordinal].isNullAt(rowId)) return null;
return columns[ordinal].getArray(rowId);
}