[SPARK-15425][SQL] Disallow cross joins by default
## What changes were proposed in this pull request? In order to prevent users from inadvertently writing queries with cartesian joins, this patch introduces a new conf `spark.sql.crossJoin.enabled` (set to `false` by default) that if not set, results in a `SparkException` if the query contains one or more cartesian products. ## How was this patch tested? Added a test to verify the new behavior in `JoinSuite`. Additionally, `SQLQuerySuite` and `SQLMetricsSuite` were modified to explicitly enable cartesian products. Author: Sameer Agarwal <sameer@databricks.com> Closes #13209 from sameeragarwal/disallow-cartesian.
This commit is contained in:
parent
fc44b694bf
commit
dafcb05c2e
|
@ -190,7 +190,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
|
|||
}
|
||||
// This join could be very slow or OOM
|
||||
joins.BroadcastNestedLoopJoinExec(
|
||||
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
|
||||
planLater(left), planLater(right), buildSide, joinType, condition,
|
||||
withinBroadcastThreshold = false) :: Nil
|
||||
|
||||
// --- Cases where this strategy does not apply ---------------------------------------------
|
||||
|
||||
|
|
|
@ -19,12 +19,14 @@ package org.apache.spark.sql.execution.joins
|
|||
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.util.collection.{BitSet, CompactBuffer}
|
||||
|
||||
case class BroadcastNestedLoopJoinExec(
|
||||
|
@ -32,7 +34,8 @@ case class BroadcastNestedLoopJoinExec(
|
|||
right: SparkPlan,
|
||||
buildSide: BuildSide,
|
||||
joinType: JoinType,
|
||||
condition: Option[Expression]) extends BinaryExecNode {
|
||||
condition: Option[Expression],
|
||||
withinBroadcastThreshold: Boolean = true) extends BinaryExecNode {
|
||||
|
||||
override private[sql] lazy val metrics = Map(
|
||||
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
|
||||
|
@ -337,6 +340,15 @@ case class BroadcastNestedLoopJoinExec(
|
|||
)
|
||||
}
|
||||
|
||||
protected override def doPrepare(): Unit = {
|
||||
if (!withinBroadcastThreshold && !sqlContext.conf.crossJoinEnabled) {
|
||||
throw new AnalysisException("Both sides of this join are outside the broadcasting " +
|
||||
"threshold and computing it could be prohibitively expensive. To explicitly enable it, " +
|
||||
s"please set ${SQLConf.CROSS_JOINS_ENABLED.key} = true")
|
||||
}
|
||||
super.doPrepare()
|
||||
}
|
||||
|
||||
protected override def doExecute(): RDD[InternalRow] = {
|
||||
val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
|
||||
|
||||
|
|
|
@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.joins
|
|||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
|
||||
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.util.CompletionIterator
|
||||
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
|
||||
|
||||
|
@ -88,6 +90,15 @@ case class CartesianProductExec(
|
|||
override private[sql] lazy val metrics = Map(
|
||||
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
|
||||
|
||||
protected override def doPrepare(): Unit = {
|
||||
if (!sqlContext.conf.crossJoinEnabled) {
|
||||
throw new AnalysisException("Cartesian joins could be prohibitively expensive and are " +
|
||||
"disabled by default. To explicitly enable them, please set " +
|
||||
s"${SQLConf.CROSS_JOINS_ENABLED.key} = true")
|
||||
}
|
||||
super.doPrepare()
|
||||
}
|
||||
|
||||
protected override def doExecute(): RDD[InternalRow] = {
|
||||
val numOutputRows = longMetric("numOutputRows")
|
||||
|
||||
|
|
|
@ -338,9 +338,14 @@ object SQLConf {
|
|||
.booleanConf
|
||||
.createWithDefault(true)
|
||||
|
||||
val CROSS_JOINS_ENABLED = SQLConfigBuilder("spark.sql.crossJoin.enabled")
|
||||
.doc("When false, we will throw an error if a query contains a cross join")
|
||||
.booleanConf
|
||||
.createWithDefault(false)
|
||||
|
||||
val ORDER_BY_ORDINAL = SQLConfigBuilder("spark.sql.orderByOrdinal")
|
||||
.doc("When true, the ordinal numbers are treated as the position in the select list. " +
|
||||
"When false, the ordinal numbers in order/sort By clause are ignored.")
|
||||
"When false, the ordinal numbers in order/sort by clause are ignored.")
|
||||
.booleanConf
|
||||
.createWithDefault(true)
|
||||
|
||||
|
@ -622,6 +627,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
|
|||
|
||||
def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED)
|
||||
|
||||
def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)
|
||||
|
||||
// Do not use a value larger than 4000 as the default value of this property.
|
||||
// See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information.
|
||||
def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD)
|
||||
|
|
|
@ -62,7 +62,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
|
|||
test("join operator selection") {
|
||||
spark.cacheManager.clearCache()
|
||||
|
||||
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
|
||||
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0",
|
||||
SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
|
||||
Seq(
|
||||
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
|
||||
classOf[SortMergeJoinExec]),
|
||||
|
@ -204,13 +205,27 @@ class JoinSuite extends QueryTest with SharedSQLContext {
|
|||
testData.rdd.flatMap(row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
|
||||
}
|
||||
|
||||
test("cartisian product join") {
|
||||
checkAnswer(
|
||||
testData3.join(testData3),
|
||||
Row(1, null, 1, null) ::
|
||||
Row(1, null, 2, 2) ::
|
||||
Row(2, 2, 1, null) ::
|
||||
Row(2, 2, 2, 2) :: Nil)
|
||||
test("cartesian product join") {
|
||||
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
|
||||
checkAnswer(
|
||||
testData3.join(testData3),
|
||||
Row(1, null, 1, null) ::
|
||||
Row(1, null, 2, 2) ::
|
||||
Row(2, 2, 1, null) ::
|
||||
Row(2, 2, 2, 2) :: Nil)
|
||||
}
|
||||
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") {
|
||||
val e = intercept[Exception] {
|
||||
checkAnswer(
|
||||
testData3.join(testData3),
|
||||
Row(1, null, 1, null) ::
|
||||
Row(1, null, 2, 2) ::
|
||||
Row(2, 2, 1, null) ::
|
||||
Row(2, 2, 2, 2) :: Nil)
|
||||
}
|
||||
assert(e.getMessage.contains("Cartesian joins could be prohibitively expensive and are " +
|
||||
"disabled by default"))
|
||||
}
|
||||
}
|
||||
|
||||
test("left outer join") {
|
||||
|
|
|
@ -104,9 +104,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
|||
).toDF("a", "b", "c").createOrReplaceTempView("cachedData")
|
||||
|
||||
spark.catalog.cacheTable("cachedData")
|
||||
checkAnswer(
|
||||
sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"),
|
||||
Row(0) :: Row(81) :: Nil)
|
||||
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
|
||||
checkAnswer(
|
||||
sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"),
|
||||
Row(0) :: Row(81) :: Nil)
|
||||
}
|
||||
}
|
||||
|
||||
test("self join with aliases") {
|
||||
|
@ -435,10 +437,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
|
||||
test("left semi greater than predicate") {
|
||||
checkAnswer(
|
||||
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
|
||||
Seq(Row(3, 1), Row(3, 2))
|
||||
)
|
||||
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
|
||||
checkAnswer(
|
||||
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
|
||||
Seq(Row(3, 1), Row(3, 2))
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
test("left semi greater than predicate and equal operator") {
|
||||
|
@ -824,12 +828,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
|
||||
test("cartesian product join") {
|
||||
checkAnswer(
|
||||
testData3.join(testData3),
|
||||
Row(1, null, 1, null) ::
|
||||
Row(1, null, 2, 2) ::
|
||||
Row(2, 2, 1, null) ::
|
||||
Row(2, 2, 2, 2) :: Nil)
|
||||
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
|
||||
checkAnswer(
|
||||
testData3.join(testData3),
|
||||
Row(1, null, 1, null) ::
|
||||
Row(1, null, 2, 2) ::
|
||||
Row(2, 2, 1, null) ::
|
||||
Row(2, 2, 2, 2) :: Nil)
|
||||
}
|
||||
}
|
||||
|
||||
test("left outer join") {
|
||||
|
|
|
@ -187,7 +187,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
|
|||
}
|
||||
|
||||
test(s"$testName using CartesianProduct") {
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1",
|
||||
SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
CartesianProductExec(left, right, Some(condition())),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.apache.spark.sql._
|
|||
import org.apache.spark.sql.execution.SparkPlanInfo
|
||||
import org.apache.spark.sql.execution.ui.SparkPlanGraph
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.util.{AccumulatorContext, JsonProtocol, Utils}
|
||||
|
||||
|
@ -237,16 +238,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
|
|||
test("BroadcastNestedLoopJoin metrics") {
|
||||
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
|
||||
testDataForJoin.createOrReplaceTempView("testDataForJoin")
|
||||
withTempTable("testDataForJoin") {
|
||||
// Assume the execution plan is
|
||||
// ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
|
||||
val df = spark.sql(
|
||||
"SELECT * FROM testData2 left JOIN testDataForJoin ON " +
|
||||
"testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a")
|
||||
testSparkPlanMetrics(df, 3, Map(
|
||||
1L -> ("BroadcastNestedLoopJoin", Map(
|
||||
"number of output rows" -> 12L)))
|
||||
)
|
||||
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
|
||||
withTempTable("testDataForJoin") {
|
||||
// Assume the execution plan is
|
||||
// ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
|
||||
val df = spark.sql(
|
||||
"SELECT * FROM testData2 left JOIN testDataForJoin ON " +
|
||||
"testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a")
|
||||
testSparkPlanMetrics(df, 3, Map(
|
||||
1L -> ("BroadcastNestedLoopJoin", Map(
|
||||
"number of output rows" -> 12L)))
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -263,17 +266,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
|
|||
}
|
||||
|
||||
test("CartesianProduct metrics") {
|
||||
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
|
||||
testDataForJoin.createOrReplaceTempView("testDataForJoin")
|
||||
withTempTable("testDataForJoin") {
|
||||
// Assume the execution plan is
|
||||
// ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0)
|
||||
val df = spark.sql(
|
||||
"SELECT * FROM testData2 JOIN testDataForJoin")
|
||||
testSparkPlanMetrics(df, 1, Map(
|
||||
0L -> ("CartesianProduct", Map(
|
||||
"number of output rows" -> 12L)))
|
||||
)
|
||||
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
|
||||
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
|
||||
testDataForJoin.createOrReplaceTempView("testDataForJoin")
|
||||
withTempTable("testDataForJoin") {
|
||||
// Assume the execution plan is
|
||||
// ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0)
|
||||
val df = spark.sql(
|
||||
"SELECT * FROM testData2 JOIN testDataForJoin")
|
||||
testSparkPlanMetrics(df, 1, Map(
|
||||
0L -> ("CartesianProduct", Map("number of output rows" -> 12L)))
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -40,6 +40,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
|
|||
private val originalColumnBatchSize = TestHive.conf.columnBatchSize
|
||||
private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning
|
||||
private val originalConvertMetastoreOrc = TestHive.sessionState.convertMetastoreOrc
|
||||
private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled
|
||||
|
||||
def testCases: Seq[(String, File)] = {
|
||||
hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f)
|
||||
|
@ -61,6 +62,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
|
|||
// Ensures that the plans generation use metastore relation and not OrcRelation
|
||||
// Was done because SqlBuilder does not work with plans having logical relation
|
||||
TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, false)
|
||||
// Ensures that cross joins are enabled so that we can test them
|
||||
TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true)
|
||||
RuleExecutor.resetTime()
|
||||
}
|
||||
|
||||
|
@ -72,6 +75,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
|
|||
TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
|
||||
TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
|
||||
TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, originalConvertMetastoreOrc)
|
||||
TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled)
|
||||
TestHive.sessionState.functionRegistry.restore()
|
||||
|
||||
// For debugging dump some statistics about how much time was spent in various optimizer rules
|
||||
|
|
|
@ -35,6 +35,7 @@ import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec
|
|||
import org.apache.spark.sql.hive._
|
||||
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
|
||||
import org.apache.spark.sql.hive.test.TestHive._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
||||
case class TestData(a: Int, b: String)
|
||||
|
||||
|
@ -48,6 +49,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
|
|||
|
||||
import org.apache.spark.sql.hive.test.TestHive.implicits._
|
||||
|
||||
private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled
|
||||
|
||||
override def beforeAll() {
|
||||
super.beforeAll()
|
||||
TestHive.setCacheTables(true)
|
||||
|
@ -55,6 +58,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
|
|||
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
|
||||
// Add Locale setting
|
||||
Locale.setDefault(Locale.US)
|
||||
// Ensures that cross joins are enabled so that we can test them
|
||||
TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true)
|
||||
}
|
||||
|
||||
override def afterAll() {
|
||||
|
@ -63,6 +68,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
|
|||
TimeZone.setDefault(originalTimeZone)
|
||||
Locale.setDefault(originalLocale)
|
||||
sql("DROP TEMPORARY FUNCTION IF EXISTS udtf_count2")
|
||||
TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled)
|
||||
} finally {
|
||||
super.afterAll()
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue