[SPARK-11553][SQL] Primitive Row accessors should not convert null to default value

Invocation of getters for type extending AnyVal returns default value (if field value is null) instead of throwing NPE. Please check comments for SPARK-11553 issue for more details.

Author: Bartlomiej Alberski <bartlomiej.alberski@allegrogroup.com>

Closes #9642 from alberskib/bugfix/SPARK-11553.
This commit is contained in:
Bartlomiej Alberski 2015-11-16 15:14:38 -08:00 committed by Michael Armbrust
parent bcea0bfda6
commit 31296628ac
3 changed files with 65 additions and 23 deletions

View file

@ -191,7 +191,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match. * @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null. * @throws NullPointerException when value is null.
*/ */
def getBoolean(i: Int): Boolean = getAs[Boolean](i) def getBoolean(i: Int): Boolean = getAnyValAs[Boolean](i)
/** /**
* Returns the value at position i as a primitive byte. * Returns the value at position i as a primitive byte.
@ -199,7 +199,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match. * @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null. * @throws NullPointerException when value is null.
*/ */
def getByte(i: Int): Byte = getAs[Byte](i) def getByte(i: Int): Byte = getAnyValAs[Byte](i)
/** /**
* Returns the value at position i as a primitive short. * Returns the value at position i as a primitive short.
@ -207,7 +207,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match. * @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null. * @throws NullPointerException when value is null.
*/ */
def getShort(i: Int): Short = getAs[Short](i) def getShort(i: Int): Short = getAnyValAs[Short](i)
/** /**
* Returns the value at position i as a primitive int. * Returns the value at position i as a primitive int.
@ -215,7 +215,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match. * @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null. * @throws NullPointerException when value is null.
*/ */
def getInt(i: Int): Int = getAs[Int](i) def getInt(i: Int): Int = getAnyValAs[Int](i)
/** /**
* Returns the value at position i as a primitive long. * Returns the value at position i as a primitive long.
@ -223,7 +223,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match. * @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null. * @throws NullPointerException when value is null.
*/ */
def getLong(i: Int): Long = getAs[Long](i) def getLong(i: Int): Long = getAnyValAs[Long](i)
/** /**
* Returns the value at position i as a primitive float. * Returns the value at position i as a primitive float.
@ -232,7 +232,7 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match. * @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null. * @throws NullPointerException when value is null.
*/ */
def getFloat(i: Int): Float = getAs[Float](i) def getFloat(i: Int): Float = getAnyValAs[Float](i)
/** /**
* Returns the value at position i as a primitive double. * Returns the value at position i as a primitive double.
@ -240,13 +240,12 @@ trait Row extends Serializable {
* @throws ClassCastException when data type does not match. * @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null. * @throws NullPointerException when value is null.
*/ */
def getDouble(i: Int): Double = getAs[Double](i) def getDouble(i: Int): Double = getAnyValAs[Double](i)
/** /**
* Returns the value at position i as a String object. * Returns the value at position i as a String object.
* *
* @throws ClassCastException when data type does not match. * @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/ */
def getString(i: Int): String = getAs[String](i) def getString(i: Int): String = getAs[String](i)
@ -318,6 +317,8 @@ trait Row extends Serializable {
/** /**
* Returns the value at position i. * Returns the value at position i.
* For primitive types if value is null it returns 'zero value' specific for primitive
* ie. 0 for Int - use isNullAt to ensure that value is not null
* *
* @throws ClassCastException when data type does not match. * @throws ClassCastException when data type does not match.
*/ */
@ -325,6 +326,8 @@ trait Row extends Serializable {
/** /**
* Returns the value of a given fieldName. * Returns the value of a given fieldName.
* For primitive types if value is null it returns 'zero value' specific for primitive
* ie. 0 for Int - use isNullAt to ensure that value is not null
* *
* @throws UnsupportedOperationException when schema is not defined. * @throws UnsupportedOperationException when schema is not defined.
* @throws IllegalArgumentException when fieldName do not exist. * @throws IllegalArgumentException when fieldName do not exist.
@ -344,6 +347,8 @@ trait Row extends Serializable {
/** /**
* Returns a Map(name -> value) for the requested fieldNames * Returns a Map(name -> value) for the requested fieldNames
* For primitive types if value is null it returns 'zero value' specific for primitive
* ie. 0 for Int - use isNullAt to ensure that value is not null
* *
* @throws UnsupportedOperationException when schema is not defined. * @throws UnsupportedOperationException when schema is not defined.
* @throws IllegalArgumentException when fieldName do not exist. * @throws IllegalArgumentException when fieldName do not exist.
@ -458,4 +463,15 @@ trait Row extends Serializable {
* start, end, and separator strings. * start, end, and separator strings.
*/ */
def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end)
/**
* Returns the value of a given fieldName.
*
* @throws UnsupportedOperationException when schema is not defined.
* @throws ClassCastException when data type does not match.
* @throws NullPointerException when value is null.
*/
private def getAnyValAs[T <: AnyVal](i: Int): T =
if (isNullAt(i)) throw new NullPointerException(s"Value at index $i in null")
else getAs[T](i)
} }

View file

@ -29,8 +29,10 @@ class RowTest extends FunSpec with Matchers {
StructField("col2", StringType) :: StructField("col2", StringType) ::
StructField("col3", IntegerType) :: Nil) StructField("col3", IntegerType) :: Nil)
val values = Array("value1", "value2", 1) val values = Array("value1", "value2", 1)
val valuesWithoutCol3 = Array[Any](null, "value2", null)
val sampleRow: Row = new GenericRowWithSchema(values, schema) val sampleRow: Row = new GenericRowWithSchema(values, schema)
val sampleRowWithoutCol3: Row = new GenericRowWithSchema(valuesWithoutCol3, schema)
val noSchemaRow: Row = new GenericRow(values) val noSchemaRow: Row = new GenericRow(values)
describe("Row (without schema)") { describe("Row (without schema)") {
@ -68,6 +70,24 @@ class RowTest extends FunSpec with Matchers {
) )
sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected
} }
it("getValuesMap() retrieves null value on non AnyVal Type") {
val expected = Map(
"col1" -> null,
"col2" -> "value2"
)
sampleRowWithoutCol3.getValuesMap[String](List("col1", "col2")) shouldBe expected
}
it("getAs() on type extending AnyVal throws an exception when accessing field that is null") {
intercept[NullPointerException] {
sampleRowWithoutCol3.getInt(sampleRowWithoutCol3.fieldIndex("col3"))
}
}
it("getAs() on type extending AnyVal does not throw exception when value is null"){
sampleRowWithoutCol3.getAs[String](sampleRowWithoutCol3.fieldIndex("col1")) shouldBe null
}
} }
describe("row equals") { describe("row equals") {

View file

@ -58,8 +58,14 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
val hashJoinNode = makeUnsafeNode(leftNode, rightNode) val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType) val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType)
val actualOutput = hashJoinNode.collect().map { row => val actualOutput = hashJoinNode.collect().map { row =>
// (id, name, id, nickname) // (
(row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) // id, name,
// id, nickname
// )
(
Option(row.get(0)).map(_.asInstanceOf[Int]), Option(row.getString(1)),
Option(row.get(2)).map(_.asInstanceOf[Int]), Option(row.getString(3))
)
} }
assert(actualOutput.toSet === expectedOutput.toSet) assert(actualOutput.toSet === expectedOutput.toSet)
} }
@ -95,36 +101,36 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
private def generateExpectedOutput( private def generateExpectedOutput(
leftInput: Array[(Int, String)], leftInput: Array[(Int, String)],
rightInput: Array[(Int, String)], rightInput: Array[(Int, String)],
joinType: JoinType): Array[(Int, String, Int, String)] = { joinType: JoinType): Array[(Option[Int], Option[String], Option[Int], Option[String])] = {
joinType match { joinType match {
case LeftOuter => case LeftOuter =>
val rightInputMap = rightInput.toMap val rightInputMap = rightInput.toMap
leftInput.map { case (k, v) => leftInput.map { case (k, v) =>
val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) val rightKey = rightInputMap.get(k).map { _ => k }
val rightValue = rightInputMap.getOrElse(k, null) val rightValue = rightInputMap.get(k)
(k, v, rightKey, rightValue) (Some(k), Some(v), rightKey, rightValue)
} }
case RightOuter => case RightOuter =>
val leftInputMap = leftInput.toMap val leftInputMap = leftInput.toMap
rightInput.map { case (k, v) => rightInput.map { case (k, v) =>
val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) val leftKey = leftInputMap.get(k).map { _ => k }
val leftValue = leftInputMap.getOrElse(k, null) val leftValue = leftInputMap.get(k)
(leftKey, leftValue, k, v) (leftKey, leftValue, Some(k), Some(v))
} }
case FullOuter => case FullOuter =>
val leftInputMap = leftInput.toMap val leftInputMap = leftInput.toMap
val rightInputMap = rightInput.toMap val rightInputMap = rightInput.toMap
val leftOutput = leftInput.map { case (k, v) => val leftOutput = leftInput.map { case (k, v) =>
val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) val rightKey = rightInputMap.get(k).map { _ => k }
val rightValue = rightInputMap.getOrElse(k, null) val rightValue = rightInputMap.get(k)
(k, v, rightKey, rightValue) (Some(k), Some(v), rightKey, rightValue)
} }
val rightOutput = rightInput.map { case (k, v) => val rightOutput = rightInput.map { case (k, v) =>
val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) val leftKey = leftInputMap.get(k).map { _ => k }
val leftValue = leftInputMap.getOrElse(k, null) val leftValue = leftInputMap.get(k)
(leftKey, leftValue, k, v) (leftKey, leftValue, Some(k), Some(v))
} }
(leftOutput ++ rightOutput).distinct (leftOutput ++ rightOutput).distinct