From ce8edb8bf4db5f82bcfeb11efbdf5229b0d25dfa Mon Sep 17 00:00:00 2001 From: Ala Luszczak Date: Fri, 19 May 2017 13:18:48 +0200 Subject: [PATCH] [SPARK-20798] GenerateUnsafeProjection should check if a value is null before calling the getter ## What changes were proposed in this pull request? GenerateUnsafeProjection.writeStructToBuffer() did not honor the assumption that the caller must make sure that a value is not null before using the getter. This could lead to various errors. This change fixes that behavior. Example of code generated before: ```scala /* 059 */ final UTF8String fieldName = value.getUTF8String(0); /* 060 */ if (value.isNullAt(0)) { /* 061 */ rowWriter1.setNullAt(0); /* 062 */ } else { /* 063 */ rowWriter1.write(0, fieldName); /* 064 */ } ``` Example of code generated now: ```scala /* 060 */ boolean isNull1 = value.isNullAt(0); /* 061 */ UTF8String value1 = isNull1 ? null : value.getUTF8String(0); /* 062 */ if (isNull1) { /* 063 */ rowWriter1.setNullAt(0); /* 064 */ } else { /* 065 */ rowWriter1.write(0, value1); /* 066 */ } ``` ## How was this patch tested? Adds GenerateUnsafeProjectionSuite. Author: Ala Luszczak Closes #18030 from ala/fix-generate-unsafe-projection. --- .../codegen/GenerateUnsafeProjection.scala | 15 +++-- .../GenerateUnsafeProjectionSuite.scala | 61 +++++++++++++++++++ .../execution/vectorized/ColumnarBatch.java | 6 ++ 3 files changed, 78 insertions(+), 4 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7e4c9089a2..b358102d91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -50,10 +50,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro fieldTypes: Seq[DataType], bufferHolder: String): String = { val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - val fieldName = ctx.freshName("fieldName") - val code = s"final ${ctx.javaType(dt)} $fieldName = ${ctx.getValue(input, dt, i.toString)};" - val isNull = s"$input.isNullAt($i)" - ExprCode(code, isNull, fieldName) + val javaType = ctx.javaType(dt) + val isNullVar = ctx.freshName("isNull") + val valueVar = ctx.freshName("value") + val defaultValue = ctx.defaultValue(dt) + val readValue = ctx.getValue(input, dt, i.toString) + val code = + s""" + boolean $isNullVar = $input.isNullAt($i); + $javaType $valueVar = $isNullVar ? $defaultValue : $readValue; + """ + ExprCode(code, isNullVar, valueVar) } s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala new file mode 100644 index 0000000000..e9d21f8a8e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types.{DataType, Decimal, StringType, StructType} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +class GenerateUnsafeProjectionSuite extends SparkFunSuite { + test("Test unsafe projection string access pattern") { + val dataType = (new StructType).add("a", StringType) + val exprs = BoundReference(0, dataType, nullable = true) :: Nil + val projection = GenerateUnsafeProjection.generate(exprs) + val result = projection.apply(InternalRow(AlwaysNull)) + assert(!result.isNullAt(0)) + assert(result.getStruct(0, 1).isNullAt(0)) + } +} + +object AlwaysNull extends InternalRow { + override def numFields: Int = 1 + override def setNullAt(i: Int): Unit = {} + override def copy(): InternalRow = this + override def anyNull: Boolean = true + override def isNullAt(ordinal: Int): Boolean = true + override def update(i: Int, value: Any): Unit = notSupported + override def getBoolean(ordinal: Int): Boolean = notSupported + override def getByte(ordinal: Int): Byte = notSupported + override def getShort(ordinal: Int): Short = notSupported + override def getInt(ordinal: Int): Int = notSupported + override def getLong(ordinal: Int): Long = notSupported + override def getFloat(ordinal: Int): Float = notSupported + override def getDouble(ordinal: Int): Double = notSupported + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = notSupported + override def getUTF8String(ordinal: Int): UTF8String = notSupported + override def getBinary(ordinal: Int): Array[Byte] = notSupported + override def getInterval(ordinal: Int): CalendarInterval = notSupported + override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported + override def getArray(ordinal: Int): ArrayData = notSupported + override def getMap(ordinal: Int): MapData = notSupported + override def get(ordinal: Int, dataType: DataType): AnyRef = notSupported + private def notSupported: Nothing = throw new UnsupportedOperationException +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index a6ce4c2edc..8b7b0e655b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -198,21 +198,25 @@ public final class ColumnarBatch { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { + if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getDecimal(rowId, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getUTF8String(rowId); } @Override public byte[] getBinary(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getBinary(rowId); } @Override public CalendarInterval getInterval(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; final int months = columns[ordinal].getChildColumn(0).getInt(rowId); final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); return new CalendarInterval(months, microseconds); @@ -220,11 +224,13 @@ public final class ColumnarBatch { @Override public InternalRow getStruct(int ordinal, int numFields) { + if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getStruct(rowId); } @Override public ArrayData getArray(int ordinal) { + if (columns[ordinal].isNullAt(rowId)) return null; return columns[ordinal].getArray(rowId); }