[SPARK-15425][SQL] Disallow cross joins by default

## What changes were proposed in this pull request?

In order to prevent users from inadvertently writing queries with cartesian joins, this patch introduces a new conf `spark.sql.crossJoin.enabled` (set to `false` by default) that if not set, results in a `SparkException` if the query contains one or more cartesian products.

## How was this patch tested?

Added a test to verify the new behavior in `JoinSuite`. Additionally, `SQLQuerySuite` and `SQLMetricsSuite` were modified to explicitly enable cartesian products.

Author: Sameer Agarwal <sameer@databricks.com>

Closes #13209 from sameeragarwal/disallow-cartesian.
This commit is contained in:
Sameer Agarwal 2016-05-22 23:32:39 -07:00 committed by Reynold Xin
parent fc44b694bf
commit dafcb05c2e
10 changed files with 113 additions and 46 deletions

View file

@ -190,7 +190,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
// This join could be very slow or OOM
joins.BroadcastNestedLoopJoinExec(
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
planLater(left), planLater(right), buildSide, joinType, condition,
withinBroadcastThreshold = false) :: Nil
// --- Cases where this strategy does not apply ---------------------------------------------

View file

@ -19,12 +19,14 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.collection.{BitSet, CompactBuffer}
case class BroadcastNestedLoopJoinExec(
@ -32,7 +34,8 @@ case class BroadcastNestedLoopJoinExec(
right: SparkPlan,
buildSide: BuildSide,
joinType: JoinType,
condition: Option[Expression]) extends BinaryExecNode {
condition: Option[Expression],
withinBroadcastThreshold: Boolean = true) extends BinaryExecNode {
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@ -337,6 +340,15 @@ case class BroadcastNestedLoopJoinExec(
)
}
protected override def doPrepare(): Unit = {
if (!withinBroadcastThreshold && !sqlContext.conf.crossJoinEnabled) {
throw new AnalysisException("Both sides of this join are outside the broadcasting " +
"threshold and computing it could be prohibitively expensive. To explicitly enable it, " +
s"please set ${SQLConf.CROSS_JOINS_ENABLED.key} = true")
}
super.doPrepare()
}
protected override def doExecute(): RDD[InternalRow] = {
val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()

View file

@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark._
import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
@ -88,6 +90,15 @@ case class CartesianProductExec(
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
protected override def doPrepare(): Unit = {
if (!sqlContext.conf.crossJoinEnabled) {
throw new AnalysisException("Cartesian joins could be prohibitively expensive and are " +
"disabled by default. To explicitly enable them, please set " +
s"${SQLConf.CROSS_JOINS_ENABLED.key} = true")
}
super.doPrepare()
}
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")

View file

@ -338,9 +338,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
val CROSS_JOINS_ENABLED = SQLConfigBuilder("spark.sql.crossJoin.enabled")
.doc("When false, we will throw an error if a query contains a cross join")
.booleanConf
.createWithDefault(false)
val ORDER_BY_ORDINAL = SQLConfigBuilder("spark.sql.orderByOrdinal")
.doc("When true, the ordinal numbers are treated as the position in the select list. " +
"When false, the ordinal numbers in order/sort By clause are ignored.")
"When false, the ordinal numbers in order/sort by clause are ignored.")
.booleanConf
.createWithDefault(true)
@ -622,6 +627,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED)
def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)
// Do not use a value larger than 4000 as the default value of this property.
// See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information.
def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD)

View file

@ -62,7 +62,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
test("join operator selection") {
spark.cacheManager.clearCache()
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0",
SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
classOf[SortMergeJoinExec]),
@ -204,13 +205,27 @@ class JoinSuite extends QueryTest with SharedSQLContext {
testData.rdd.flatMap(row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}
test("cartisian product join") {
checkAnswer(
testData3.join(testData3),
Row(1, null, 1, null) ::
Row(1, null, 2, 2) ::
Row(2, 2, 1, null) ::
Row(2, 2, 2, 2) :: Nil)
test("cartesian product join") {
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
checkAnswer(
testData3.join(testData3),
Row(1, null, 1, null) ::
Row(1, null, 2, 2) ::
Row(2, 2, 1, null) ::
Row(2, 2, 2, 2) :: Nil)
}
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") {
val e = intercept[Exception] {
checkAnswer(
testData3.join(testData3),
Row(1, null, 1, null) ::
Row(1, null, 2, 2) ::
Row(2, 2, 1, null) ::
Row(2, 2, 2, 2) :: Nil)
}
assert(e.getMessage.contains("Cartesian joins could be prohibitively expensive and are " +
"disabled by default"))
}
}
test("left outer join") {

View file

@ -104,9 +104,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
).toDF("a", "b", "c").createOrReplaceTempView("cachedData")
spark.catalog.cacheTable("cachedData")
checkAnswer(
sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"),
Row(0) :: Row(81) :: Nil)
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
checkAnswer(
sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"),
Row(0) :: Row(81) :: Nil)
}
}
test("self join with aliases") {
@ -435,10 +437,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("left semi greater than predicate") {
checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
Seq(Row(3, 1), Row(3, 2))
)
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
Seq(Row(3, 1), Row(3, 2))
)
}
}
test("left semi greater than predicate and equal operator") {
@ -824,12 +828,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("cartesian product join") {
checkAnswer(
testData3.join(testData3),
Row(1, null, 1, null) ::
Row(1, null, 2, 2) ::
Row(2, 2, 1, null) ::
Row(2, 2, 2, 2) :: Nil)
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
checkAnswer(
testData3.join(testData3),
Row(1, null, 1, null) ::
Row(1, null, 2, 2) ::
Row(2, 2, 1, null) ::
Row(2, 2, 2, 2) :: Nil)
}
}
test("left outer join") {

View file

@ -187,7 +187,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
}
test(s"$testName using CartesianProduct") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1",
SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
CartesianProductExec(left, right, Some(condition())),
expectedAnswer.map(Row.fromTuple),

View file

@ -29,6 +29,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraph
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.{AccumulatorContext, JsonProtocol, Utils}
@ -237,16 +238,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
test("BroadcastNestedLoopJoin metrics") {
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.createOrReplaceTempView("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
// ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
val df = spark.sql(
"SELECT * FROM testData2 left JOIN testDataForJoin ON " +
"testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a")
testSparkPlanMetrics(df, 3, Map(
1L -> ("BroadcastNestedLoopJoin", Map(
"number of output rows" -> 12L)))
)
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
withTempTable("testDataForJoin") {
// Assume the execution plan is
// ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
val df = spark.sql(
"SELECT * FROM testData2 left JOIN testDataForJoin ON " +
"testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a")
testSparkPlanMetrics(df, 3, Map(
1L -> ("BroadcastNestedLoopJoin", Map(
"number of output rows" -> 12L)))
)
}
}
}
@ -263,17 +266,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
}
test("CartesianProduct metrics") {
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.createOrReplaceTempView("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
// ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0)
val df = spark.sql(
"SELECT * FROM testData2 JOIN testDataForJoin")
testSparkPlanMetrics(df, 1, Map(
0L -> ("CartesianProduct", Map(
"number of output rows" -> 12L)))
)
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.createOrReplaceTempView("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
// ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0)
val df = spark.sql(
"SELECT * FROM testData2 JOIN testDataForJoin")
testSparkPlanMetrics(df, 1, Map(
0L -> ("CartesianProduct", Map("number of output rows" -> 12L)))
)
}
}
}

View file

@ -40,6 +40,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
private val originalColumnBatchSize = TestHive.conf.columnBatchSize
private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning
private val originalConvertMetastoreOrc = TestHive.sessionState.convertMetastoreOrc
private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled
def testCases: Seq[(String, File)] = {
hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f)
@ -61,6 +62,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// Ensures that the plans generation use metastore relation and not OrcRelation
// Was done because SqlBuilder does not work with plans having logical relation
TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, false)
// Ensures that cross joins are enabled so that we can test them
TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true)
RuleExecutor.resetTime()
}
@ -72,6 +75,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, originalConvertMetastoreOrc)
TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled)
TestHive.sessionState.functionRegistry.restore()
// For debugging dump some statistics about how much time was spent in various optimizer rules

View file

@ -35,6 +35,7 @@ import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.internal.SQLConf
case class TestData(a: Int, b: String)
@ -48,6 +49,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
import org.apache.spark.sql.hive.test.TestHive.implicits._
private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled
override def beforeAll() {
super.beforeAll()
TestHive.setCacheTables(true)
@ -55,6 +58,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
// Add Locale setting
Locale.setDefault(Locale.US)
// Ensures that cross joins are enabled so that we can test them
TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true)
}
override def afterAll() {
@ -63,6 +68,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
TimeZone.setDefault(originalTimeZone)
Locale.setDefault(originalLocale)
sql("DROP TEMPORARY FUNCTION IF EXISTS udtf_count2")
TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled)
} finally {
super.afterAll()
}