[SPARK-4573] [SQL] Add SettableStructObjectInspector support in "wrap" function

Hive UDAF may create an customized object constructed by SettableStructObjectInspector, this is critical when integrate Hive UDAF with the refactor-ed UDAF interface.

Performance issue in `wrap/unwrap` since more match cases added, will do it in another PR.

Author: Cheng Hao <hao.cheng@intel.com>

Closes #3429 from chenghao-intel/settable_oi and squashes the following commits:

9f0aff3 [Cheng Hao] update code style issues as feedbacks
2b0561d [Cheng Hao] Add more scala doc
f5a40e8 [Cheng Hao] add scala doc
2977e9b [Cheng Hao] remove the timezone setting for test suite
3ed284c [Cheng Hao] fix the date type comparison
f1b6749 [Cheng Hao] Update the comment
932940d [Cheng Hao] Add more unit test
72e4332 [Cheng Hao] Add settable StructObjectInspector support
This commit is contained in:
Cheng Hao 2014-12-18 20:21:52 -08:00 committed by Michael Armbrust
parent 7687415c25
commit ae9f128608
4 changed files with 661 additions and 126 deletions

View file

@ -18,9 +18,7 @@
package org.apache.spark.sql.hive package org.apache.spark.sql.hive
import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar}
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory
import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector._
import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.primitive._
import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.{io => hadoopIo} import org.apache.hadoop.{io => hadoopIo}
@ -33,6 +31,145 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal
/* Implicit conversions */ /* Implicit conversions */
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
/**
* 1. The Underlying data type in catalyst and in Hive
* In catalyst:
* Primitive =>
* java.lang.String
* int / scala.Int
* boolean / scala.Boolean
* float / scala.Float
* double / scala.Double
* long / scala.Long
* short / scala.Short
* byte / scala.Byte
* org.apache.spark.sql.catalyst.types.decimal.Decimal
* Array[Byte]
* java.sql.Date
* java.sql.Timestamp
* Complex Types =>
* Map: scala.collection.immutable.Map
* List: scala.collection.immutable.Seq
* Struct:
* org.apache.spark.sql.catalyst.expression.Row
* Union: NOT SUPPORTED YET
* The Complex types plays as a container, which can hold arbitrary data types.
*
* In Hive, the native data types are various, in UDF/UDAF/UDTF, and associated with
* Object Inspectors, in Hive expression evaluation framework, the underlying data are
* Primitive Type
* Java Boxed Primitives:
* org.apache.hadoop.hive.common.type.HiveVarchar
* java.lang.String
* java.lang.Integer
* java.lang.Boolean
* java.lang.Float
* java.lang.Double
* java.lang.Long
* java.lang.Short
* java.lang.Byte
* org.apache.hadoop.hive.common.`type`.HiveDecimal
* byte[]
* java.sql.Date
* java.sql.Timestamp
* Writables:
* org.apache.hadoop.hive.serde2.io.HiveVarcharWritable
* org.apache.hadoop.io.Text
* org.apache.hadoop.io.IntWritable
* org.apache.hadoop.hive.serde2.io.DoubleWritable
* org.apache.hadoop.io.BooleanWritable
* org.apache.hadoop.io.LongWritable
* org.apache.hadoop.io.FloatWritable
* org.apache.hadoop.hive.serde2.io.ShortWritable
* org.apache.hadoop.hive.serde2.io.ByteWritable
* org.apache.hadoop.io.BytesWritable
* org.apache.hadoop.hive.serde2.io.DateWritable
* org.apache.hadoop.hive.serde2.io.TimestampWritable
* org.apache.hadoop.hive.serde2.io.HiveDecimalWritable
* Complex Type
* List: Object[] / java.util.List
* Map: java.util.Map
* Struct: Object[] / java.util.List / java POJO
* Union: class StandardUnion { byte tag; Object object }
*
* NOTICE: HiveVarchar is not supported by catalyst, it will be simply considered as String type.
*
*
* 2. Hive ObjectInspector is a group of flexible APIs to inspect value in different data
* representation, and developers can extend those API as needed, so technically,
* object inspector supports arbitrary data type in java.
*
* Fortunately, only few built-in Hive Object Inspectors are used in generic udf/udaf/udtf
* evaluation.
* 1) Primitive Types (PrimitiveObjectInspector & its sub classes)
{{{
public interface PrimitiveObjectInspector {
// Java Primitives (java.lang.Integer, java.lang.String etc.)
Object getPrimitiveWritableObject(Object o);
// Writables (hadoop.io.IntWritable, hadoop.io.Text etc.)
Object getPrimitiveJavaObject(Object o);
// ObjectInspector only inspect the `writable` always return true, we need to check it
// before invoking the methods above.
boolean preferWritable();
...
}
}}}
* 2) Complex Types:
* ListObjectInspector: inspects java array or [[java.util.List]]
* MapObjectInspector: inspects [[java.util.Map]]
* Struct.StructObjectInspector: inspects java array, [[java.util.List]] and
* even a normal java object (POJO)
* UnionObjectInspector: (tag: Int, object data) (TODO: not supported by SparkSQL yet)
*
* 3) ConstantObjectInspector:
* Constant object inspector can be either primitive type or Complex type, and it bundles a
* constant value as its property, usually the value is created when the constant object inspector
* constructed.
* {{{
public interface ConstantObjectInspector extends ObjectInspector {
Object getWritableConstantValue();
...
}
}}}
* Hive provides 3 built-in constant object inspectors:
* Primitive Object Inspectors:
* WritableConstantStringObjectInspector
* WritableConstantHiveVarcharObjectInspector
* WritableConstantHiveDecimalObjectInspector
* WritableConstantTimestampObjectInspector
* WritableConstantIntObjectInspector
* WritableConstantDoubleObjectInspector
* WritableConstantBooleanObjectInspector
* WritableConstantLongObjectInspector
* WritableConstantFloatObjectInspector
* WritableConstantShortObjectInspector
* WritableConstantByteObjectInspector
* WritableConstantBinaryObjectInspector
* WritableConstantDateObjectInspector
* Map Object Inspector:
* StandardConstantMapObjectInspector
* List Object Inspector:
* StandardConstantListObjectInspector]]
* Struct Object Inspector: Hive doesn't provide the built-in constant object inspector for Struct
* Union Object Inspector: Hive doesn't provide the built-in constant object inspector for Union
*
*
* 3. This trait facilitates:
* Data Unwrapping: Hive Data => Catalyst Data (unwrap)
* Data Wrapping: Catalyst Data => Hive Data (wrap)
* Binding the Object Inspector for Catalyst Data (toInspector)
* Retrieving the Catalyst Data Type from Object Inspector (inspectorToDataType)
*
*
* 4. Future Improvement (TODO)
* This implementation is quite ugly and inefficient:
* a. Pattern matching in runtime
* b. Small objects creation in catalyst data => writable
* c. Unnecessary unwrap / wrap for nested UDF invoking:
* e.g. date_add(printf("%s-%s-%s", a,b,c), 3)
* We don't need to unwrap the data for printf and wrap it again and passes in data_add
*/
private[hive] trait HiveInspectors { private[hive] trait HiveInspectors {
def javaClassToDataType(clz: Class[_]): DataType = clz match { def javaClassToDataType(clz: Class[_]): DataType = clz match {
@ -87,10 +224,23 @@ private[hive] trait HiveInspectors {
* @param oi the ObjectInspector associated with the Hive Type * @param oi the ObjectInspector associated with the Hive Type
* @return convert the data into catalyst type * @return convert the data into catalyst type
* TODO return the function of (data => Any) instead for performance consideration * TODO return the function of (data => Any) instead for performance consideration
*
* Strictly follows the following order in unwrapping (constant OI has the higher priority):
* Constant Null object inspector =>
* return null
* Constant object inspector =>
* extract the value from constant object inspector
* Check whether the `data` is null =>
* return null if true
* If object inspector prefers writable =>
* extract writable from `data` and then get the catalyst type from the writable
* Extract the java object directly from the object inspector
*
* NOTICE: the complex data type requires recursive unwrapping.
*/ */
def unwrap(data: Any, oi: ObjectInspector): Any = oi match { def unwrap(data: Any, oi: ObjectInspector): Any = oi match {
case _ if data == null => null case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null
case poi: VoidObjectInspector => null case poi: WritableConstantStringObjectInspector => poi.getWritableConstantValue.toString
case poi: WritableConstantHiveVarcharObjectInspector => case poi: WritableConstantHiveVarcharObjectInspector =>
poi.getWritableConstantValue.getHiveVarchar.getValue poi.getWritableConstantValue.getHiveVarchar.getValue
case poi: WritableConstantHiveDecimalObjectInspector => case poi: WritableConstantHiveDecimalObjectInspector =>
@ -119,12 +269,44 @@ private[hive] trait HiveInspectors {
System.arraycopy(writable.getBytes, 0, temp, 0, temp.length) System.arraycopy(writable.getBytes, 0, temp, 0, temp.length)
temp temp
case poi: WritableConstantDateObjectInspector => poi.getWritableConstantValue.get() case poi: WritableConstantDateObjectInspector => poi.getWritableConstantValue.get()
case hvoi: HiveVarcharObjectInspector => hvoi.getPrimitiveJavaObject(data).getValue case mi: StandardConstantMapObjectInspector =>
case hdoi: HiveDecimalObjectInspector => HiveShim.toCatalystDecimal(hdoi, data) // take the value from the map inspector object, rather than the input data
// org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object mi.getWritableConstantValue.map { case (k, v) =>
// if next timestamp is null, so Timestamp object is cloned (unwrap(k, mi.getMapKeyObjectInspector),
case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone() unwrap(v, mi.getMapValueObjectInspector))
case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data) }.toMap
case li: StandardConstantListObjectInspector =>
// take the value from the list inspector object, rather than the input data
li.getWritableConstantValue.map(unwrap(_, li.getListElementObjectInspector)).toSeq
// if the value is null, we don't care about the object inspector type
case _ if data == null => null
case poi: VoidObjectInspector => null // always be null for void object inspector
case pi: PrimitiveObjectInspector => pi match {
// We think HiveVarchar is also a String
case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() =>
hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue
case hvoi: HiveVarcharObjectInspector => hvoi.getPrimitiveJavaObject(data).getValue
case x: StringObjectInspector if x.preferWritable() =>
x.getPrimitiveWritableObject(data).toString
case x: IntObjectInspector if x.preferWritable() => x.get(data)
case x: BooleanObjectInspector if x.preferWritable() => x.get(data)
case x: FloatObjectInspector if x.preferWritable() => x.get(data)
case x: DoubleObjectInspector if x.preferWritable() => x.get(data)
case x: LongObjectInspector if x.preferWritable() => x.get(data)
case x: ShortObjectInspector if x.preferWritable() => x.get(data)
case x: ByteObjectInspector if x.preferWritable() => x.get(data)
case x: HiveDecimalObjectInspector => HiveShim.toCatalystDecimal(x, data)
case x: BinaryObjectInspector if x.preferWritable() =>
x.getPrimitiveWritableObject(data).copyBytes()
case x: DateObjectInspector if x.preferWritable() =>
x.getPrimitiveWritableObject(data).get()
// org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object
// if next timestamp is null, so Timestamp object is cloned
case x: TimestampObjectInspector if x.preferWritable() =>
x.getPrimitiveWritableObject(data).getTimestamp.clone()
case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone()
case _ => pi.getPrimitiveJavaObject(data)
}
case li: ListObjectInspector => case li: ListObjectInspector =>
Option(li.getList(data)) Option(li.getList(data))
.map(_.map(unwrap(_, li.getListElementObjectInspector)).toSeq) .map(_.map(unwrap(_, li.getListElementObjectInspector)).toSeq)
@ -132,10 +314,11 @@ private[hive] trait HiveInspectors {
case mi: MapObjectInspector => case mi: MapObjectInspector =>
Option(mi.getMap(data)).map( Option(mi.getMap(data)).map(
_.map { _.map {
case (k,v) => case (k, v) =>
(unwrap(k, mi.getMapKeyObjectInspector), (unwrap(k, mi.getMapKeyObjectInspector),
unwrap(v, mi.getMapValueObjectInspector)) unwrap(v, mi.getMapValueObjectInspector))
}.toMap).orNull }.toMap).orNull
// currently, hive doesn't provide the ConstantStructObjectInspector
case si: StructObjectInspector => case si: StructObjectInspector =>
val allRefs = si.getAllStructFieldRefs val allRefs = si.getAllStructFieldRefs
new GenericRow( new GenericRow(
@ -191,55 +374,89 @@ private[hive] trait HiveInspectors {
* the ObjectInspector should also be consistent with those returned from * the ObjectInspector should also be consistent with those returned from
* toInspector: DataType => ObjectInspector and * toInspector: DataType => ObjectInspector and
* toInspector: Expression => ObjectInspector * toInspector: Expression => ObjectInspector
*
* Strictly follows the following order in wrapping (constant OI has the higher priority):
* Constant object inspector => return the bundled value of Constant object inspector
* Check whether the `a` is null => return null if true
* If object inspector prefers writable object => return a Writable for the given data `a`
* Map the catalyst data to the boxed java primitive
*
* NOTICE: the complex data type requires recursive wrapping.
*/ */
def wrap(a: Any, oi: ObjectInspector): AnyRef = if (a == null) { def wrap(a: Any, oi: ObjectInspector): AnyRef = oi match {
null case x: ConstantObjectInspector => x.getWritableConstantValue
} else { case _ if a == null => null
oi match { case x: PrimitiveObjectInspector => x match {
case x: ConstantObjectInspector => x.getWritableConstantValue // TODO we don't support the HiveVarcharObjectInspector yet.
case x: PrimitiveObjectInspector => a match { case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a)
// TODO what if x.preferWritable() == true? reuse the writable? case _: StringObjectInspector => a.asInstanceOf[java.lang.String]
case s: String => s: java.lang.String case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a)
case i: Int => i: java.lang.Integer case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer]
case b: Boolean => b: java.lang.Boolean case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a)
case f: Float => f: java.lang.Float case _: BooleanObjectInspector => a.asInstanceOf[java.lang.Boolean]
case d: Double => d: java.lang.Double case _: FloatObjectInspector if x.preferWritable() => HiveShim.getFloatWritable(a)
case l: Long => l: java.lang.Long case _: FloatObjectInspector => a.asInstanceOf[java.lang.Float]
case l: Short => l: java.lang.Short case _: DoubleObjectInspector if x.preferWritable() => HiveShim.getDoubleWritable(a)
case l: Byte => l: java.lang.Byte case _: DoubleObjectInspector => a.asInstanceOf[java.lang.Double]
case b: BigDecimal => HiveShim.createDecimal(b.underlying()) case _: LongObjectInspector if x.preferWritable() => HiveShim.getLongWritable(a)
case d: Decimal => HiveShim.createDecimal(d.toBigDecimal.underlying()) case _: LongObjectInspector => a.asInstanceOf[java.lang.Long]
case b: Array[Byte] => b case _: ShortObjectInspector if x.preferWritable() => HiveShim.getShortWritable(a)
case d: java.sql.Date => d case _: ShortObjectInspector => a.asInstanceOf[java.lang.Short]
case t: java.sql.Timestamp => t case _: ByteObjectInspector if x.preferWritable() => HiveShim.getByteWritable(a)
} case _: ByteObjectInspector => a.asInstanceOf[java.lang.Byte]
case x: StructObjectInspector => case _: HiveDecimalObjectInspector if x.preferWritable() =>
val fieldRefs = x.getAllStructFieldRefs HiveShim.getDecimalWritable(a.asInstanceOf[Decimal])
val row = a.asInstanceOf[Seq[_]] case _: HiveDecimalObjectInspector =>
val result = new java.util.ArrayList[AnyRef](fieldRefs.length) HiveShim.createDecimal(a.asInstanceOf[Decimal].toBigDecimal.underlying())
var i = 0 case _: BinaryObjectInspector if x.preferWritable() => HiveShim.getBinaryWritable(a)
while (i < fieldRefs.length) { case _: BinaryObjectInspector => a.asInstanceOf[Array[Byte]]
result.add(wrap(row(i), fieldRefs.get(i).getFieldObjectInspector)) case _: DateObjectInspector if x.preferWritable() => HiveShim.getDateWritable(a)
i += 1 case _: DateObjectInspector => a.asInstanceOf[java.sql.Date]
} case _: TimestampObjectInspector if x.preferWritable() => HiveShim.getTimestampWritable(a)
case _: TimestampObjectInspector => a.asInstanceOf[java.sql.Timestamp]
result
case x: ListObjectInspector =>
val list = new java.util.ArrayList[Object]
a.asInstanceOf[Seq[_]].foreach {
v => list.add(wrap(v, x.getListElementObjectInspector))
}
list
case x: MapObjectInspector =>
// Some UDFs seem to assume we pass in a HashMap.
val hashMap = new java.util.HashMap[AnyRef, AnyRef]()
hashMap.putAll(a.asInstanceOf[Map[_, _]].map {
case (k, v) =>
wrap(k, x.getMapKeyObjectInspector) -> wrap(v, x.getMapValueObjectInspector)
})
hashMap
} }
case x: SettableStructObjectInspector =>
val fieldRefs = x.getAllStructFieldRefs
val row = a.asInstanceOf[Seq[_]]
// 1. create the pojo (most likely) object
val result = x.create()
var i = 0
while (i < fieldRefs.length) {
// 2. set the property for the pojo
x.setStructFieldData(
result,
fieldRefs.get(i),
wrap(row(i), fieldRefs.get(i).getFieldObjectInspector))
i += 1
}
result
case x: StructObjectInspector =>
val fieldRefs = x.getAllStructFieldRefs
val row = a.asInstanceOf[Seq[_]]
val result = new java.util.ArrayList[AnyRef](fieldRefs.length)
var i = 0
while (i < fieldRefs.length) {
result.add(wrap(row(i), fieldRefs.get(i).getFieldObjectInspector))
i += 1
}
result
case x: ListObjectInspector =>
val list = new java.util.ArrayList[Object]
a.asInstanceOf[Seq[_]].foreach {
v => list.add(wrap(v, x.getListElementObjectInspector))
}
list
case x: MapObjectInspector =>
// Some UDFs seem to assume we pass in a HashMap.
val hashMap = new java.util.HashMap[AnyRef, AnyRef]()
hashMap.putAll(a.asInstanceOf[Map[_, _]].map {
case (k, v) =>
wrap(k, x.getMapKeyObjectInspector) -> wrap(v, x.getMapValueObjectInspector)
})
hashMap
} }
def wrap( def wrap(
@ -254,6 +471,11 @@ private[hive] trait HiveInspectors {
cache cache
} }
/**
* @param dataType Catalyst data type
* @return Hive java object inspector (recursively), not the Writable ObjectInspector
* We can easily map to the Hive built-in object inspector according to the data type.
*/
def toInspector(dataType: DataType): ObjectInspector = dataType match { def toInspector(dataType: DataType): ObjectInspector = dataType match {
case ArrayType(tpe, _) => case ArrayType(tpe, _) =>
ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe))
@ -272,12 +494,20 @@ private[hive] trait HiveInspectors {
case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
case DateType => PrimitiveObjectInspectorFactory.javaDateObjectInspector case DateType => PrimitiveObjectInspectorFactory.javaDateObjectInspector
case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
// TODO decimal precision?
case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
case StructType(fields) => case StructType(fields) =>
ObjectInspectorFactory.getStandardStructObjectInspector( ObjectInspectorFactory.getStandardStructObjectInspector(
fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) fields.map(f => f.name), fields.map(f => toInspector(f.dataType)))
} }
/**
* Map the catalyst expression to ObjectInspector, however,
* if the expression is [[Literal]] or foldable, a constant writable object inspector returns;
* Otherwise, we always get the object inspector according to its data type(in catalyst)
* @param expr Catalyst expression to be mapped
* @return Hive java objectinspector (recursively).
*/
def toInspector(expr: Expression): ObjectInspector = expr match { def toInspector(expr: Expression): ObjectInspector = expr match {
case Literal(value, StringType) => case Literal(value, StringType) =>
HiveShim.getStringWritableConstantObjectInspector(value) HiveShim.getStringWritableConstantObjectInspector(value)
@ -326,8 +556,12 @@ private[hive] trait HiveInspectors {
}) })
ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map) ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map)
} }
// We will enumerate all of the possible constant expressions, throw exception if we missed
case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].") case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].")
// ideally, we don't test the foldable here(but in optimizer), however, some of the
// Hive UDF / UDAF requires its argument to be constant objectinspector, we do it eagerly.
case _ if expr.foldable => toInspector(Literal(expr.eval(), expr.dataType)) case _ if expr.foldable => toInspector(Literal(expr.eval(), expr.dataType))
// For those non constant expression, map to object inspector according to its data type
case _ => toInspector(expr.dataType) case _ => toInspector(expr.dataType)
} }

View file

@ -0,0 +1,220 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hive
import java.sql.Date
import java.util
import org.apache.hadoop.hive.serde2.io.DoubleWritable
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.scalatest.FunSuite
import org.apache.hadoop.hive.ql.udf.UDAFPercentile
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, StructObjectInspector, ObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
import org.apache.hadoop.io.LongWritable
import org.apache.spark.sql.catalyst.expressions.{Literal, Row}
class HiveInspectorSuite extends FunSuite with HiveInspectors {
test("Test wrap SettableStructObjectInspector") {
val udaf = new UDAFPercentile.PercentileLongEvaluator()
udaf.init()
udaf.iterate(new LongWritable(1), 0.1)
udaf.iterate(new LongWritable(1), 0.1)
val state = udaf.terminatePartial()
val soi = ObjectInspectorFactory.getReflectionObjectInspector(
classOf[UDAFPercentile.State],
ObjectInspectorOptions.JAVA).asInstanceOf[StructObjectInspector]
val a = unwrap(state, soi).asInstanceOf[Row]
val b = wrap(a, soi).asInstanceOf[UDAFPercentile.State]
val sfCounts = soi.getStructFieldRef("counts")
val sfPercentiles = soi.getStructFieldRef("percentiles")
assert(2 === soi.getStructFieldData(b, sfCounts)
.asInstanceOf[util.Map[LongWritable, LongWritable]]
.get(new LongWritable(1L))
.get())
assert(0.1 === soi.getStructFieldData(b, sfPercentiles)
.asInstanceOf[util.ArrayList[DoubleWritable]]
.get(0)
.get())
}
val data =
Literal(true) ::
Literal(0.asInstanceOf[Byte]) ::
Literal(0.asInstanceOf[Short]) ::
Literal(0) ::
Literal(0.asInstanceOf[Long]) ::
Literal(0.asInstanceOf[Float]) ::
Literal(0.asInstanceOf[Double]) ::
Literal("0") ::
Literal(new Date(2014, 9, 23)) ::
Literal(Decimal(BigDecimal(123.123))) ::
Literal(new java.sql.Timestamp(123123)) ::
Literal(Array[Byte](1,2,3)) ::
Literal(Seq[Int](1,2,3), ArrayType(IntegerType)) ::
Literal(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) ::
Literal(Row(1,2.0d,3.0f),
StructType(StructField("c1", IntegerType) ::
StructField("c2", DoubleType) ::
StructField("c3", FloatType) :: Nil)) ::
Nil
val row = data.map(_.eval(null))
val dataTypes = data.map(_.dataType)
import scala.collection.JavaConversions._
def toWritableInspector(dataType: DataType): ObjectInspector = dataType match {
case ArrayType(tpe, _) =>
ObjectInspectorFactory.getStandardListObjectInspector(toWritableInspector(tpe))
case MapType(keyType, valueType, _) =>
ObjectInspectorFactory.getStandardMapObjectInspector(
toWritableInspector(keyType), toWritableInspector(valueType))
case StringType => PrimitiveObjectInspectorFactory.writableStringObjectInspector
case IntegerType => PrimitiveObjectInspectorFactory.writableIntObjectInspector
case DoubleType => PrimitiveObjectInspectorFactory.writableDoubleObjectInspector
case BooleanType => PrimitiveObjectInspectorFactory.writableBooleanObjectInspector
case LongType => PrimitiveObjectInspectorFactory.writableLongObjectInspector
case FloatType => PrimitiveObjectInspectorFactory.writableFloatObjectInspector
case ShortType => PrimitiveObjectInspectorFactory.writableShortObjectInspector
case ByteType => PrimitiveObjectInspectorFactory.writableByteObjectInspector
case NullType => PrimitiveObjectInspectorFactory.writableVoidObjectInspector
case BinaryType => PrimitiveObjectInspectorFactory.writableBinaryObjectInspector
case DateType => PrimitiveObjectInspectorFactory.writableDateObjectInspector
case TimestampType => PrimitiveObjectInspectorFactory.writableTimestampObjectInspector
case DecimalType() => PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector
case StructType(fields) =>
ObjectInspectorFactory.getStandardStructObjectInspector(
fields.map(f => f.name), fields.map(f => toWritableInspector(f.dataType)))
}
def checkDataType(dt1: Seq[DataType], dt2: Seq[DataType]): Unit = {
dt1.zip(dt2).map {
case (dd1, dd2) =>
assert(dd1.getClass === dd2.getClass) // DecimalType doesn't has the default precision info
}
}
def checkValues(row1: Seq[Any], row2: Seq[Any]): Unit = {
row1.zip(row2).map {
case (r1, r2) => checkValues(r1, r2)
}
}
def checkValues(v1: Any, v2: Any): Unit = {
(v1, v2) match {
case (r1: Decimal, r2: Decimal) =>
// Ignore the Decimal precision
assert(r1.compare(r2) === 0)
case (r1: Array[Byte], r2: Array[Byte])
if r1 != null && r2 != null && r1.length == r2.length =>
r1.zip(r2).map { case (b1, b2) => assert(b1 === b2) }
case (r1: Date, r2: Date) => assert(r1.compareTo(r2) === 0)
case (r1, r2) => assert(r1 === r2)
}
}
test("oi => datatype => oi") {
val ois = dataTypes.map(toInspector)
checkDataType(ois.map(inspectorToDataType), dataTypes)
checkDataType(dataTypes.map(toWritableInspector).map(inspectorToDataType), dataTypes)
}
test("wrap / unwrap null, constant null and writables") {
val writableOIs = dataTypes.map(toWritableInspector)
val nullRow = data.map(d => null)
checkValues(nullRow, nullRow.zip(writableOIs).map {
case (d, oi) => unwrap(wrap(d, oi), oi)
})
// struct couldn't be constant, sweep it out
val constantExprs = data.filter(!_.dataType.isInstanceOf[StructType])
val constantData = constantExprs.map(_.eval())
val constantNullData = constantData.map(_ => null)
val constantWritableOIs = constantExprs.map(e => toWritableInspector(e.dataType))
val constantNullWritableOIs = constantExprs.map(e => toInspector(Literal(null, e.dataType)))
checkValues(constantData, constantData.zip(constantWritableOIs).map {
case (d, oi) => unwrap(wrap(d, oi), oi)
})
checkValues(constantNullData, constantData.zip(constantNullWritableOIs).map {
case (d, oi) => unwrap(wrap(d, oi), oi)
})
checkValues(constantNullData, constantNullData.zip(constantWritableOIs).map {
case (d, oi) => unwrap(wrap(d, oi), oi)
})
}
test("wrap / unwrap primitive writable object inspector") {
val writableOIs = dataTypes.map(toWritableInspector)
checkValues(row, row.zip(writableOIs).map {
case (data, oi) => unwrap(wrap(data, oi), oi)
})
}
test("wrap / unwrap primitive java object inspector") {
val ois = dataTypes.map(toInspector)
checkValues(row, row.zip(ois).map {
case (data, oi) => unwrap(wrap(data, oi), oi)
})
}
test("wrap / unwrap Struct Type") {
val dt = StructType(dataTypes.zipWithIndex.map {
case (t, idx) => StructField(s"c_$idx", t)
})
checkValues(row, unwrap(wrap(row, toInspector(dt)), toInspector(dt)).asInstanceOf[Row])
checkValues(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt)))
}
test("wrap / unwrap Array Type") {
val dt = ArrayType(dataTypes(0))
val d = row(0) :: row(0) :: Nil
checkValues(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt)))
checkValues(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt)))
checkValues(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt))))
checkValues(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt))))
}
test("wrap / unwrap Map Type") {
val dt = MapType(dataTypes(0), dataTypes(1))
val d = Map(row(0) -> row(1))
checkValues(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt)))
checkValues(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt)))
checkValues(d, unwrap(wrap(d, toInspector(Literal(d, dt))), toInspector(Literal(d, dt))))
checkValues(d, unwrap(wrap(null, toInspector(Literal(d, dt))), toInspector(Literal(d, dt))))
}
}

View file

@ -35,6 +35,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector,
import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory}
import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils}
import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.{io => hadoopIo} import org.apache.hadoop.{io => hadoopIo}
import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.InputFormat
import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.types.decimal.Decimal
@ -71,76 +72,114 @@ private[hive] object HiveShim {
def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = def getStringWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.STRING, PrimitiveCategory.STRING,
if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])) getStringWritable(value))
def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = def getIntWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.INT, PrimitiveCategory.INT,
if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])) getIntWritable(value))
def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.DOUBLE, PrimitiveCategory.DOUBLE,
if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double])) getDoubleWritable(value))
def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.BOOLEAN, PrimitiveCategory.BOOLEAN,
if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean])) getBooleanWritable(value))
def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = def getLongWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.LONG, PrimitiveCategory.LONG,
if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])) getLongWritable(value))
def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.FLOAT, PrimitiveCategory.FLOAT,
if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float])) getFloatWritable(value))
def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = def getShortWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.SHORT, PrimitiveCategory.SHORT,
if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])) getShortWritable(value))
def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = def getByteWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.BYTE, PrimitiveCategory.BYTE,
if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])) getByteWritable(value))
def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.BINARY, PrimitiveCategory.BINARY,
if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]])) getBinaryWritable(value))
def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = def getDateWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.DATE, PrimitiveCategory.DATE,
if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date])) getDateWritable(value))
def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.TIMESTAMP, PrimitiveCategory.TIMESTAMP,
if (value == null) { getTimestampWritable(value))
null
} else {
new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp])
})
def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.DECIMAL, PrimitiveCategory.DECIMAL,
if (value == null) { getDecimalWritable(value))
null
} else {
new hiveIo.HiveDecimalWritable(
HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying()))
})
def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
PrimitiveCategory.VOID, null) PrimitiveCategory.VOID, null)
def getStringWritable(value: Any): hadoopIo.Text =
if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])
def getIntWritable(value: Any): hadoopIo.IntWritable =
if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])
def getDoubleWritable(value: Any): hiveIo.DoubleWritable =
if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double])
def getBooleanWritable(value: Any): hadoopIo.BooleanWritable =
if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean])
def getLongWritable(value: Any): hadoopIo.LongWritable =
if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])
def getFloatWritable(value: Any): hadoopIo.FloatWritable =
if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float])
def getShortWritable(value: Any): hiveIo.ShortWritable =
if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])
def getByteWritable(value: Any): hiveIo.ByteWritable =
if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])
def getBinaryWritable(value: Any): hadoopIo.BytesWritable =
if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]])
def getDateWritable(value: Any): hiveIo.DateWritable =
if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date])
def getTimestampWritable(value: Any): hiveIo.TimestampWritable =
if (value == null) {
null
} else {
new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp])
}
def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable =
if (value == null) {
null
} else {
new hiveIo.HiveDecimalWritable(
HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying()))
}
def getPrimitiveNullWritable: NullWritable = NullWritable.get()
def createDriverResultsArray = new JArrayList[String] def createDriverResultsArray = new JArrayList[String]
def processResults(results: JArrayList[String]) = results def processResults(results: JArrayList[String]) = results
@ -197,7 +236,11 @@ private[hive] object HiveShim {
} }
def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = {
Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) if (hdoi.preferWritable()) {
Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue)
} else {
Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue())
}
} }
} }

View file

@ -22,6 +22,7 @@ import java.util.Properties
import org.apache.hadoop.conf.Configuration import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.common.StatsSetupConst
import org.apache.hadoop.hive.common.`type`.{HiveDecimal} import org.apache.hadoop.hive.common.`type`.{HiveDecimal}
@ -163,91 +164,123 @@ private[hive] object HiveShim {
new TableDesc(inputFormatClass, outputFormatClass, properties) new TableDesc(inputFormatClass, outputFormatClass, properties)
} }
def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = def getStringWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.stringTypeInfo, TypeInfoFactory.stringTypeInfo, getStringWritable(value))
if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]))
def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = def getIntWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.intTypeInfo, TypeInfoFactory.intTypeInfo, getIntWritable(value))
if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]))
def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.doubleTypeInfo, if (value == null) { TypeInfoFactory.doubleTypeInfo, getDoubleWritable(value))
null
} else {
new hiveIo.DoubleWritable(value.asInstanceOf[Double])
})
def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.booleanTypeInfo, if (value == null) { TypeInfoFactory.booleanTypeInfo, getBooleanWritable(value))
null
} else {
new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean])
})
def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = def getLongWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.longTypeInfo, TypeInfoFactory.longTypeInfo, getLongWritable(value))
if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]))
def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.floatTypeInfo, if (value == null) { TypeInfoFactory.floatTypeInfo, getFloatWritable(value))
null
} else {
new hadoopIo.FloatWritable(value.asInstanceOf[Float])
})
def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = def getShortWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.shortTypeInfo, TypeInfoFactory.shortTypeInfo, getShortWritable(value))
if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]))
def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = def getByteWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.byteTypeInfo, TypeInfoFactory.byteTypeInfo, getByteWritable(value))
if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]))
def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.binaryTypeInfo, if (value == null) { TypeInfoFactory.binaryTypeInfo, getBinaryWritable(value))
null
} else {
new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]])
})
def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = def getDateWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.dateTypeInfo, TypeInfoFactory.dateTypeInfo, getDateWritable(value))
if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date]))
def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.timestampTypeInfo, if (value == null) { TypeInfoFactory.timestampTypeInfo, getTimestampWritable(value))
null
} else {
new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp])
})
def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.decimalTypeInfo, TypeInfoFactory.decimalTypeInfo, getDecimalWritable(value))
if (value == null) {
null
} else {
// TODO precise, scale?
new hiveIo.HiveDecimalWritable(
HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying()))
})
def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector =
PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector(
TypeInfoFactory.voidTypeInfo, null) TypeInfoFactory.voidTypeInfo, null)
def getStringWritable(value: Any): hadoopIo.Text =
if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String])
def getIntWritable(value: Any): hadoopIo.IntWritable =
if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int])
def getDoubleWritable(value: Any): hiveIo.DoubleWritable =
if (value == null) {
null
} else {
new hiveIo.DoubleWritable(value.asInstanceOf[Double])
}
def getBooleanWritable(value: Any): hadoopIo.BooleanWritable =
if (value == null) {
null
} else {
new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean])
}
def getLongWritable(value: Any): hadoopIo.LongWritable =
if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long])
def getFloatWritable(value: Any): hadoopIo.FloatWritable =
if (value == null) {
null
} else {
new hadoopIo.FloatWritable(value.asInstanceOf[Float])
}
def getShortWritable(value: Any): hiveIo.ShortWritable =
if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short])
def getByteWritable(value: Any): hiveIo.ByteWritable =
if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte])
def getBinaryWritable(value: Any): hadoopIo.BytesWritable =
if (value == null) {
null
} else {
new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]])
}
def getDateWritable(value: Any): hiveIo.DateWritable =
if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[java.sql.Date])
def getTimestampWritable(value: Any): hiveIo.TimestampWritable =
if (value == null) {
null
} else {
new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp])
}
def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable =
if (value == null) {
null
} else {
// TODO precise, scale?
new hiveIo.HiveDecimalWritable(
HiveShim.createDecimal(value.asInstanceOf[Decimal].toBigDecimal.underlying()))
}
def getPrimitiveNullWritable: NullWritable = NullWritable.get()
def createDriverResultsArray = new JArrayList[Object] def createDriverResultsArray = new JArrayList[Object]
def processResults(results: JArrayList[Object]) = { def processResults(results: JArrayList[Object]) = {
@ -355,7 +388,12 @@ private[hive] object HiveShim {
} }
def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = {
Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) if (hdoi.preferWritable()) {
Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue,
hdoi.precision(), hdoi.scale())
} else {
Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale())
}
} }
} }