diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 139d8e897b..ebb2575416 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -23,9 +23,8 @@ import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ -import scala.concurrent.{Await, Promise, future} +import scala.concurrent.{Await, ExecutionContext, Promise, future} import scala.io.Source import scala.util.{Random, Try} @@ -43,7 +42,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.{Logging, SparkFunSuite} object TestData { @@ -356,31 +355,54 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") queries.foreach(statement.execute) + implicit val ec = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("test-jdbc-cancel")) + try { + // Start a very-long-running query that will take hours to finish, then cancel it in order + // to demonstrate that cancellation works. + val f = future { + statement.executeQuery( + "SELECT COUNT(*) FROM test_map " + + List.fill(10)("join test_map").mkString(" ")) + } + // Note that this is slightly race-prone: if the cancel is issued before the statement + // begins executing then we'll fail with a timeout. As a result, this fixed delay is set + // slightly more conservatively than may be strictly necessary. + Thread.sleep(1000) + statement.cancel() + val e = intercept[SQLException] { + Await.result(f, 3.minute) + } + assert(e.getMessage.contains("cancelled")) - val largeJoin = "SELECT COUNT(*) FROM test_map " + - List.fill(10)("join test_map").mkString(" ") - val f = future { Thread.sleep(100); statement.cancel(); } - val e = intercept[SQLException] { - statement.executeQuery(largeJoin) + // Cancellation is a no-op if spark.sql.hive.thriftServer.async=false + statement.executeQuery("SET spark.sql.hive.thriftServer.async=false") + try { + val sf = future { + statement.executeQuery( + "SELECT COUNT(*) FROM test_map " + + List.fill(4)("join test_map").mkString(" ") + ) + } + // Similarly, this is also slightly race-prone on fast machines where the query above + // might race and complete before we issue the cancel. + Thread.sleep(1000) + statement.cancel() + val rs1 = Await.result(sf, 3.minute) + rs1.next() + assert(rs1.getInt(1) === math.pow(5, 5)) + rs1.close() + + val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map") + rs2.next() + assert(rs2.getInt(1) === 5) + rs2.close() + } finally { + statement.executeQuery("SET spark.sql.hive.thriftServer.async=true") + } + } finally { + ec.shutdownNow() } - assert(e.getMessage contains "cancelled") - Await.result(f, 3.minute) - - // cancel is a noop - statement.executeQuery("SET spark.sql.hive.thriftServer.async=false") - val sf = future { Thread.sleep(100); statement.cancel(); } - val smallJoin = "SELECT COUNT(*) FROM test_map " + - List.fill(4)("join test_map").mkString(" ") - val rs1 = statement.executeQuery(smallJoin) - Await.result(sf, 3.minute) - rs1.next() - assert(rs1.getInt(1) === math.pow(5, 5)) - rs1.close() - - val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map") - rs2.next() - assert(rs2.getInt(1) === 5) - rs2.close() } } @@ -817,6 +839,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl } override protected def beforeAll(): Unit = { + super.beforeAll() // Chooses a random port between 10000 and 19999 listeningPort = 10000 + Random.nextInt(10000) diagnosisBuffer.clear() @@ -838,7 +861,11 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl } override protected def afterAll(): Unit = { - stopThriftServer() - logInfo("HiveThriftServer2 stopped") + try { + stopThriftServer() + logInfo("HiveThriftServer2 stopped") + } finally { + super.afterAll() + } } }