[SQL] SPARK-1364 Improve datatype and test coverage for ScalaReflection schema inference.

Author: Michael Armbrust <michael@databricks.com>

Closes #293 from marmbrus/reflectTypes and squashes the following commits:

f54e8e8 [Michael Armbrust] Improve datatype and test coverage for ScalaReflection schema inference.
This commit is contained in:
Michael Armbrust 2014-04-02 18:14:31 -07:00 committed by Patrick Wendell
parent 9c65fa76f9
commit 47ebea5468
2 changed files with 66 additions and 0 deletions

View file

@ -43,15 +43,25 @@ object ScalaReflection {
val params = t.member("<init>": TermName).asMethod.paramss
StructType(
params.head.map(p => StructField(p.name.toString, schemaFor(p.typeSignature), true)))
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => BinaryType
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
ArrayType(schemaFor(elementType))
case t if t <:< typeOf[Map[_,_]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
MapType(schemaFor(keyType), schemaFor(valueType))
case t if t <:< typeOf[String] => StringType
case t if t <:< definitions.IntTpe => IntegerType
case t if t <:< definitions.LongTpe => LongType
case t if t <:< definitions.FloatTpe => FloatType
case t if t <:< definitions.DoubleTpe => DoubleType
case t if t <:< definitions.ShortTpe => ShortType
case t if t <:< definitions.ByteTpe => ByteType
case t if t <:< definitions.BooleanTpe => BooleanType
case t if t <:< typeOf[BigDecimal] => DecimalType
}
implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {

View file

@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.scalatest.FunSuite
import org.apache.spark.sql.test.TestSQLContext._
case class ReflectData(
stringField: String,
intField: Int,
longField: Long,
floatField: Float,
doubleField: Double,
shortField: Short,
byteField: Byte,
booleanField: Boolean,
decimalField: BigDecimal,
seqInt: Seq[Int])
case class ReflectBinary(data: Array[Byte])
class ScalaReflectionRelationSuite extends FunSuite {
test("query case class RDD") {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
BigDecimal(1), Seq(1,2,3))
val rdd = sparkContext.parallelize(data :: Nil)
rdd.registerAsTable("reflectData")
assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq)
}
// Equality is broken for Arrays, so we test that separately.
test("query binary data") {
val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil)
rdd.registerAsTable("reflectBinary")
val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]]
assert(result.toSeq === Seq[Byte](1))
}
}