From 2333a34d390f2fa19b939b8007be0deb31f31d3c Mon Sep 17 00:00:00 2001 From: Daniel van der Ende Date: Fri, 20 Jul 2018 13:03:57 -0700 Subject: [PATCH] [SPARK-22880][SQL] Add cascadeTruncate option to JDBC datasource This commit adds the `cascadeTruncate` option to the JDBC datasource API, for databases that support this functionality (PostgreSQL and Oracle at the moment). This allows for applying a cascading truncate that affects tables that have foreign key constraints on the table being truncated. ## What changes were proposed in this pull request? Add `cascadeTruncate` option to JDBC datasource API. Allow this to affect the `TRUNCATE` query for databases that support this option. ## How was this patch tested? Existing tests for `truncateQuery` were updated. Also, an additional test was added to ensure that the correct syntax was applied, and that enabling the config for databases that do not support this option does not result in invalid queries. Author: Daniel van der Ende Closes #20057 from danielvdende/SPARK-22880. --- docs/sql-programming-guide.md | 7 +++ .../datasources/jdbc/JDBCOptions.scala | 3 ++ .../datasources/jdbc/JdbcUtils.scala | 7 ++- .../spark/sql/jdbc/AggregatedDialect.scala | 13 +++++- .../apache/spark/sql/jdbc/DerbyDialect.scala | 2 + .../apache/spark/sql/jdbc/JdbcDialects.scala | 20 +++++++- .../apache/spark/sql/jdbc/OracleDialect.scala | 16 +++++++ .../spark/sql/jdbc/PostgresDialect.scala | 29 ++++++++---- .../spark/sql/jdbc/TeradataDialect.scala | 18 ++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 46 ++++++++++++++++--- 10 files changed, 140 insertions(+), 21 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ad23dae7c6..4bab58aff0 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1407,6 +1407,13 @@ the following case-insensitive options: This is a JDBC writer related option. When SaveMode.Overwrite is enabled, this option causes Spark to truncate an existing table instead of dropping and recreating it. This can be more efficient, and prevents the table metadata (e.g., indices) from being removed. However, it will not work in some cases, such as when the new data has a different schema. It defaults to false. This option applies only to writing. + + + cascadeTruncate + + This is a JDBC writer related option. If enabled and supported by the JDBC database (PostgreSQL and Oracle at the moment), this options allows execution of a TRUNCATE TABLE t CASCADE (in the case of PostgreSQL a TRUNCATE TABLE ONLY t CASCADE is executed to prevent inadvertently truncating descendant tables). This will affect other tables, and thus should be used with care. This option applies only to writing. It defaults to the default cascading truncate behaviour of the JDBC database in question, specified in the isCascadeTruncate in each JDBCDialect. + + createTableOptions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index eea966d309..574aed4958 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -157,6 +157,8 @@ class JDBCOptions( // ------------------------------------------------------------ // if to truncate the table from the JDBC database val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean + + val isCascadeTruncate: Option[Boolean] = parameters.get(JDBC_CASCADE_TRUNCATE).map(_.toBoolean) // the create table option , which can be table_options or partition_options. // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options @@ -225,6 +227,7 @@ object JDBCOptions { val JDBC_QUERY_TIMEOUT = newOption("queryTimeout") val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") + val JDBC_CASCADE_TRUNCATE = newOption("cascadeTruncate") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes") val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 6cc7922396..edea549748 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -105,7 +105,12 @@ object JdbcUtils extends Logging { val statement = conn.createStatement try { statement.setQueryTimeout(options.queryTimeout) - statement.executeUpdate(dialect.getTruncateQuery(options.table)) + val truncateQuery = if (options.isCascadeTruncate.isDefined) { + dialect.getTruncateQuery(options.table, options.isCascadeTruncate) + } else { + dialect.getTruncateQuery(options.table) + } + statement.executeUpdate(truncateQuery) } finally { statement.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 8b92c8b4f5..3a3246a1b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -64,7 +64,16 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect } } - override def getTruncateQuery(table: String): String = { - dialects.head.getTruncateQuery(table) + /** + * The SQL query used to truncate a table. + * @param table The table to truncate. + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable() + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + dialects.head.getTruncateQuery(table, cascade) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 84f68e779c..d13c29ed46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -41,4 +41,6 @@ private object DerbyDialect extends JdbcDialect { Option(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) case _ => None } + + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 83d87a1181..f76c1fae56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -22,6 +22,7 @@ import java.sql.{Connection, Date, Timestamp} import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions import org.apache.spark.sql.types._ /** @@ -120,12 +121,27 @@ abstract class JdbcDialect extends Serializable { * The SQL query that should be used to truncate a table. Dialects can override this method to * return a query that is suitable for a particular database. For PostgreSQL, for instance, * a different query is used to prevent "TRUNCATE" affecting other tables. - * @param table The name of the table. + * @param table The table to truncate * @return The SQL query to use for truncating a table */ @Since("2.3.0") def getTruncateQuery(table: String): String = { - s"TRUNCATE TABLE $table" + getTruncateQuery(table, isCascadingTruncateTable) + } + + /** + * The SQL query that should be used to truncate a table. Dialects can override this method to + * return a query that is suitable for a particular database. For PostgreSQL, for instance, + * a different query is used to prevent "TRUNCATE" affecting other tables. + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation + * @return The SQL query to use for truncating a table + */ + @Since("2.4.0") + def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + s"TRUNCATE TABLE $table" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 6ef77f2446..f4a6d0a4d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -95,4 +95,20 @@ private case object OracleDialect extends JdbcDialect { } override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + + /** + * The SQL query used to truncate a table. + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable() + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + cascade match { + case Some(true) => s"TRUNCATE TABLE $table CASCADE" + case _ => s"TRUNCATE TABLE $table" + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 13a2035f4d..f8d2bc8e0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -85,15 +85,27 @@ private object PostgresDialect extends JdbcDialect { s"SELECT 1 FROM $table LIMIT 1" } + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + /** - * The SQL query used to truncate a table. For Postgres, the default behaviour is to - * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, - * the Postgres dialect adds 'ONLY' to truncate only the table in question - * @param table The name of the table. - * @return The SQL query to use for truncating a table - */ - override def getTruncateQuery(table: String): String = { - s"TRUNCATE TABLE ONLY $table" + * The SQL query used to truncate a table. For Postgres, the default behaviour is to + * also truncate any descendant tables. As this is a (possibly unwanted) side-effect, + * the Postgres dialect adds 'ONLY' to truncate only the table in question + * @param table The table to truncate + * @param cascade Whether or not to cascade the truncation. Default value is the value of + * isCascadingTruncateTable(). Cascading a truncation will truncate tables + * with a foreign key relationship to the target table. However, it will not + * truncate tables with an inheritance relationship to the target table, as + * the truncate query always includes "ONLY" to prevent this behaviour. + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + cascade match { + case Some(true) => s"TRUNCATE TABLE ONLY $table CASCADE" + case _ => s"TRUNCATE TABLE ONLY $table" + } } override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { @@ -110,5 +122,4 @@ private object PostgresDialect extends JdbcDialect { } } - override def isCascadingTruncateTable(): Option[Boolean] = Some(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 5749b791fc..6c17bd7ed9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -31,4 +31,22 @@ private case object TeradataDialect extends JdbcDialect { case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) case _ => None } + + // Teradata does not support cascading a truncation + override def isCascadingTruncateTable(): Option[Boolean] = Some(false) + + /** + * The SQL query used to truncate a table. Teradata does not support the 'TRUNCATE' syntax that + * other dialects use. Instead, we need to use a 'DELETE FROM' statement. + * @param table The table to truncate. + * @param cascade Whether or not to cascade the truncation. Default value is the + * value of isCascadingTruncateTable(). Teradata does not support cascading a + * 'DELETE FROM' statement (and as mentioned, does not support 'TRUNCATE' syntax) + * @return The SQL query to use for truncating a table + */ + override def getTruncateQuery( + table: String, + cascade: Option[Boolean] = isCascadingTruncateTable): String = { + s"DELETE FROM $table ALL" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0389273d6c..09facb9bef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -861,19 +861,51 @@ class JDBCSuite extends QueryTest } test("truncate table query by jdbc dialect") { - val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") - val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") val h2 = JdbcDialects.get(url) val derby = JdbcDialects.get("jdbc:derby:db") + val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + val teradata = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + val table = "weblogs" val defaultQuery = s"TRUNCATE TABLE $table" val postgresQuery = s"TRUNCATE TABLE ONLY $table" - assert(MySQL.getTruncateQuery(table) == defaultQuery) - assert(Postgres.getTruncateQuery(table) == postgresQuery) - assert(db2.getTruncateQuery(table) == defaultQuery) - assert(h2.getTruncateQuery(table) == defaultQuery) - assert(derby.getTruncateQuery(table) == defaultQuery) + val teradataQuery = s"DELETE FROM $table ALL" + + Seq(mysql, db2, h2, derby).foreach{ dialect => + assert(dialect.getTruncateQuery(table, Some(true)) == defaultQuery) + } + + assert(postgres.getTruncateQuery(table) == postgresQuery) + assert(oracle.getTruncateQuery(table) == defaultQuery) + assert(teradata.getTruncateQuery(table) == teradataQuery) + } + + test("SPARK-22880: Truncate table with CASCADE by jdbc dialect") { + // cascade in a truncate should only be applied for databases that support this, + // even if the parameter is passed. + val mysql = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val derby = JdbcDialects.get("jdbc:derby:db") + val oracle = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + val teradata = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + + val table = "weblogs" + val defaultQuery = s"TRUNCATE TABLE $table" + val postgresQuery = s"TRUNCATE TABLE ONLY $table CASCADE" + val oracleQuery = s"TRUNCATE TABLE $table CASCADE" + val teradataQuery = s"DELETE FROM $table ALL" + + Seq(mysql, db2, h2, derby).foreach{ dialect => + assert(dialect.getTruncateQuery(table, Some(true)) == defaultQuery) + } + assert(postgres.getTruncateQuery(table, Some(true)) == postgresQuery) + assert(oracle.getTruncateQuery(table, Some(true)) == oracleQuery) + assert(teradata.getTruncateQuery(table, Some(true)) == teradataQuery) } test("Test DataFrame.where for Date and Timestamp") {