[SPARK-20798] GenerateUnsafeProjection should check if a value is null before calling the getter
## What changes were proposed in this pull request? GenerateUnsafeProjection.writeStructToBuffer() did not honor the assumption that the caller must make sure that a value is not null before using the getter. This could lead to various errors. This change fixes that behavior. Example of code generated before: ```scala /* 059 */ final UTF8String fieldName = value.getUTF8String(0); /* 060 */ if (value.isNullAt(0)) { /* 061 */ rowWriter1.setNullAt(0); /* 062 */ } else { /* 063 */ rowWriter1.write(0, fieldName); /* 064 */ } ``` Example of code generated now: ```scala /* 060 */ boolean isNull1 = value.isNullAt(0); /* 061 */ UTF8String value1 = isNull1 ? null : value.getUTF8String(0); /* 062 */ if (isNull1) { /* 063 */ rowWriter1.setNullAt(0); /* 064 */ } else { /* 065 */ rowWriter1.write(0, value1); /* 066 */ } ``` ## How was this patch tested? Adds GenerateUnsafeProjectionSuite. Author: Ala Luszczak <ala@databricks.com> Closes #18030 from ala/fix-generate-unsafe-projection.
This commit is contained in:
parent
92580bd0ea
commit
ce8edb8bf4
|
@ -50,10 +50,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
|
|||
fieldTypes: Seq[DataType],
|
||||
bufferHolder: String): String = {
|
||||
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
|
||||
val fieldName = ctx.freshName("fieldName")
|
||||
val code = s"final ${ctx.javaType(dt)} $fieldName = ${ctx.getValue(input, dt, i.toString)};"
|
||||
val isNull = s"$input.isNullAt($i)"
|
||||
ExprCode(code, isNull, fieldName)
|
||||
val javaType = ctx.javaType(dt)
|
||||
val isNullVar = ctx.freshName("isNull")
|
||||
val valueVar = ctx.freshName("value")
|
||||
val defaultValue = ctx.defaultValue(dt)
|
||||
val readValue = ctx.getValue(input, dt, i.toString)
|
||||
val code =
|
||||
s"""
|
||||
boolean $isNullVar = $input.isNullAt($i);
|
||||
$javaType $valueVar = $isNullVar ? $defaultValue : $readValue;
|
||||
"""
|
||||
ExprCode(code, isNullVar, valueVar)
|
||||
}
|
||||
|
||||
s"""
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.expressions.codegen
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.BoundReference
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
|
||||
import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType}
|
||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||
|
||||
class GenerateUnsafeProjectionSuite extends SparkFunSuite {
|
||||
test("Test unsafe projection string access pattern") {
|
||||
val dataType = (new StructType).add("a", StringType)
|
||||
val exprs = BoundReference(0, dataType, nullable = true) :: Nil
|
||||
val projection = GenerateUnsafeProjection.generate(exprs)
|
||||
val result = projection.apply(InternalRow(AlwaysNull))
|
||||
assert(!result.isNullAt(0))
|
||||
assert(result.getStruct(0, 1).isNullAt(0))
|
||||
}
|
||||
}
|
||||
|
||||
object AlwaysNull extends InternalRow {
|
||||
override def numFields: Int = 1
|
||||
override def setNullAt(i: Int): Unit = {}
|
||||
override def copy(): InternalRow = this
|
||||
override def anyNull: Boolean = true
|
||||
override def isNullAt(ordinal: Int): Boolean = true
|
||||
override def update(i: Int, value: Any): Unit = notSupported
|
||||
override def getBoolean(ordinal: Int): Boolean = notSupported
|
||||
override def getByte(ordinal: Int): Byte = notSupported
|
||||
override def getShort(ordinal: Int): Short = notSupported
|
||||
override def getInt(ordinal: Int): Int = notSupported
|
||||
override def getLong(ordinal: Int): Long = notSupported
|
||||
override def getFloat(ordinal: Int): Float = notSupported
|
||||
override def getDouble(ordinal: Int): Double = notSupported
|
||||
override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported
|
||||
override def getUTF8String(ordinal: Int): UTF8String = notSupported
|
||||
override def getBinary(ordinal: Int): Array[Byte] = notSupported
|
||||
override def getInterval(ordinal: Int): CalendarInterval = notSupported
|
||||
override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported
|
||||
override def getArray(ordinal: Int): ArrayData = notSupported
|
||||
override def getMap(ordinal: Int): MapData = notSupported
|
||||
override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported
|
||||
private def notSupported: Nothing = throw new UnsupportedOperationException
|
||||
}
|
|
@ -198,21 +198,25 @@ public final class ColumnarBatch {
|
|||
|
||||
@Override
|
||||
public Decimal getDecimal(int ordinal, int precision, int scale) {
|
||||
if (columns[ordinal].isNullAt(rowId)) return null;
|
||||
return columns[ordinal].getDecimal(rowId, precision, scale);
|
||||
}
|
||||
|
||||
@Override
|
||||
public UTF8String getUTF8String(int ordinal) {
|
||||
if (columns[ordinal].isNullAt(rowId)) return null;
|
||||
return columns[ordinal].getUTF8String(rowId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] getBinary(int ordinal) {
|
||||
if (columns[ordinal].isNullAt(rowId)) return null;
|
||||
return columns[ordinal].getBinary(rowId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public CalendarInterval getInterval(int ordinal) {
|
||||
if (columns[ordinal].isNullAt(rowId)) return null;
|
||||
final int months = columns[ordinal].getChildColumn(0).getInt(rowId);
|
||||
final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId);
|
||||
return new CalendarInterval(months, microseconds);
|
||||
|
@ -220,11 +224,13 @@ public final class ColumnarBatch {
|
|||
|
||||
@Override
|
||||
public InternalRow getStruct(int ordinal, int numFields) {
|
||||
if (columns[ordinal].isNullAt(rowId)) return null;
|
||||
return columns[ordinal].getStruct(rowId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ArrayData getArray(int ordinal) {
|
||||
if (columns[ordinal].isNullAt(rowId)) return null;
|
||||
return columns[ordinal].getArray(rowId);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue