diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index e3efa2d3ae..b665d4a31b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.thriftserver import java.io._ import java.nio.charset.StandardCharsets.UTF_8 -import java.util.{ArrayList => JArrayList, Locale} +import java.util.{ArrayList => JArrayList, List => JList, Locale} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -37,6 +37,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.log4j.Level import org.apache.thrift.transport.TSocket +import sun.misc.{Signal, SignalHandler} import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil @@ -438,5 +439,112 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { ret } } + + // Adapted processLine from Hive 2.3's CliDriver.processLine. + override def processLine(line: String, allowInterrupting: Boolean): Int = { + var oldSignal: SignalHandler = null + var interruptSignal: Signal = null + + if (allowInterrupting) { + // Remember all threads that were running at the time we started line processing. + // Hook up the custom Ctrl+C handler while processing this line + interruptSignal = new Signal("INT") + oldSignal = Signal.handle(interruptSignal, new SignalHandler() { + private var interruptRequested: Boolean = false + + override def handle(signal: Signal) { + val initialRequest = !interruptRequested + interruptRequested = true + + // Kill the VM on second ctrl+c + if (!initialRequest) { + console.printInfo("Exiting the JVM") + System.exit(127) + } + + // Interrupt the CLI thread to stop the current statement and return + // to prompt + console.printInfo("Interrupting... Be patient, this might take some time.") + console.printInfo("Press Ctrl+C again to kill JVM") + + HiveInterruptUtils.interrupt() + } + }) + } + + try { + var lastRet: Int = 0 + + // we can not use "split" function directly as ";" may be quoted + val commands = splitSemiColon(line).asScala + var command: String = "" + for (oneCmd <- commands) { + if (StringUtils.endsWith(oneCmd, "\\")) { + command += StringUtils.chop(oneCmd) + ";" + } else { + command += oneCmd + if (!StringUtils.isBlank(command)) { + val ret = processCmd(command) + command = "" + lastRet = ret + val ignoreErrors = HiveConf.getBoolVar(conf, HiveConf.ConfVars.CLIIGNOREERRORS) + if (ret != 0 && !ignoreErrors) { + CommandProcessorFactory.clean(conf.asInstanceOf[HiveConf]) + ret + } + } + } + } + CommandProcessorFactory.clean(conf.asInstanceOf[HiveConf]) + lastRet + } finally { + // Once we are done processing the line, restore the old handler + if (oldSignal != null && interruptSignal != null) { + Signal.handle(interruptSignal, oldSignal) + } + } + } + + // Adapted splitSemiColon from Hive 2.3's CliDriver.splitSemiColon. + private def splitSemiColon(line: String): JList[String] = { + var insideSingleQuote = false + var insideDoubleQuote = false + var escape = false + var beginIndex = 0 + val ret = new JArrayList[String] + for (index <- 0 until line.length) { + if (line.charAt(index) == '\'') { + // take a look to see if it is escaped + if (!escape) { + // flip the boolean variable + insideSingleQuote = !insideSingleQuote + } + } else if (line.charAt(index) == '\"') { + // take a look to see if it is escaped + if (!escape) { + // flip the boolean variable + insideDoubleQuote = !insideDoubleQuote + } + } else if (line.charAt(index) == ';') { + if (insideSingleQuote || insideDoubleQuote) { + // do not split + } else { + // split, do not include ; itself + ret.add(line.substring(beginIndex, index)) + beginIndex = index + 1 + } + } else { + // nothing to do + } + // set the escape + if (escape) { + escape = false + } else if (line.charAt(index) == '\\') { + escape = true + } + } + ret.add(line.substring(beginIndex)) + ret + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index f3063675a7..04e7f579ff 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -384,4 +384,13 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { -> "" ) } + + test("SPARK-26321 Should not split semicolon within quoted string literals") { + runCliWithin(3.minute)( + """select 'Test1', "^;^";""" -> "Test1\t^;^", + """select 'Test2', "\";";""" -> "Test2\t\";", + """select 'Test3', "\';";""" -> "Test3\t';", + "select concat('Test4', ';');" -> "Test4;" + ) + } }