[SPARK-23856][SQL] Add an option queryTimeout in JDBCOptions

## What changes were proposed in this pull request?
This pr added an option `queryTimeout` for the number of seconds the  the driver will wait for a Statement object to execute.

## How was this patch tested?
Added tests in `JDBCSuite`.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #21173 from maropu/SPARK-23856.
This commit is contained in:
Takeshi Yamamuro 2018-05-18 13:38:36 -07:00 committed by gatorsmile
parent 3159ee085b
commit a53ea70c1d
8 changed files with 69 additions and 5 deletions

View file

@ -1338,6 +1338,17 @@ the following case-insensitive options:
</td>
</tr>
<tr>
<td><code>queryTimeout</code></td>
<td>
The number of seconds the driver will wait for a Statement object to execute to the given
number of seconds. Zero means there is no limit. In the write path, this option depends on
how JDBC drivers implement the API <code>setQueryTimeout</code>, e.g., the h2 JDBC driver
checks the timeout of each query instead of an entire JDBC batch.
It defaults to <code>0</code>.
</td>
</tr>
<tr>
<td><code>fetchsize</code></td>
<td>

View file

@ -257,7 +257,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @param connectionProperties JDBC database connection arguments, a list of arbitrary string
* tag/value. Normally at least a "user" and "password" property
* should be included. "fetchsize" can be used to control the
* number of rows per fetch.
* number of rows per fetch and "queryTimeout" can be used to wait
* for a Statement object to execute to the given number of seconds.
* @since 1.4.0
*/
def jdbc(

View file

@ -89,6 +89,10 @@ class JDBCOptions(
// the number of partitions
val numPartitions = parameters.get(JDBC_NUM_PARTITIONS).map(_.toInt)
// the number of seconds the driver will wait for a Statement object to execute to the given
// number of seconds. Zero means there is no limit.
val queryTimeout = parameters.getOrElse(JDBC_QUERY_TIMEOUT, "0").toInt
// ------------------------------------------------------------
// Optional parameters only for reading
// ------------------------------------------------------------
@ -160,6 +164,7 @@ object JDBCOptions {
val JDBC_LOWER_BOUND = newOption("lowerBound")
val JDBC_UPPER_BOUND = newOption("upperBound")
val JDBC_NUM_PARTITIONS = newOption("numPartitions")
val JDBC_QUERY_TIMEOUT = newOption("queryTimeout")
val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize")
val JDBC_TRUNCATE = newOption("truncate")
val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")

View file

@ -57,6 +57,7 @@ object JDBCRDD extends Logging {
try {
val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
try {
statement.setQueryTimeout(options.queryTimeout)
val rs = statement.executeQuery()
try {
JdbcUtils.getSchema(rs, dialect, alwaysNullable = true)
@ -281,6 +282,7 @@ private[jdbc] class JDBCRDD(
val statement = conn.prepareStatement(sql)
logInfo(s"Executing sessionInitStatement: $sql")
try {
statement.setQueryTimeout(options.queryTimeout)
statement.execute()
} finally {
statement.close()
@ -298,6 +300,7 @@ private[jdbc] class JDBCRDD(
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
stmt.setQueryTimeout(options.queryTimeout)
rs = stmt.executeQuery()
val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)

View file

@ -73,7 +73,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
saveTable(df, tableSchema, isCaseSensitive, options)
} else {
// Otherwise, do not truncate the table, instead drop and recreate it
dropTable(conn, options.table)
dropTable(conn, options.table, options)
createTable(conn, df, options)
saveTable(df, Some(df.schema), isCaseSensitive, options)
}

View file

@ -76,6 +76,7 @@ object JdbcUtils extends Logging {
Try {
val statement = conn.prepareStatement(dialect.getTableExistsQuery(options.table))
try {
statement.setQueryTimeout(options.queryTimeout)
statement.executeQuery()
} finally {
statement.close()
@ -86,9 +87,10 @@ object JdbcUtils extends Logging {
/**
* Drops a table from the JDBC database.
*/
def dropTable(conn: Connection, table: String): Unit = {
def dropTable(conn: Connection, table: String, options: JDBCOptions): Unit = {
val statement = conn.createStatement
try {
statement.setQueryTimeout(options.queryTimeout)
statement.executeUpdate(s"DROP TABLE $table")
} finally {
statement.close()
@ -102,6 +104,7 @@ object JdbcUtils extends Logging {
val dialect = JdbcDialects.get(options.url)
val statement = conn.createStatement
try {
statement.setQueryTimeout(options.queryTimeout)
statement.executeUpdate(dialect.getTruncateQuery(options.table))
} finally {
statement.close()
@ -254,6 +257,7 @@ object JdbcUtils extends Logging {
try {
val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table))
try {
statement.setQueryTimeout(options.queryTimeout)
Some(getSchema(statement.executeQuery(), dialect))
} catch {
case _: SQLException => None
@ -596,7 +600,8 @@ object JdbcUtils extends Logging {
insertStmt: String,
batchSize: Int,
dialect: JdbcDialect,
isolationLevel: Int): Iterator[Byte] = {
isolationLevel: Int,
options: JDBCOptions): Iterator[Byte] = {
val conn = getConnection()
var committed = false
@ -637,6 +642,9 @@ object JdbcUtils extends Logging {
try {
var rowCount = 0
stmt.setQueryTimeout(options.queryTimeout)
while (iterator.hasNext) {
val row = iterator.next()
var i = 0
@ -819,7 +827,8 @@ object JdbcUtils extends Logging {
case _ => df
}
repartitionedDF.rdd.foreachPartition(iterator => savePartition(
getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel,
options)
)
}
@ -841,6 +850,7 @@ object JdbcUtils extends Logging {
val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions"
val statement = conn.createStatement
try {
statement.setQueryTimeout(options.queryTimeout)
statement.executeUpdate(sql)
} finally {
statement.close()

View file

@ -1190,4 +1190,20 @@ class JDBCSuite extends SparkFunSuite
assert(sql("select * from people_view").schema === schema)
}
}
test("SPARK-23856 Spark jdbc setQueryTimeout option") {
val numJoins = 100
val longRunningQuery =
s"SELECT t0.NAME AS c0, ${(1 to numJoins).map(i => s"t$i.NAME AS c$i").mkString(", ")} " +
s"FROM test.people t0 ${(1 to numJoins).map(i => s"join test.people t$i").mkString(" ")}"
val df = spark.read.format("jdbc")
.option("Url", urlWithUserAndPass)
.option("dbtable", s"($longRunningQuery)")
.option("queryTimeout", 1)
.load()
val errMsg = intercept[SparkException] {
df.collect()
}.getMessage
assert(errMsg.contains("Statement was canceled or the session timed out"))
}
}

View file

@ -515,4 +515,22 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
}.getMessage
assert(e.contains("NULL not allowed for column \"NAME\""))
}
ignore("SPARK-23856 Spark jdbc setQueryTimeout option") {
// The behaviour of the option `queryTimeout` depends on how JDBC drivers implement the API
// `setQueryTimeout`. For example, in the h2 JDBC driver, `executeBatch` invokes multiple
// INSERT queries in a batch and `setQueryTimeout` means that the driver checks the timeout
// of each query. In the PostgreSQL JDBC driver, `setQueryTimeout` means that the driver
// checks the timeout of an entire batch in a driver side. So, the test below fails because
// this test suite depends on the h2 JDBC driver and the JDBC write path internally
// uses `executeBatch`.
val errMsg = intercept[SparkException] {
spark.range(10000000L).selectExpr("id AS k", "id AS v").coalesce(1).write
.mode(SaveMode.Overwrite)
.option("queryTimeout", 1)
.option("batchsize", Int.MaxValue)
.jdbc(url1, "TEST.TIMEOUTTEST", properties)
}.getMessage
assert(errMsg.contains("Statement was canceled or the session timed out"))
}
}