[SQL] SPARK-1732 - Support for null primitive values.
I also removed a println that I bumped into. Author: Michael Armbrust <michael@databricks.com> Closes #658 from marmbrus/nullPrimitives and squashes the following commits: a3ec4f3 [Michael Armbrust] Remove println. 695606b [Michael Armbrust] Support for null primatives from using scala and java reflection.
This commit is contained in:
parent
a2262cdb7a
commit
3c64750bdd
|
@ -44,7 +44,8 @@ object ScalaReflection {
|
|||
case t if t <:< typeOf[Product] =>
|
||||
val params = t.member("<init>": TermName).asMethod.paramss
|
||||
StructType(
|
||||
params.head.map(p => StructField(p.name.toString, schemaFor(p.typeSignature), true)))
|
||||
params.head.map(p =>
|
||||
StructField(p.name.toString, schemaFor(p.typeSignature), nullable = true)))
|
||||
// Need to decide if we actually need a special type here.
|
||||
case t if t <:< typeOf[Array[Byte]] => BinaryType
|
||||
case t if t <:< typeOf[Array[_]] =>
|
||||
|
@ -58,6 +59,17 @@ object ScalaReflection {
|
|||
case t if t <:< typeOf[String] => StringType
|
||||
case t if t <:< typeOf[Timestamp] => TimestampType
|
||||
case t if t <:< typeOf[BigDecimal] => DecimalType
|
||||
case t if t <:< typeOf[Option[_]] =>
|
||||
val TypeRef(_, _, Seq(optType)) = t
|
||||
schemaFor(optType)
|
||||
case t if t <:< typeOf[java.lang.Integer] => IntegerType
|
||||
case t if t <:< typeOf[java.lang.Long] => LongType
|
||||
case t if t <:< typeOf[java.lang.Double] => DoubleType
|
||||
case t if t <:< typeOf[java.lang.Float] => FloatType
|
||||
case t if t <:< typeOf[java.lang.Short] => ShortType
|
||||
case t if t <:< typeOf[java.lang.Byte] => ByteType
|
||||
case t if t <:< typeOf[java.lang.Boolean] => BooleanType
|
||||
// TODO: The following datatypes could be marked as non-nullable.
|
||||
case t if t <:< definitions.IntTpe => IntegerType
|
||||
case t if t <:< definitions.LongTpe => LongType
|
||||
case t if t <:< definitions.DoubleTpe => DoubleType
|
||||
|
|
|
@ -132,6 +132,14 @@ class JavaSQLContext(sparkContext: JavaSparkContext) {
|
|||
case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
|
||||
case c: Class[_] if c == java.lang.Float.TYPE => FloatType
|
||||
case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
|
||||
|
||||
case c: Class[_] if c == classOf[java.lang.Short] => ShortType
|
||||
case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
|
||||
case c: Class[_] if c == classOf[java.lang.Long] => LongType
|
||||
case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
|
||||
case c: Class[_] if c == classOf[java.lang.Byte] => ByteType
|
||||
case c: Class[_] if c == classOf[java.lang.Float] => FloatType
|
||||
case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
|
||||
}
|
||||
// TODO: Nullability could be stricter.
|
||||
AttributeReference(property.getName, dataType, nullable = true)()
|
||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow}
|
|||
/**
|
||||
* A result row from a SparkSQL query.
|
||||
*/
|
||||
class Row(row: ScalaRow) extends Serializable {
|
||||
class Row(private[spark] val row: ScalaRow) extends Serializable {
|
||||
|
||||
/** Returns the number of columns present in this Row. */
|
||||
def length: Int = row.length
|
||||
|
|
|
@ -164,6 +164,7 @@ case class Sort(
|
|||
@DeveloperApi
|
||||
object ExistingRdd {
|
||||
def convertToCatalyst(a: Any): Any = a match {
|
||||
case o: Option[_] => o.orNull
|
||||
case s: Seq[Any] => s.map(convertToCatalyst)
|
||||
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
|
||||
case other => other
|
||||
|
@ -180,7 +181,7 @@ object ExistingRdd {
|
|||
bufferedIterator.map { r =>
|
||||
var i = 0
|
||||
while (i < mutableRow.length) {
|
||||
mutableRow(i) = r.productElement(i)
|
||||
mutableRow(i) = convertToCatalyst(r.productElement(i))
|
||||
i += 1
|
||||
}
|
||||
|
||||
|
|
|
@ -36,6 +36,24 @@ case class ReflectData(
|
|||
timestampField: Timestamp,
|
||||
seqInt: Seq[Int])
|
||||
|
||||
case class NullReflectData(
|
||||
intField: java.lang.Integer,
|
||||
longField: java.lang.Long,
|
||||
floatField: java.lang.Float,
|
||||
doubleField: java.lang.Double,
|
||||
shortField: java.lang.Short,
|
||||
byteField: java.lang.Byte,
|
||||
booleanField: java.lang.Boolean)
|
||||
|
||||
case class OptionalReflectData(
|
||||
intField: Option[Int],
|
||||
longField: Option[Long],
|
||||
floatField: Option[Float],
|
||||
doubleField: Option[Double],
|
||||
shortField: Option[Short],
|
||||
byteField: Option[Byte],
|
||||
booleanField: Option[Boolean])
|
||||
|
||||
case class ReflectBinary(data: Array[Byte])
|
||||
|
||||
class ScalaReflectionRelationSuite extends FunSuite {
|
||||
|
@ -48,6 +66,22 @@ class ScalaReflectionRelationSuite extends FunSuite {
|
|||
assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq)
|
||||
}
|
||||
|
||||
test("query case class RDD with nulls") {
|
||||
val data = NullReflectData(null, null, null, null, null, null, null)
|
||||
val rdd = sparkContext.parallelize(data :: Nil)
|
||||
rdd.registerAsTable("reflectNullData")
|
||||
|
||||
assert(sql("SELECT * FROM reflectNullData").collect().head === Seq.fill(7)(null))
|
||||
}
|
||||
|
||||
test("query case class RDD with Nones") {
|
||||
val data = OptionalReflectData(None, None, None, None, None, None, None)
|
||||
val rdd = sparkContext.parallelize(data :: Nil)
|
||||
rdd.registerAsTable("reflectOptionalData")
|
||||
|
||||
assert(sql("SELECT * FROM reflectOptionalData").collect().head === Seq.fill(7)(null))
|
||||
}
|
||||
|
||||
// Equality is broken for Arrays, so we test that separately.
|
||||
test("query binary data") {
|
||||
val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil)
|
||||
|
|
|
@ -35,6 +35,17 @@ class PersonBean extends Serializable {
|
|||
var age: Int = _
|
||||
}
|
||||
|
||||
class AllTypesBean extends Serializable {
|
||||
@BeanProperty var stringField: String = _
|
||||
@BeanProperty var intField: java.lang.Integer = _
|
||||
@BeanProperty var longField: java.lang.Long = _
|
||||
@BeanProperty var floatField: java.lang.Float = _
|
||||
@BeanProperty var doubleField: java.lang.Double = _
|
||||
@BeanProperty var shortField: java.lang.Short = _
|
||||
@BeanProperty var byteField: java.lang.Byte = _
|
||||
@BeanProperty var booleanField: java.lang.Boolean = _
|
||||
}
|
||||
|
||||
class JavaSQLSuite extends FunSuite {
|
||||
val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext)
|
||||
val javaSqlCtx = new JavaSQLContext(javaCtx)
|
||||
|
@ -50,4 +61,54 @@ class JavaSQLSuite extends FunSuite {
|
|||
schemaRDD.registerAsTable("people")
|
||||
javaSqlCtx.sql("SELECT * FROM people").collect()
|
||||
}
|
||||
|
||||
test("all types in JavaBeans") {
|
||||
val bean = new AllTypesBean
|
||||
bean.setStringField("")
|
||||
bean.setIntField(0)
|
||||
bean.setLongField(0)
|
||||
bean.setFloatField(0.0F)
|
||||
bean.setDoubleField(0.0)
|
||||
bean.setShortField(0.toShort)
|
||||
bean.setByteField(0.toByte)
|
||||
bean.setBooleanField(false)
|
||||
|
||||
val rdd = javaCtx.parallelize(bean :: Nil)
|
||||
val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean])
|
||||
schemaRDD.registerAsTable("allTypes")
|
||||
|
||||
assert(
|
||||
javaSqlCtx.sql(
|
||||
"""
|
||||
|SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField,
|
||||
| booleanField
|
||||
|FROM allTypes
|
||||
""".stripMargin).collect.head.row ===
|
||||
Seq("", 0, 0L, 0F, 0.0, 0.toShort, 0.toByte, false))
|
||||
}
|
||||
|
||||
test("all types null in JavaBeans") {
|
||||
val bean = new AllTypesBean
|
||||
bean.setStringField(null)
|
||||
bean.setIntField(null)
|
||||
bean.setLongField(null)
|
||||
bean.setFloatField(null)
|
||||
bean.setDoubleField(null)
|
||||
bean.setShortField(null)
|
||||
bean.setByteField(null)
|
||||
bean.setBooleanField(null)
|
||||
|
||||
val rdd = javaCtx.parallelize(bean :: Nil)
|
||||
val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean])
|
||||
schemaRDD.registerAsTable("allTypes")
|
||||
|
||||
assert(
|
||||
javaSqlCtx.sql(
|
||||
"""
|
||||
|SELECT stringField, intField, longField, floatField, doubleField, shortField, byteField,
|
||||
| booleanField
|
||||
|FROM allTypes
|
||||
""".stripMargin).collect.head.row ===
|
||||
Seq.fill(8)(null))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,11 +21,12 @@ import java.nio.ByteBuffer
|
|||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.sql.Logging
|
||||
import org.apache.spark.sql.catalyst.types._
|
||||
import org.apache.spark.sql.columnar.ColumnarTestUtils._
|
||||
import org.apache.spark.sql.execution.SparkSqlSerializer
|
||||
|
||||
class ColumnTypeSuite extends FunSuite {
|
||||
class ColumnTypeSuite extends FunSuite with Logging {
|
||||
val DEFAULT_BUFFER_SIZE = 512
|
||||
|
||||
test("defaultSize") {
|
||||
|
@ -163,7 +164,7 @@ class ColumnTypeSuite extends FunSuite {
|
|||
|
||||
buffer.rewind()
|
||||
seq.foreach { expected =>
|
||||
println("buffer = " + buffer + ", expected = " + expected)
|
||||
logger.info("buffer = " + buffer + ", expected = " + expected)
|
||||
val extracted = columnType.extract(buffer)
|
||||
assert(
|
||||
expected === extracted,
|
||||
|
|
Loading…
Reference in a new issue