diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 0f0f200122..b14c66cc5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -191,7 +191,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getBoolean(i: Int): Boolean = getAs[Boolean](i) + def getBoolean(i: Int): Boolean = getAnyValAs[Boolean](i) /** * Returns the value at position i as a primitive byte. @@ -199,7 +199,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getByte(i: Int): Byte = getAs[Byte](i) + def getByte(i: Int): Byte = getAnyValAs[Byte](i) /** * Returns the value at position i as a primitive short. @@ -207,7 +207,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getShort(i: Int): Short = getAs[Short](i) + def getShort(i: Int): Short = getAnyValAs[Short](i) /** * Returns the value at position i as a primitive int. @@ -215,7 +215,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getInt(i: Int): Int = getAs[Int](i) + def getInt(i: Int): Int = getAnyValAs[Int](i) /** * Returns the value at position i as a primitive long. @@ -223,7 +223,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getLong(i: Int): Long = getAs[Long](i) + def getLong(i: Int): Long = getAnyValAs[Long](i) /** * Returns the value at position i as a primitive float. @@ -232,7 +232,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getFloat(i: Int): Float = getAs[Float](i) + def getFloat(i: Int): Float = getAnyValAs[Float](i) /** * Returns the value at position i as a primitive double. @@ -240,13 +240,12 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getDouble(i: Int): Double = getAs[Double](i) + def getDouble(i: Int): Double = getAnyValAs[Double](i) /** * Returns the value at position i as a String object. * * @throws ClassCastException when data type does not match. - * @throws NullPointerException when value is null. */ def getString(i: Int): String = getAs[String](i) @@ -318,6 +317,8 @@ trait Row extends Serializable { /** * Returns the value at position i. + * For primitive types if value is null it returns 'zero value' specific for primitive + * ie. 0 for Int - use isNullAt to ensure that value is not null * * @throws ClassCastException when data type does not match. */ @@ -325,6 +326,8 @@ trait Row extends Serializable { /** * Returns the value of a given fieldName. + * For primitive types if value is null it returns 'zero value' specific for primitive + * ie. 0 for Int - use isNullAt to ensure that value is not null * * @throws UnsupportedOperationException when schema is not defined. * @throws IllegalArgumentException when fieldName do not exist. @@ -344,6 +347,8 @@ trait Row extends Serializable { /** * Returns a Map(name -> value) for the requested fieldNames + * For primitive types if value is null it returns 'zero value' specific for primitive + * ie. 0 for Int - use isNullAt to ensure that value is not null * * @throws UnsupportedOperationException when schema is not defined. * @throws IllegalArgumentException when fieldName do not exist. @@ -458,4 +463,15 @@ trait Row extends Serializable { * start, end, and separator strings. */ def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) + + /** + * Returns the value of a given fieldName. + * + * @throws UnsupportedOperationException when schema is not defined. + * @throws ClassCastException when data type does not match. + * @throws NullPointerException when value is null. + */ + private def getAnyValAs[T <: AnyVal](i: Int): T = + if (isNullAt(i)) throw new NullPointerException(s"Value at index $i in null") + else getAs[T](i) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index 01ff84cb56..5c22a72192 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -29,8 +29,10 @@ class RowTest extends FunSpec with Matchers { StructField("col2", StringType) :: StructField("col3", IntegerType) :: Nil) val values = Array("value1", "value2", 1) + val valuesWithoutCol3 = Array[Any](null, "value2", null) val sampleRow: Row = new GenericRowWithSchema(values, schema) + val sampleRowWithoutCol3: Row = new GenericRowWithSchema(valuesWithoutCol3, schema) val noSchemaRow: Row = new GenericRow(values) describe("Row (without schema)") { @@ -68,6 +70,24 @@ class RowTest extends FunSpec with Matchers { ) sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected } + + it("getValuesMap() retrieves null value on non AnyVal Type") { + val expected = Map( + "col1" -> null, + "col2" -> "value2" + ) + sampleRowWithoutCol3.getValuesMap[String](List("col1", "col2")) shouldBe expected + } + + it("getAs() on type extending AnyVal throws an exception when accessing field that is null") { + intercept[NullPointerException] { + sampleRowWithoutCol3.getInt(sampleRowWithoutCol3.fieldIndex("col3")) + } + } + + it("getAs() on type extending AnyVal does not throw exception when value is null"){ + sampleRowWithoutCol3.getAs[String](sampleRowWithoutCol3.fieldIndex("col1")) shouldBe null + } } describe("row equals") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala index 252f7cc897..45df2ea655 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -58,8 +58,14 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { val hashJoinNode = makeUnsafeNode(leftNode, rightNode) val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType) val actualOutput = hashJoinNode.collect().map { row => - // (id, name, id, nickname) - (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + // ( + // id, name, + // id, nickname + // ) + ( + Option(row.get(0)).map(_.asInstanceOf[Int]), Option(row.getString(1)), + Option(row.get(2)).map(_.asInstanceOf[Int]), Option(row.getString(3)) + ) } assert(actualOutput.toSet === expectedOutput.toSet) } @@ -95,36 +101,36 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest { private def generateExpectedOutput( leftInput: Array[(Int, String)], rightInput: Array[(Int, String)], - joinType: JoinType): Array[(Int, String, Int, String)] = { + joinType: JoinType): Array[(Option[Int], Option[String], Option[Int], Option[String])] = { joinType match { case LeftOuter => val rightInputMap = rightInput.toMap leftInput.map { case (k, v) => - val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) - val rightValue = rightInputMap.getOrElse(k, null) - (k, v, rightKey, rightValue) + val rightKey = rightInputMap.get(k).map { _ => k } + val rightValue = rightInputMap.get(k) + (Some(k), Some(v), rightKey, rightValue) } case RightOuter => val leftInputMap = leftInput.toMap rightInput.map { case (k, v) => - val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) - val leftValue = leftInputMap.getOrElse(k, null) - (leftKey, leftValue, k, v) + val leftKey = leftInputMap.get(k).map { _ => k } + val leftValue = leftInputMap.get(k) + (leftKey, leftValue, Some(k), Some(v)) } case FullOuter => val leftInputMap = leftInput.toMap val rightInputMap = rightInput.toMap val leftOutput = leftInput.map { case (k, v) => - val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) - val rightValue = rightInputMap.getOrElse(k, null) - (k, v, rightKey, rightValue) + val rightKey = rightInputMap.get(k).map { _ => k } + val rightValue = rightInputMap.get(k) + (Some(k), Some(v), rightKey, rightValue) } val rightOutput = rightInput.map { case (k, v) => - val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) - val leftValue = leftInputMap.getOrElse(k, null) - (leftKey, leftValue, k, v) + val leftKey = leftInputMap.get(k).map { _ => k } + val leftValue = leftInputMap.get(k) + (leftKey, leftValue, Some(k), Some(v)) } (leftOutput ++ rightOutput).distinct