[SPARK-23856][SQL] Add an option queryTimeout
in JDBCOptions
## What changes were proposed in this pull request? This pr added an option `queryTimeout` for the number of seconds the the driver will wait for a Statement object to execute. ## How was this patch tested? Added tests in `JDBCSuite`. Author: Takeshi Yamamuro <yamamuro@apache.org> Closes #21173 from maropu/SPARK-23856.
This commit is contained in:
parent
3159ee085b
commit
a53ea70c1d
|
@ -1338,6 +1338,17 @@ the following case-insensitive options:
|
|||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td><code>queryTimeout</code></td>
|
||||
<td>
|
||||
The number of seconds the driver will wait for a Statement object to execute to the given
|
||||
number of seconds. Zero means there is no limit. In the write path, this option depends on
|
||||
how JDBC drivers implement the API <code>setQueryTimeout</code>, e.g., the h2 JDBC driver
|
||||
checks the timeout of each query instead of an entire JDBC batch.
|
||||
It defaults to <code>0</code>.
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<tr>
|
||||
<td><code>fetchsize</code></td>
|
||||
<td>
|
||||
|
|
|
@ -257,7 +257,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
|
|||
* @param connectionProperties JDBC database connection arguments, a list of arbitrary string
|
||||
* tag/value. Normally at least a "user" and "password" property
|
||||
* should be included. "fetchsize" can be used to control the
|
||||
* number of rows per fetch.
|
||||
* number of rows per fetch and "queryTimeout" can be used to wait
|
||||
* for a Statement object to execute to the given number of seconds.
|
||||
* @since 1.4.0
|
||||
*/
|
||||
def jdbc(
|
||||
|
|
|
@ -89,6 +89,10 @@ class JDBCOptions(
|
|||
// the number of partitions
|
||||
val numPartitions = parameters.get(JDBC_NUM_PARTITIONS).map(_.toInt)
|
||||
|
||||
// the number of seconds the driver will wait for a Statement object to execute to the given
|
||||
// number of seconds. Zero means there is no limit.
|
||||
val queryTimeout = parameters.getOrElse(JDBC_QUERY_TIMEOUT, "0").toInt
|
||||
|
||||
// ------------------------------------------------------------
|
||||
// Optional parameters only for reading
|
||||
// ------------------------------------------------------------
|
||||
|
@ -160,6 +164,7 @@ object JDBCOptions {
|
|||
val JDBC_LOWER_BOUND = newOption("lowerBound")
|
||||
val JDBC_UPPER_BOUND = newOption("upperBound")
|
||||
val JDBC_NUM_PARTITIONS = newOption("numPartitions")
|
||||
val JDBC_QUERY_TIMEOUT = newOption("queryTimeout")
|
||||
val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize")
|
||||
val JDBC_TRUNCATE = newOption("truncate")
|
||||
val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
|
||||
|
|
|
@ -57,6 +57,7 @@ object JDBCRDD extends Logging {
|
|||
try {
|
||||
val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
|
||||
try {
|
||||
statement.setQueryTimeout(options.queryTimeout)
|
||||
val rs = statement.executeQuery()
|
||||
try {
|
||||
JdbcUtils.getSchema(rs, dialect, alwaysNullable = true)
|
||||
|
@ -281,6 +282,7 @@ private[jdbc] class JDBCRDD(
|
|||
val statement = conn.prepareStatement(sql)
|
||||
logInfo(s"Executing sessionInitStatement: $sql")
|
||||
try {
|
||||
statement.setQueryTimeout(options.queryTimeout)
|
||||
statement.execute()
|
||||
} finally {
|
||||
statement.close()
|
||||
|
@ -298,6 +300,7 @@ private[jdbc] class JDBCRDD(
|
|||
stmt = conn.prepareStatement(sqlText,
|
||||
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
|
||||
stmt.setFetchSize(options.fetchSize)
|
||||
stmt.setQueryTimeout(options.queryTimeout)
|
||||
rs = stmt.executeQuery()
|
||||
val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
|
|||
saveTable(df, tableSchema, isCaseSensitive, options)
|
||||
} else {
|
||||
// Otherwise, do not truncate the table, instead drop and recreate it
|
||||
dropTable(conn, options.table)
|
||||
dropTable(conn, options.table, options)
|
||||
createTable(conn, df, options)
|
||||
saveTable(df, Some(df.schema), isCaseSensitive, options)
|
||||
}
|
||||
|
|
|
@ -76,6 +76,7 @@ object JdbcUtils extends Logging {
|
|||
Try {
|
||||
val statement = conn.prepareStatement(dialect.getTableExistsQuery(options.table))
|
||||
try {
|
||||
statement.setQueryTimeout(options.queryTimeout)
|
||||
statement.executeQuery()
|
||||
} finally {
|
||||
statement.close()
|
||||
|
@ -86,9 +87,10 @@ object JdbcUtils extends Logging {
|
|||
/**
|
||||
* Drops a table from the JDBC database.
|
||||
*/
|
||||
def dropTable(conn: Connection, table: String): Unit = {
|
||||
def dropTable(conn: Connection, table: String, options: JDBCOptions): Unit = {
|
||||
val statement = conn.createStatement
|
||||
try {
|
||||
statement.setQueryTimeout(options.queryTimeout)
|
||||
statement.executeUpdate(s"DROP TABLE $table")
|
||||
} finally {
|
||||
statement.close()
|
||||
|
@ -102,6 +104,7 @@ object JdbcUtils extends Logging {
|
|||
val dialect = JdbcDialects.get(options.url)
|
||||
val statement = conn.createStatement
|
||||
try {
|
||||
statement.setQueryTimeout(options.queryTimeout)
|
||||
statement.executeUpdate(dialect.getTruncateQuery(options.table))
|
||||
} finally {
|
||||
statement.close()
|
||||
|
@ -254,6 +257,7 @@ object JdbcUtils extends Logging {
|
|||
try {
|
||||
val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table))
|
||||
try {
|
||||
statement.setQueryTimeout(options.queryTimeout)
|
||||
Some(getSchema(statement.executeQuery(), dialect))
|
||||
} catch {
|
||||
case _: SQLException => None
|
||||
|
@ -596,7 +600,8 @@ object JdbcUtils extends Logging {
|
|||
insertStmt: String,
|
||||
batchSize: Int,
|
||||
dialect: JdbcDialect,
|
||||
isolationLevel: Int): Iterator[Byte] = {
|
||||
isolationLevel: Int,
|
||||
options: JDBCOptions): Iterator[Byte] = {
|
||||
val conn = getConnection()
|
||||
var committed = false
|
||||
|
||||
|
@ -637,6 +642,9 @@ object JdbcUtils extends Logging {
|
|||
|
||||
try {
|
||||
var rowCount = 0
|
||||
|
||||
stmt.setQueryTimeout(options.queryTimeout)
|
||||
|
||||
while (iterator.hasNext) {
|
||||
val row = iterator.next()
|
||||
var i = 0
|
||||
|
@ -819,7 +827,8 @@ object JdbcUtils extends Logging {
|
|||
case _ => df
|
||||
}
|
||||
repartitionedDF.rdd.foreachPartition(iterator => savePartition(
|
||||
getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
|
||||
getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel,
|
||||
options)
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -841,6 +850,7 @@ object JdbcUtils extends Logging {
|
|||
val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions"
|
||||
val statement = conn.createStatement
|
||||
try {
|
||||
statement.setQueryTimeout(options.queryTimeout)
|
||||
statement.executeUpdate(sql)
|
||||
} finally {
|
||||
statement.close()
|
||||
|
|
|
@ -1190,4 +1190,20 @@ class JDBCSuite extends SparkFunSuite
|
|||
assert(sql("select * from people_view").schema === schema)
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-23856 Spark jdbc setQueryTimeout option") {
|
||||
val numJoins = 100
|
||||
val longRunningQuery =
|
||||
s"SELECT t0.NAME AS c0, ${(1 to numJoins).map(i => s"t$i.NAME AS c$i").mkString(", ")} " +
|
||||
s"FROM test.people t0 ${(1 to numJoins).map(i => s"join test.people t$i").mkString(" ")}"
|
||||
val df = spark.read.format("jdbc")
|
||||
.option("Url", urlWithUserAndPass)
|
||||
.option("dbtable", s"($longRunningQuery)")
|
||||
.option("queryTimeout", 1)
|
||||
.load()
|
||||
val errMsg = intercept[SparkException] {
|
||||
df.collect()
|
||||
}.getMessage
|
||||
assert(errMsg.contains("Statement was canceled or the session timed out"))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -515,4 +515,22 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
|
|||
}.getMessage
|
||||
assert(e.contains("NULL not allowed for column \"NAME\""))
|
||||
}
|
||||
|
||||
ignore("SPARK-23856 Spark jdbc setQueryTimeout option") {
|
||||
// The behaviour of the option `queryTimeout` depends on how JDBC drivers implement the API
|
||||
// `setQueryTimeout`. For example, in the h2 JDBC driver, `executeBatch` invokes multiple
|
||||
// INSERT queries in a batch and `setQueryTimeout` means that the driver checks the timeout
|
||||
// of each query. In the PostgreSQL JDBC driver, `setQueryTimeout` means that the driver
|
||||
// checks the timeout of an entire batch in a driver side. So, the test below fails because
|
||||
// this test suite depends on the h2 JDBC driver and the JDBC write path internally
|
||||
// uses `executeBatch`.
|
||||
val errMsg = intercept[SparkException] {
|
||||
spark.range(10000000L).selectExpr("id AS k", "id AS v").coalesce(1).write
|
||||
.mode(SaveMode.Overwrite)
|
||||
.option("queryTimeout", 1)
|
||||
.option("batchsize", Int.MaxValue)
|
||||
.jdbc(url1, "TEST.TIMEOUTTEST", properties)
|
||||
}.getMessage
|
||||
assert(errMsg.contains("Statement was canceled or the session timed out"))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue