[SPARK-33107][SQL] Remove hive-2.3 workaround code
### What changes were proposed in this pull request? This pr remove `hive-2.3` workaround code. ### Why are the changes needed? Make code more clear and readable. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing unit tests. Closes #29996 from wangyum/SPARK-33107. Authored-by: Yuming Wang <yumwang@ebay.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
parent
7696ca5673
commit
5e170140b0
|
@ -118,8 +118,7 @@ private[hive] class SparkExecuteStatementOperation(
|
|||
validateDefaultFetchOrientation(order)
|
||||
assertState(OperationState.FINISHED)
|
||||
setHasResultSet(true)
|
||||
val resultRowSet: RowSet =
|
||||
ThriftserverShimUtils.resultRowSet(getResultSetSchema, getProtocolVersion)
|
||||
val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion, false)
|
||||
|
||||
// Reset iter when FETCH_FIRST or FETCH_PRIOR
|
||||
if ((order.equals(FetchOrientation.FETCH_FIRST) ||
|
||||
|
|
|
@ -20,6 +20,8 @@ package org.apache.spark.sql.hive.thriftserver
|
|||
import java.util.UUID
|
||||
|
||||
import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType
|
||||
import org.apache.hadoop.hive.serde2.thrift.Type
|
||||
import org.apache.hadoop.hive.serde2.thrift.Type._
|
||||
import org.apache.hive.service.cli.OperationState
|
||||
import org.apache.hive.service.cli.operation.GetTypeInfoOperation
|
||||
import org.apache.hive.service.cli.session.HiveSession
|
||||
|
@ -61,7 +63,7 @@ private[hive] class SparkGetTypeInfoOperation(
|
|||
parentSession.getUsername)
|
||||
|
||||
try {
|
||||
ThriftserverShimUtils.supportedType().foreach(typeInfo => {
|
||||
SparkGetTypeInfoUtil.supportedType.foreach(typeInfo => {
|
||||
val rowData = Array[AnyRef](
|
||||
typeInfo.getName, // TYPE_NAME
|
||||
typeInfo.toJavaSQLType.asInstanceOf[AnyRef], // DATA_TYPE
|
||||
|
@ -90,3 +92,13 @@ private[hive] class SparkGetTypeInfoOperation(
|
|||
HiveThriftServer2.eventManager.onStatementFinish(statementId)
|
||||
}
|
||||
}
|
||||
|
||||
private[hive] object SparkGetTypeInfoUtil {
|
||||
val supportedType: Seq[Type] = {
|
||||
Seq(NULL_TYPE, BOOLEAN_TYPE, STRING_TYPE, BINARY_TYPE,
|
||||
TINYINT_TYPE, SMALLINT_TYPE, INT_TYPE, BIGINT_TYPE,
|
||||
FLOAT_TYPE, DOUBLE_TYPE, DECIMAL_TYPE,
|
||||
DATE_TYPE, TIMESTAMP_TYPE,
|
||||
ARRAY_TYPE, MAP_TYPE, STRUCT_TYPE)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,6 +37,7 @@ import org.apache.hadoop.hive.ql.session.SessionState
|
|||
import org.apache.hadoop.security.{Credentials, UserGroupInformation}
|
||||
import org.apache.log4j.Level
|
||||
import org.apache.thrift.transport.TSocket
|
||||
import org.slf4j.LoggerFactory
|
||||
import sun.misc.{Signal, SignalHandler}
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
|
@ -307,7 +308,9 @@ private[hive] object SparkSQLCLIDriver extends Logging {
|
|||
private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
|
||||
private val sessionState = SessionState.get().asInstanceOf[CliSessionState]
|
||||
|
||||
private val console = ThriftserverShimUtils.getConsole
|
||||
private val LOG = LoggerFactory.getLogger(classOf[SparkSQLCLIDriver])
|
||||
|
||||
private val console = new SessionState.LogHelper(LOG)
|
||||
|
||||
private val isRemoteMode = {
|
||||
SparkSQLCLIDriver.isRemoteMode(sessionState)
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.hadoop.hive.conf.HiveConf
|
|||
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
|
||||
import org.apache.hive.service.cli.SessionHandle
|
||||
import org.apache.hive.service.cli.session.SessionManager
|
||||
import org.apache.hive.service.rpc.thrift.TProtocolVersion
|
||||
import org.apache.hive.service.server.HiveServer2
|
||||
|
||||
import org.apache.spark.sql.SQLContext
|
||||
|
@ -45,7 +46,7 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext:
|
|||
}
|
||||
|
||||
override def openSession(
|
||||
protocol: ThriftserverShimUtils.TProtocolVersion,
|
||||
protocol: TProtocolVersion,
|
||||
username: String,
|
||||
passwd: String,
|
||||
ipAddress: String,
|
||||
|
|
|
@ -1,80 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.hive.thriftserver
|
||||
|
||||
import org.apache.hadoop.hive.ql.session.SessionState
|
||||
import org.apache.hadoop.hive.serde2.thrift.Type
|
||||
import org.apache.hadoop.hive.serde2.thrift.Type._
|
||||
import org.apache.hive.service.cli.{RowSet, RowSetFactory, TableSchema}
|
||||
import org.apache.hive.service.rpc.thrift.TProtocolVersion._
|
||||
import org.slf4j.LoggerFactory
|
||||
|
||||
/**
|
||||
* Various utilities for hive-thriftserver used to upgrade the built-in Hive.
|
||||
*/
|
||||
private[thriftserver] object ThriftserverShimUtils {
|
||||
|
||||
private[thriftserver] object TOperationType {
|
||||
val GET_TYPE_INFO = org.apache.hive.service.rpc.thrift.TOperationType.GET_TYPE_INFO
|
||||
}
|
||||
|
||||
private[thriftserver] type TProtocolVersion = org.apache.hive.service.rpc.thrift.TProtocolVersion
|
||||
private[thriftserver] type Client = org.apache.hive.service.rpc.thrift.TCLIService.Client
|
||||
private[thriftserver] type TOpenSessionReq = org.apache.hive.service.rpc.thrift.TOpenSessionReq
|
||||
private[thriftserver] type TGetSchemasReq = org.apache.hive.service.rpc.thrift.TGetSchemasReq
|
||||
private[thriftserver] type TGetTablesReq = org.apache.hive.service.rpc.thrift.TGetTablesReq
|
||||
private[thriftserver] type TGetColumnsReq = org.apache.hive.service.rpc.thrift.TGetColumnsReq
|
||||
private[thriftserver] type TGetInfoReq = org.apache.hive.service.rpc.thrift.TGetInfoReq
|
||||
private[thriftserver] type TExecuteStatementReq =
|
||||
org.apache.hive.service.rpc.thrift.TExecuteStatementReq
|
||||
private[thriftserver] type THandleIdentifier =
|
||||
org.apache.hive.service.rpc.thrift.THandleIdentifier
|
||||
private[thriftserver] type TOperationType = org.apache.hive.service.rpc.thrift.TOperationType
|
||||
private[thriftserver] type TOperationHandle = org.apache.hive.service.rpc.thrift.TOperationHandle
|
||||
|
||||
private[thriftserver] def getConsole: SessionState.LogHelper = {
|
||||
val LOG = LoggerFactory.getLogger(classOf[SparkSQLCLIDriver])
|
||||
new SessionState.LogHelper(LOG)
|
||||
}
|
||||
|
||||
private[thriftserver] def resultRowSet(
|
||||
getResultSetSchema: TableSchema,
|
||||
getProtocolVersion: TProtocolVersion): RowSet = {
|
||||
RowSetFactory.create(getResultSetSchema, getProtocolVersion, false)
|
||||
}
|
||||
|
||||
private[thriftserver] def supportedType(): Seq[Type] = {
|
||||
Seq(NULL_TYPE, BOOLEAN_TYPE, STRING_TYPE, BINARY_TYPE,
|
||||
TINYINT_TYPE, SMALLINT_TYPE, INT_TYPE, BIGINT_TYPE,
|
||||
FLOAT_TYPE, DOUBLE_TYPE, DECIMAL_TYPE,
|
||||
DATE_TYPE, TIMESTAMP_TYPE,
|
||||
ARRAY_TYPE, MAP_TYPE, STRUCT_TYPE)
|
||||
}
|
||||
|
||||
private[thriftserver] val testedProtocolVersions = Seq(
|
||||
HIVE_CLI_SERVICE_PROTOCOL_V1,
|
||||
HIVE_CLI_SERVICE_PROTOCOL_V2,
|
||||
HIVE_CLI_SERVICE_PROTOCOL_V3,
|
||||
HIVE_CLI_SERVICE_PROTOCOL_V4,
|
||||
HIVE_CLI_SERVICE_PROTOCOL_V5,
|
||||
HIVE_CLI_SERVICE_PROTOCOL_V6,
|
||||
HIVE_CLI_SERVICE_PROTOCOL_V7,
|
||||
HIVE_CLI_SERVICE_PROTOCOL_V8,
|
||||
HIVE_CLI_SERVICE_PROTOCOL_V9,
|
||||
HIVE_CLI_SERVICE_PROTOCOL_V10)
|
||||
}
|
|
@ -22,8 +22,7 @@ import java.util.UUID
|
|||
import org.apache.hive.service.cli.OperationHandle
|
||||
import org.apache.hive.service.cli.operation.GetCatalogsOperation
|
||||
import org.apache.hive.service.cli.session.HiveSession
|
||||
|
||||
import org.apache.spark.sql.hive.thriftserver.ThriftserverShimUtils.{THandleIdentifier, TOperationHandle, TOperationType}
|
||||
import org.apache.hive.service.rpc.thrift.{THandleIdentifier, TOperationHandle, TOperationType}
|
||||
|
||||
class GetCatalogsOperationMock(parentSession: HiveSession)
|
||||
extends GetCatalogsOperation(parentSession) {
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.apache.hadoop.hive.conf.HiveConf
|
|||
import org.apache.hive.service.cli.OperationHandle
|
||||
import org.apache.hive.service.cli.operation.{GetCatalogsOperation, Operation, OperationManager}
|
||||
import org.apache.hive.service.cli.session.{HiveSession, HiveSessionImpl, SessionManager}
|
||||
import org.apache.hive.service.rpc.thrift.TProtocolVersion
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
|
||||
|
@ -39,7 +40,7 @@ class HiveSessionImplSuite extends SparkFunSuite {
|
|||
operationManager = new OperationManagerMock()
|
||||
|
||||
session = new HiveSessionImpl(
|
||||
ThriftserverShimUtils.testedProtocolVersions.head,
|
||||
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1,
|
||||
"",
|
||||
"",
|
||||
new HiveConf(),
|
||||
|
|
|
@ -37,6 +37,7 @@ import org.apache.hive.jdbc.HiveDriver
|
|||
import org.apache.hive.service.auth.PlainSaslHelper
|
||||
import org.apache.hive.service.cli.{FetchOrientation, FetchType, GetInfoType, RowSet}
|
||||
import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient
|
||||
import org.apache.hive.service.rpc.thrift.TCLIService.Client
|
||||
import org.apache.thrift.protocol.TBinaryProtocol
|
||||
import org.apache.thrift.transport.TSocket
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
@ -67,7 +68,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
|
|||
val user = System.getProperty("user.name")
|
||||
val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
|
||||
val protocol = new TBinaryProtocol(transport)
|
||||
val client = new ThriftCLIServiceClient(new ThriftserverShimUtils.Client(protocol))
|
||||
val client = new ThriftCLIServiceClient(new Client(protocol))
|
||||
|
||||
transport.open()
|
||||
try f(client) finally transport.close()
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.apache.hadoop.hive.ql.session.SessionState
|
|||
import org.apache.hive.jdbc.HttpBasicAuthInterceptor
|
||||
import org.apache.hive.service.auth.PlainSaslHelper
|
||||
import org.apache.hive.service.cli.thrift.{ThriftCLIService, ThriftCLIServiceClient}
|
||||
import org.apache.hive.service.rpc.thrift.TCLIService.Client
|
||||
import org.apache.http.impl.client.HttpClientBuilder
|
||||
import org.apache.thrift.protocol.TBinaryProtocol
|
||||
import org.apache.thrift.transport.{THttpClient, TSocket}
|
||||
|
@ -115,7 +116,7 @@ trait SharedThriftServer extends SharedSparkSession {
|
|||
}
|
||||
|
||||
val protocol = new TBinaryProtocol(transport)
|
||||
val client = new ThriftCLIServiceClient(new ThriftserverShimUtils.Client(protocol))
|
||||
val client = new ThriftCLIServiceClient(new Client(protocol))
|
||||
|
||||
transport.open()
|
||||
try f(client) finally transport.close()
|
||||
|
|
|
@ -25,6 +25,7 @@ import scala.concurrent.duration._
|
|||
import org.apache.hadoop.hive.conf.HiveConf
|
||||
import org.apache.hive.service.cli.OperationState
|
||||
import org.apache.hive.service.cli.session.{HiveSession, HiveSessionImpl}
|
||||
import org.apache.hive.service.rpc.thrift.TProtocolVersion
|
||||
import org.mockito.Mockito.{doReturn, mock, spy, when, RETURNS_DEEP_STUBS}
|
||||
import org.mockito.invocation.InvocationOnMock
|
||||
|
||||
|
@ -64,7 +65,7 @@ class SparkExecuteStatementOperationSuite extends SparkFunSuite with SharedSpark
|
|||
).foreach { case (finalState, transition) =>
|
||||
test("SPARK-32057 SparkExecuteStatementOperation should not transiently become ERROR " +
|
||||
s"before being set to $finalState") {
|
||||
val hiveSession = new HiveSessionImpl(ThriftserverShimUtils.testedProtocolVersions.head,
|
||||
val hiveSession = new HiveSessionImpl(TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1,
|
||||
"username", "password", new HiveConf, "ip address")
|
||||
hiveSession.open(new util.HashMap)
|
||||
|
||||
|
|
|
@ -255,7 +255,7 @@ class SparkMetadataOperationSuite extends HiveThriftJdbcTest {
|
|||
|
||||
withJdbcStatement() { statement =>
|
||||
val metaData = statement.getConnection.getMetaData
|
||||
checkResult(metaData.getTypeInfo, ThriftserverShimUtils.supportedType().map(_.getName))
|
||||
checkResult(metaData.getTypeInfo, SparkGetTypeInfoUtil.supportedType.map(_.getName))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -23,11 +23,12 @@ import java.util.{List => JList, Properties}
|
|||
import org.apache.hive.jdbc.{HiveConnection, HiveQueryResultSet}
|
||||
import org.apache.hive.service.auth.PlainSaslHelper
|
||||
import org.apache.hive.service.cli.GetInfoType
|
||||
import org.apache.hive.service.rpc.thrift.{TExecuteStatementReq, TGetInfoReq, TGetTablesReq, TOpenSessionReq, TProtocolVersion}
|
||||
import org.apache.hive.service.rpc.thrift.TCLIService.Client
|
||||
import org.apache.thrift.protocol.TBinaryProtocol
|
||||
import org.apache.thrift.transport.TSocket
|
||||
|
||||
import org.apache.spark.sql.catalyst.util.NumberConverter
|
||||
import org.apache.spark.sql.hive.HiveUtils
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
class SparkThriftServerProtocolVersionsSuite extends HiveThriftJdbcTest {
|
||||
|
@ -35,20 +36,20 @@ class SparkThriftServerProtocolVersionsSuite extends HiveThriftJdbcTest {
|
|||
override def mode: ServerMode.Value = ServerMode.binary
|
||||
|
||||
def testExecuteStatementWithProtocolVersion(
|
||||
version: ThriftserverShimUtils.TProtocolVersion,
|
||||
version: TProtocolVersion,
|
||||
sql: String)(f: HiveQueryResultSet => Unit): Unit = {
|
||||
val rawTransport = new TSocket("localhost", serverPort)
|
||||
val connection = new HiveConnection(s"jdbc:hive2://localhost:$serverPort", new Properties)
|
||||
val user = System.getProperty("user.name")
|
||||
val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
|
||||
val client = new ThriftserverShimUtils.Client(new TBinaryProtocol(transport))
|
||||
val client = new Client(new TBinaryProtocol(transport))
|
||||
transport.open()
|
||||
var rs: HiveQueryResultSet = null
|
||||
try {
|
||||
val clientProtocol = new ThriftserverShimUtils.TOpenSessionReq(version)
|
||||
val clientProtocol = new TOpenSessionReq(version)
|
||||
val openResp = client.OpenSession(clientProtocol)
|
||||
val sessHandle = openResp.getSessionHandle
|
||||
val execReq = new ThriftserverShimUtils.TExecuteStatementReq(sessHandle, sql)
|
||||
val execReq = new TExecuteStatementReq(sessHandle, sql)
|
||||
val execResp = client.ExecuteStatement(execReq)
|
||||
val stmtHandle = execResp.getOperationHandle
|
||||
|
||||
|
@ -73,23 +74,21 @@ class SparkThriftServerProtocolVersionsSuite extends HiveThriftJdbcTest {
|
|||
}
|
||||
}
|
||||
|
||||
def testGetInfoWithProtocolVersion(version: ThriftserverShimUtils.TProtocolVersion): Unit = {
|
||||
def testGetInfoWithProtocolVersion(version: TProtocolVersion): Unit = {
|
||||
val rawTransport = new TSocket("localhost", serverPort)
|
||||
val connection = new HiveConnection(s"jdbc:hive2://localhost:$serverPort", new Properties)
|
||||
val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
|
||||
val client = new ThriftserverShimUtils.Client(new TBinaryProtocol(transport))
|
||||
val client = new Client(new TBinaryProtocol(transport))
|
||||
transport.open()
|
||||
try {
|
||||
val clientProtocol = new ThriftserverShimUtils.TOpenSessionReq(version)
|
||||
val clientProtocol = new TOpenSessionReq(version)
|
||||
val openResp = client.OpenSession(clientProtocol)
|
||||
val sessHandle = openResp.getSessionHandle
|
||||
|
||||
val dbVersionReq =
|
||||
new ThriftserverShimUtils.TGetInfoReq(sessHandle, GetInfoType.CLI_DBMS_VER.toTGetInfoType)
|
||||
val dbVersionReq = new TGetInfoReq(sessHandle, GetInfoType.CLI_DBMS_VER.toTGetInfoType)
|
||||
val dbVersion = client.GetInfo(dbVersionReq).getInfoValue.getStringValue
|
||||
|
||||
val dbNameReq =
|
||||
new ThriftserverShimUtils.TGetInfoReq(sessHandle, GetInfoType.CLI_DBMS_NAME.toTGetInfoType)
|
||||
val dbNameReq = new TGetInfoReq(sessHandle, GetInfoType.CLI_DBMS_NAME.toTGetInfoType)
|
||||
val dbName = client.GetInfo(dbNameReq).getInfoValue.getStringValue
|
||||
|
||||
assert(dbVersion === org.apache.spark.SPARK_VERSION)
|
||||
|
@ -102,21 +101,21 @@ class SparkThriftServerProtocolVersionsSuite extends HiveThriftJdbcTest {
|
|||
}
|
||||
|
||||
def testGetTablesWithProtocolVersion(
|
||||
version: ThriftserverShimUtils.TProtocolVersion,
|
||||
version: TProtocolVersion,
|
||||
schema: String,
|
||||
tableNamePattern: String,
|
||||
tableTypes: JList[String])(f: HiveQueryResultSet => Unit): Unit = {
|
||||
val rawTransport = new TSocket("localhost", serverPort)
|
||||
val connection = new HiveConnection(s"jdbc:hive2://localhost:$serverPort", new Properties)
|
||||
val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
|
||||
val client = new ThriftserverShimUtils.Client(new TBinaryProtocol(transport))
|
||||
val client = new Client(new TBinaryProtocol(transport))
|
||||
transport.open()
|
||||
var rs: HiveQueryResultSet = null
|
||||
try {
|
||||
val clientProtocol = new ThriftserverShimUtils.TOpenSessionReq(version)
|
||||
val clientProtocol = new TOpenSessionReq(version)
|
||||
val openResp = client.OpenSession(clientProtocol)
|
||||
val sessHandle = openResp.getSessionHandle
|
||||
val getTableReq = new ThriftserverShimUtils.TGetTablesReq(sessHandle)
|
||||
val getTableReq = new TGetTablesReq(sessHandle)
|
||||
getTableReq.setSchemaName(schema)
|
||||
getTableReq.setTableName(tableNamePattern)
|
||||
getTableReq.setTableTypes(tableTypes)
|
||||
|
@ -144,7 +143,7 @@ class SparkThriftServerProtocolVersionsSuite extends HiveThriftJdbcTest {
|
|||
}
|
||||
}
|
||||
|
||||
ThriftserverShimUtils.testedProtocolVersions.foreach { version =>
|
||||
TProtocolVersion.values().foreach { version =>
|
||||
test(s"$version get byte type") {
|
||||
testExecuteStatementWithProtocolVersion(version, "SELECT cast(1 as byte)") { rs =>
|
||||
assert(rs.next())
|
||||
|
|
|
@ -17,18 +17,16 @@
|
|||
|
||||
package org.apache.spark.sql.hive
|
||||
|
||||
import java.io.{InputStream, OutputStream}
|
||||
import java.lang.reflect.Method
|
||||
import java.rmi.server.UID
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.language.implicitConversions
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import com.google.common.base.Objects
|
||||
import org.apache.avro.Schema
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.hadoop.hive.ql.exec.SerializationUtilities
|
||||
import org.apache.hadoop.hive.ql.exec.UDF
|
||||
import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc}
|
||||
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro
|
||||
|
@ -148,40 +146,12 @@ private[hive] object HiveShim {
|
|||
case _ => false
|
||||
}
|
||||
|
||||
private lazy val serUtilClass =
|
||||
Utils.classForName("org.apache.hadoop.hive.ql.exec.SerializationUtilities")
|
||||
private lazy val utilClass = Utils.classForName("org.apache.hadoop.hive.ql.exec.Utilities")
|
||||
private val deserializeMethodName = "deserializeObjectByKryo"
|
||||
private val serializeMethodName = "serializeObjectByKryo"
|
||||
|
||||
private def findMethod(klass: Class[_], name: String, args: Class[_]*): Method = {
|
||||
val method = klass.getDeclaredMethod(name, args: _*)
|
||||
method.setAccessible(true)
|
||||
method
|
||||
}
|
||||
|
||||
def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = {
|
||||
val borrowKryo = serUtilClass.getMethod("borrowKryo")
|
||||
val kryo = borrowKryo.invoke(serUtilClass)
|
||||
val deserializeObjectByKryo = findMethod(serUtilClass, deserializeMethodName,
|
||||
kryo.getClass.getSuperclass, classOf[InputStream], classOf[Class[_]])
|
||||
try {
|
||||
deserializeObjectByKryo.invoke(null, kryo, is, clazz).asInstanceOf[UDFType]
|
||||
} finally {
|
||||
serUtilClass.getMethod("releaseKryo", kryo.getClass.getSuperclass).invoke(null, kryo)
|
||||
}
|
||||
SerializationUtilities.deserializePlan(is, clazz).asInstanceOf[UDFType]
|
||||
}
|
||||
|
||||
def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = {
|
||||
val borrowKryo = serUtilClass.getMethod("borrowKryo")
|
||||
val kryo = borrowKryo.invoke(serUtilClass)
|
||||
val serializeObjectByKryo = findMethod(serUtilClass, serializeMethodName,
|
||||
kryo.getClass.getSuperclass, classOf[Object], classOf[OutputStream])
|
||||
try {
|
||||
serializeObjectByKryo.invoke(null, kryo, function, out)
|
||||
} finally {
|
||||
serUtilClass.getMethod("releaseKryo", kryo.getClass.getSuperclass).invoke(null, kryo)
|
||||
}
|
||||
SerializationUtilities.serializePlan(function, out)
|
||||
}
|
||||
|
||||
def writeExternal(out: java.io.ObjectOutput): Unit = {
|
||||
|
|
|
@ -55,10 +55,8 @@ private[spark] object HiveUtils extends Logging {
|
|||
sc
|
||||
}
|
||||
|
||||
private val hiveVersion = HiveVersionInfo.getVersion
|
||||
|
||||
/** The version of hive used internally by Spark SQL. */
|
||||
val builtinHiveVersion: String = hiveVersion
|
||||
val builtinHiveVersion: String = HiveVersionInfo.getVersion
|
||||
|
||||
val HIVE_METASTORE_VERSION = buildStaticConf("spark.sql.hive.metastore.version")
|
||||
.doc("Version of the Hive metastore. Available options are " +
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
|
||||
package org.apache.spark.sql.hive
|
||||
|
||||
import java.lang.{Boolean => JBoolean}
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
@ -39,7 +38,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
|
|||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
|
||||
import org.apache.spark.sql.hive.HiveShim._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
* Here we cannot extends `ImplicitTypeCasts` to compatible with UDF input data type, the reason is:
|
||||
|
@ -349,11 +347,7 @@ private[hive] case class HiveUDAFFunction(
|
|||
funcWrapper.createFunction[AbstractGenericUDAFResolver]()
|
||||
}
|
||||
|
||||
val clazz = Utils.classForName(classOf[SimpleGenericUDAFParameterInfo].getName)
|
||||
val ctor = clazz.getDeclaredConstructor(
|
||||
classOf[Array[ObjectInspector]], JBoolean.TYPE, JBoolean.TYPE, JBoolean.TYPE)
|
||||
val args = Array[AnyRef](inputInspectors, JBoolean.FALSE, JBoolean.FALSE, JBoolean.FALSE)
|
||||
val parameterInfo = ctor.newInstance(args: _*).asInstanceOf[SimpleGenericUDAFParameterInfo]
|
||||
val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false, false)
|
||||
resolver.getEvaluator(parameterInfo)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue