[SPARK-9925] [SQL] [TESTS] Set SQLConf.SHUFFLE_PARTITIONS.key correctly for tests
This PR fix the failed test and conflict for #8155 https://issues.apache.org/jira/browse/SPARK-9925 Closes #8155 Author: Yin Huai <yhuai@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #8602 from davies/shuffle_partitions.
This commit is contained in:
parent
22eab706f4
commit
47058ca5db
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.test.{TestSQLContext, SharedSQLContext}
|
||||
|
||||
|
||||
class SQLConfSuite extends QueryTest with SharedSQLContext {
|
||||
|
@ -32,8 +32,12 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
|
||||
test("programmatic ways of basic setting and getting") {
|
||||
// Set a conf first.
|
||||
sqlContext.setConf(testKey, testVal)
|
||||
// Clear the conf.
|
||||
sqlContext.conf.clear()
|
||||
assert(sqlContext.getAllConfs.size === 0)
|
||||
// After clear, only overrideConfs used by unit test should be in the SQLConf.
|
||||
assert(sqlContext.getAllConfs === TestSQLContext.overrideConfs)
|
||||
|
||||
sqlContext.setConf(testKey, testVal)
|
||||
assert(sqlContext.getConf(testKey) === testVal)
|
||||
|
@ -42,7 +46,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
|
|||
|
||||
// Tests SQLConf as accessed from a SQLContext is mutable after
|
||||
// the latter is initialized, unlike SparkConf inside a SparkContext.
|
||||
assert(sqlContext.getConf(testKey) == testVal)
|
||||
assert(sqlContext.getConf(testKey) === testVal)
|
||||
assert(sqlContext.getConf(testKey, testVal + "_") === testVal)
|
||||
assert(sqlContext.getAllConfs.contains(testKey))
|
||||
|
||||
|
@ -73,8 +77,13 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
|
|||
|
||||
test("deprecated property") {
|
||||
sqlContext.conf.clear()
|
||||
sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
|
||||
assert(sqlContext.conf.numShufflePartitions === 10)
|
||||
val original = sqlContext.conf.numShufflePartitions
|
||||
try{
|
||||
sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
|
||||
assert(ctx.conf.numShufflePartitions === 10)
|
||||
} finally {
|
||||
sql(s"set ${SQLConf.SHUFFLE_PARTITIONS}=$original")
|
||||
}
|
||||
}
|
||||
|
||||
test("invalid conf value") {
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.errors.DialectException
|
|||
import org.apache.spark.sql.execution.aggregate
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.SQLTestData._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/** A SQL Dialect for testing purpose, and it can not be nested type */
|
||||
|
@ -991,21 +991,30 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
|||
val nonexistentKey = "nonexistent"
|
||||
|
||||
// "set" itself returns all config variables currently specified in SQLConf.
|
||||
assert(sql("SET").collect().size == 0)
|
||||
assert(sql("SET").collect().size === TestSQLContext.overrideConfs.size)
|
||||
sql("SET").collect().foreach { row =>
|
||||
val key = row.getString(0)
|
||||
val value = row.getString(1)
|
||||
assert(
|
||||
TestSQLContext.overrideConfs.contains(key),
|
||||
s"$key should exist in SQLConf.")
|
||||
assert(
|
||||
TestSQLContext.overrideConfs(key) === value,
|
||||
s"The value of $key should be ${TestSQLContext.overrideConfs(key)} instead of $value.")
|
||||
}
|
||||
val overrideConfs = sql("SET").collect()
|
||||
|
||||
// "set key=val"
|
||||
sql(s"SET $testKey=$testVal")
|
||||
checkAnswer(
|
||||
sql("SET"),
|
||||
Row(testKey, testVal)
|
||||
overrideConfs ++ Seq(Row(testKey, testVal))
|
||||
)
|
||||
|
||||
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
|
||||
checkAnswer(
|
||||
sql("set"),
|
||||
Seq(
|
||||
Row(testKey, testVal),
|
||||
Row(testKey + testKey, testVal + testVal))
|
||||
overrideConfs ++ Seq(Row(testKey, testVal), Row(testKey + testKey, testVal + testVal))
|
||||
)
|
||||
|
||||
// "set key"
|
||||
|
|
|
@ -31,13 +31,24 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel
|
|||
new SparkConf().set("spark.sql.testkey", "true")))
|
||||
}
|
||||
|
||||
// Use fewer partitions to speed up testing
|
||||
// Make sure we set those test specific confs correctly when we create
|
||||
// the SQLConf as well as when we call clear.
|
||||
protected[sql] override def createSession(): SQLSession = new this.SQLSession()
|
||||
|
||||
/** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */
|
||||
protected[sql] class SQLSession extends super.SQLSession {
|
||||
protected[sql] override lazy val conf: SQLConf = new SQLConf {
|
||||
override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5)
|
||||
|
||||
clear()
|
||||
|
||||
override def clear(): Unit = {
|
||||
super.clear()
|
||||
|
||||
// Make sure we start with the default test configs even after clear
|
||||
TestSQLContext.overrideConfs.map {
|
||||
case (key, value) => setConfString(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,3 +61,14 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel
|
|||
protected override def sqlContext: SQLContext = self
|
||||
}
|
||||
}
|
||||
|
||||
private[sql] object TestSQLContext {
|
||||
|
||||
/**
|
||||
* A map used to store all confs that need to be overridden in sql/core unit tests.
|
||||
*/
|
||||
val overrideConfs: Map[String, String] =
|
||||
Map(
|
||||
// Fewer shuffle partitions to speed up testing.
|
||||
SQLConf.SHUFFLE_PARTITIONS.key -> "5")
|
||||
}
|
||||
|
|
|
@ -116,18 +116,28 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
|
|||
override def executePlan(plan: LogicalPlan): this.QueryExecution =
|
||||
new this.QueryExecution(plan)
|
||||
|
||||
// Make sure we set those test specific confs correctly when we create
|
||||
// the SQLConf as well as when we call clear.
|
||||
override protected[sql] def createSession(): SQLSession = {
|
||||
new this.SQLSession()
|
||||
}
|
||||
|
||||
protected[hive] class SQLSession extends super.SQLSession {
|
||||
/** Fewer partitions to speed up testing. */
|
||||
protected[sql] override lazy val conf: SQLConf = new SQLConf {
|
||||
override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, 5)
|
||||
// TODO as in unit test, conf.clear() probably be called, all of the value will be cleared.
|
||||
// The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql"
|
||||
override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql")
|
||||
override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
|
||||
|
||||
clear()
|
||||
|
||||
override def clear(): Unit = {
|
||||
super.clear()
|
||||
|
||||
TestHiveContext.overrideConfs.map {
|
||||
case (key, value) => setConfString(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -455,3 +465,15 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[hive] object TestHiveContext {
|
||||
|
||||
/**
|
||||
* A map used to store all confs that need to be overridden in sql/hive unit tests.
|
||||
*/
|
||||
val overrideConfs: Map[String, String] =
|
||||
Map(
|
||||
// Fewer shuffle partitions to speed up testing.
|
||||
SQLConf.SHUFFLE_PARTITIONS.key -> "5"
|
||||
)
|
||||
}
|
||||
|
|
|
@ -173,6 +173,7 @@ object SparkSubmitClassLoaderTest extends Logging {
|
|||
def main(args: Array[String]) {
|
||||
Utils.configTestLog4j("INFO")
|
||||
val conf = new SparkConf()
|
||||
conf.set("spark.ui.enabled", "false")
|
||||
val sc = new SparkContext(conf)
|
||||
val hiveContext = new TestHiveContext(sc)
|
||||
val df = hiveContext.createDataFrame((1 to 100).map(i => (i, i))).toDF("i", "j")
|
||||
|
@ -264,6 +265,7 @@ object SparkSQLConfTest extends Logging {
|
|||
// For this simple test, we do not really clone this object.
|
||||
override def clone: SparkConf = this
|
||||
}
|
||||
conf.set("spark.ui.enabled", "false")
|
||||
val sc = new SparkContext(conf)
|
||||
val hiveContext = new TestHiveContext(sc)
|
||||
// Run a simple command to make sure all lazy vals in hiveContext get instantiated.
|
||||
|
@ -283,7 +285,8 @@ object SPARK_9757 extends QueryTest {
|
|||
val sparkContext = new SparkContext(
|
||||
new SparkConf()
|
||||
.set("spark.sql.hive.metastore.version", "0.13.1")
|
||||
.set("spark.sql.hive.metastore.jars", "maven"))
|
||||
.set("spark.sql.hive.metastore.jars", "maven")
|
||||
.set("spark.ui.enabled", "false"))
|
||||
|
||||
val hiveContext = new TestHiveContext(sparkContext)
|
||||
sqlContext = hiveContext
|
||||
|
|
|
@ -25,8 +25,10 @@ class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll {
|
|||
ignore("multiple instances not supported") {
|
||||
test("Multiple Hive Instances") {
|
||||
(1 to 10).map { i =>
|
||||
val conf = new SparkConf()
|
||||
conf.set("spark.ui.enabled", "false")
|
||||
val ts =
|
||||
new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", new SparkConf()))
|
||||
new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", conf))
|
||||
ts.executeSql("SHOW TABLES").toRdd.collect()
|
||||
ts.executeSql("SELECT * FROM src").toRdd.collect()
|
||||
ts.executeSql("SHOW TABLES").toRdd.collect()
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
|
|||
import org.apache.spark.sql.catalyst.expressions.Cast
|
||||
import org.apache.spark.sql.catalyst.plans.logical.Project
|
||||
import org.apache.spark.sql.hive._
|
||||
import org.apache.spark.sql.hive.test.TestHiveContext
|
||||
import org.apache.spark.sql.hive.test.TestHive
|
||||
import org.apache.spark.sql.hive.test.TestHive._
|
||||
|
||||
|
@ -1104,18 +1105,19 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
|
|||
|
||||
// "SET" itself returns all config variables currently specified in SQLConf.
|
||||
// TODO: Should we be listing the default here always? probably...
|
||||
assert(sql("SET").collect().size == 0)
|
||||
assert(sql("SET").collect().size === TestHiveContext.overrideConfs.size)
|
||||
|
||||
val defaults = collectResults(sql("SET"))
|
||||
assertResult(Set(testKey -> testVal)) {
|
||||
collectResults(sql(s"SET $testKey=$testVal"))
|
||||
}
|
||||
|
||||
assert(hiveconf.get(testKey, "") == testVal)
|
||||
assertResult(Set(testKey -> testVal))(collectResults(sql("SET")))
|
||||
assert(hiveconf.get(testKey, "") === testVal)
|
||||
assertResult(defaults ++ Set(testKey -> testVal))(collectResults(sql("SET")))
|
||||
|
||||
sql(s"SET ${testKey + testKey}=${testVal + testVal}")
|
||||
assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
|
||||
assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
|
||||
assertResult(defaults ++ Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) {
|
||||
collectResults(sql("SET"))
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue