[SPARK-9580] [SQL] Replace singletons in SQL tests

A fundamental limitation of the existing SQL tests is that *there is simply no way to create your own `SparkContext`*. This is a serious limitation because the user may wish to use a different master or config. As a case in point, `BroadcastJoinSuite` is entirely commented out because there is no way to make it pass with the existing infrastructure.

This patch removes the singletons `TestSQLContext` and `TestData`, and instead introduces a `SharedSQLContext` that starts a context per suite. Unfortunately the singletons were so ingrained in the SQL tests that this patch necessarily needed to touch *all* the SQL test files.

<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/8111)
<!-- Reviewable:end -->

Author: Andrew Or <andrew@databricks.com>

Closes #8111 from andrewor14/sql-tests-refactor.
This commit is contained in:
Andrew Or 2015-08-13 17:42:01 -07:00 committed by Reynold Xin
parent c50f97dafd
commit 8187b3ae47
96 changed files with 1461 additions and 1204 deletions

View file

@ -178,6 +178,16 @@ object MimaExcludes {
// SPARK-4751 Dynamic allocation for standalone mode
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.SparkContext.supportDynamicAllocation")
) ++ Seq(
// SPARK-9580: Remove SQL test singletons
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.test.LocalSQLContext$SQLSession"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.test.LocalSQLContext"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.test.TestSQLContext"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.test.TestSQLContext$")
) ++ Seq(
// SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs
ProblemFilters.exclude[IncompatibleResultTypeProblem](

View file

@ -319,6 +319,8 @@ object SQL {
lazy val settings = Seq(
initialCommands in console :=
"""
|import org.apache.spark.SparkContext
|import org.apache.spark.sql.SQLContext
|import org.apache.spark.sql.catalyst.analysis._
|import org.apache.spark.sql.catalyst.dsl._
|import org.apache.spark.sql.catalyst.errors._
@ -328,9 +330,14 @@ object SQL {
|import org.apache.spark.sql.catalyst.util._
|import org.apache.spark.sql.execution
|import org.apache.spark.sql.functions._
|import org.apache.spark.sql.test.TestSQLContext._
|import org.apache.spark.sql.types._""".stripMargin,
cleanupCommands in console := "sparkContext.stop()"
|import org.apache.spark.sql.types._
|
|val sc = new SparkContext("local[*]", "dev-shell")
|val sqlContext = new SQLContext(sc)
|import sqlContext.implicits._
|import sqlContext._
""".stripMargin,
cleanupCommands in console := "sc.stop()"
)
}
@ -340,8 +347,6 @@ object Hive {
javaOptions += "-XX:MaxPermSize=256m",
// Specially disable assertions since some Hive tests fail them
javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"),
// Multiple queries rely on the TestHive singleton. See comments there for more details.
parallelExecution in Test := false,
// Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings
// only for this subproject.
scalacOptions <<= scalacOptions map { currentOpts: Seq[String] =>
@ -349,6 +354,7 @@ object Hive {
},
initialCommands in console :=
"""
|import org.apache.spark.SparkContext
|import org.apache.spark.sql.catalyst.analysis._
|import org.apache.spark.sql.catalyst.dsl._
|import org.apache.spark.sql.catalyst.errors._

View file

@ -17,14 +17,10 @@
package org.apache.spark.sql.catalyst.analysis
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.types._
@ -42,7 +38,7 @@ case class UnresolvedTestPlan() extends LeafNode {
override def output: Seq[Attribute] = Nil
}
class AnalysisErrorSuite extends AnalysisTest with BeforeAndAfter {
class AnalysisErrorSuite extends AnalysisTest {
import TestRelations._
def errorTest(

View file

@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConversions._
import scala.collection.immutable
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
@ -41,10 +40,9 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
/**
@ -334,98 +332,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @since 1.3.0
*/
@Experimental
object implicits extends Serializable {
object implicits extends SQLImplicits with Serializable {
protected override def _sqlContext: SQLContext = self
}
// scalastyle:on
/**
* Converts $"col name" into an [[Column]].
* @since 1.3.0
*/
implicit class StringToColumn(val sc: StringContext) {
def $(args: Any*): ColumnName = {
new ColumnName(sc.s(args: _*))
}
}
/**
* An implicit conversion that turns a Scala `Symbol` into a [[Column]].
* @since 1.3.0
*/
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
/**
* Creates a DataFrame from an RDD of case classes or tuples.
* @since 1.3.0
*/
implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = {
DataFrameHolder(self.createDataFrame(rdd))
}
/**
* Creates a DataFrame from a local Seq of Product.
* @since 1.3.0
*/
implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder =
{
DataFrameHolder(self.createDataFrame(data))
}
// Do NOT add more implicit conversions. They are likely to break source compatibility by
// making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous
// because of [[DoubleRDDFunctions]].
/**
* Creates a single column DataFrame from an RDD[Int].
* @since 1.3.0
*/
implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = {
val dataType = IntegerType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
iter.map { v =>
row.setInt(0, v)
row: InternalRow
}
}
DataFrameHolder(
self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
/**
* Creates a single column DataFrame from an RDD[Long].
* @since 1.3.0
*/
implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = {
val dataType = LongType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
iter.map { v =>
row.setLong(0, v)
row: InternalRow
}
}
DataFrameHolder(
self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
/**
* Creates a single column DataFrame from an RDD[String].
* @since 1.3.0
*/
implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = {
val dataType = StringType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
iter.map { v =>
row.update(0, UTF8String.fromString(v))
row: InternalRow
}
}
DataFrameHolder(
self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
}
/**
* :: Experimental ::
* Creates a DataFrame from an RDD of case classes.

View file

@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
import org.apache.spark.sql.types.StructField
import org.apache.spark.unsafe.types.UTF8String
/**
* A collection of implicit methods for converting common Scala objects into [[DataFrame]]s.
*/
private[sql] abstract class SQLImplicits {
protected def _sqlContext: SQLContext
/**
* Converts $"col name" into an [[Column]].
* @since 1.3.0
*/
implicit class StringToColumn(val sc: StringContext) {
def $(args: Any*): ColumnName = {
new ColumnName(sc.s(args: _*))
}
}
/**
* An implicit conversion that turns a Scala `Symbol` into a [[Column]].
* @since 1.3.0
*/
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
/**
* Creates a DataFrame from an RDD of case classes or tuples.
* @since 1.3.0
*/
implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = {
DataFrameHolder(_sqlContext.createDataFrame(rdd))
}
/**
* Creates a DataFrame from a local Seq of Product.
* @since 1.3.0
*/
implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder =
{
DataFrameHolder(_sqlContext.createDataFrame(data))
}
// Do NOT add more implicit conversions. They are likely to break source compatibility by
// making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous
// because of [[DoubleRDDFunctions]].
/**
* Creates a single column DataFrame from an RDD[Int].
* @since 1.3.0
*/
implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = {
val dataType = IntegerType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
iter.map { v =>
row.setInt(0, v)
row: InternalRow
}
}
DataFrameHolder(
_sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
/**
* Creates a single column DataFrame from an RDD[Long].
* @since 1.3.0
*/
implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = {
val dataType = LongType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
iter.map { v =>
row.setLong(0, v)
row: InternalRow
}
}
DataFrameHolder(
_sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
/**
* Creates a single column DataFrame from an RDD[String].
* @since 1.3.0
*/
implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = {
val dataType = StringType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
iter.map { v =>
row.update(0, UTF8String.fromString(v))
row: InternalRow
}
}
DataFrameHolder(
_sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
}

View file

@ -27,6 +27,7 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
@ -34,7 +35,6 @@ import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
@ -48,14 +48,16 @@ public class JavaApplySchemaSuite implements Serializable {
@Before
public void setUp() {
sqlContext = TestSQLContext$.MODULE$;
javaCtx = new JavaSparkContext(sqlContext.sparkContext());
SparkContext context = new SparkContext("local[*]", "testing");
javaCtx = new JavaSparkContext(context);
sqlContext = new SQLContext(context);
}
@After
public void tearDown() {
javaCtx = null;
sqlContext.sparkContext().stop();
sqlContext = null;
javaCtx = null;
}
public static class Person implements Serializable {

View file

@ -17,44 +17,45 @@
package test.org.apache.spark.sql;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Ints;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.*;
import org.junit.*;
import scala.collection.JavaConversions;
import scala.collection.Seq;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import scala.collection.JavaConversions;
import scala.collection.Seq;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Ints;
import org.junit.*;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.types.*;
public class JavaDataFrameSuite {
private transient JavaSparkContext jsc;
private transient SQLContext context;
private transient TestSQLContext context;
@Before
public void setUp() {
// Trigger static initializer of TestData
TestData$.MODULE$.testData();
jsc = new JavaSparkContext(TestSQLContext.sparkContext());
context = TestSQLContext$.MODULE$;
SparkContext sc = new SparkContext("local[*]", "testing");
jsc = new JavaSparkContext(sc);
context = new TestSQLContext(sc);
context.loadTestData();
}
@After
public void tearDown() {
jsc = null;
context.sparkContext().stop();
context = null;
jsc = null;
}
@Test
@ -230,7 +231,7 @@ public class JavaDataFrameSuite {
@Test
public void testSampleBy() {
DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key"));
DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
Row[] actual = sampled.groupBy("key").count().orderBy("key").collect();
Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)};

View file

@ -23,12 +23,12 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.SparkContext;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.api.java.UDF2;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.DataTypes;
// The test suite itself is Serializable so that anonymous Function implementations can be
@ -40,12 +40,16 @@ public class JavaUDFSuite implements Serializable {
@Before
public void setUp() {
sqlContext = TestSQLContext$.MODULE$;
sc = new JavaSparkContext(sqlContext.sparkContext());
SparkContext _sc = new SparkContext("local[*]", "testing");
sqlContext = new SQLContext(_sc);
sc = new JavaSparkContext(_sc);
}
@After
public void tearDown() {
sqlContext.sparkContext().stop();
sqlContext = null;
sc = null;
}
@SuppressWarnings("unchecked")

View file

@ -21,13 +21,14 @@ import java.io.File;
import java.io.IOException;
import java.util.*;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
@ -52,8 +53,9 @@ public class JavaSaveLoadSuite {
@Before
public void setUp() throws IOException {
sqlContext = TestSQLContext$.MODULE$;
sc = new JavaSparkContext(sqlContext.sparkContext());
SparkContext _sc = new SparkContext("local[*]", "testing");
sqlContext = new SQLContext(_sc);
sc = new JavaSparkContext(_sc);
originalDefaultSource = sqlContext.conf().defaultDataSourceName();
path =
@ -71,6 +73,13 @@ public class JavaSaveLoadSuite {
df.registerTempTable("jsonTable");
}
@After
public void tearDown() {
sqlContext.sparkContext().stop();
sqlContext = null;
sc = null;
}
@Test
public void saveAndLoad() {
Map<String, String> options = new HashMap<String, String>();

View file

@ -18,24 +18,20 @@
package org.apache.spark.sql
import scala.concurrent.duration._
import scala.language.{implicitConversions, postfixOps}
import scala.language.postfixOps
import org.scalatest.concurrent.Eventually._
import org.apache.spark.Accumulators
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.storage.{StorageLevel, RDDBlockId}
case class BigData(s: String)
private case class BigData(s: String)
class CachedTableSuite extends QueryTest {
TestData // Load test tables.
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import ctx.sql
class CachedTableSuite extends QueryTest with SharedSQLContext {
import testImplicits._
def rddIdOf(tableName: String): Int = {
val executedPlan = ctx.table(tableName).queryExecution.executedPlan

View file

@ -21,16 +21,20 @@ import org.scalatest.Matchers._
import org.apache.spark.sql.execution.{Project, TungstenProject}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.SQLTestUtils
class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
import org.apache.spark.sql.TestData._
class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
import testImplicits._
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
override def sqlContext(): SQLContext = ctx
private lazy val booleanData = {
ctx.createDataFrame(ctx.sparkContext.parallelize(
Row(false, false) ::
Row(false, true) ::
Row(true, false) ::
Row(true, true) :: Nil),
StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType))))
}
test("column names with space") {
val df = Seq((1, "a")).toDF("name with space", "name.with.dot")
@ -258,7 +262,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
nullStrings.collect().toSeq.filter(r => r.getString(1) eq null))
checkAnswer(
ctx.sql("select isnull(null), isnull(1)"),
sql("select isnull(null), isnull(1)"),
Row(true, false))
}
@ -268,7 +272,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
nullStrings.collect().toSeq.filter(r => r.getString(1) ne null))
checkAnswer(
ctx.sql("select isnotnull(null), isnotnull('a')"),
sql("select isnotnull(null), isnotnull('a')"),
Row(false, true))
}
@ -289,7 +293,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil)
checkAnswer(
ctx.sql("select isnan(15), isnan('invalid')"),
sql("select isnan(15), isnan('invalid')"),
Row(false, false))
}
@ -309,7 +313,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
)
testData.registerTempTable("t")
checkAnswer(
ctx.sql(
sql(
"select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " +
" nanvl(b, e), nanvl(e, f) from t"),
Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0)
@ -433,13 +437,6 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
}
}
val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize(
Row(false, false) ::
Row(false, true) ::
Row(true, false) ::
Row(true, true) :: Nil),
StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType))))
test("&&") {
checkAnswer(
booleanData.filter($"a" && true),
@ -523,7 +520,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
)
checkAnswer(
ctx.sql("SELECT upper('aB'), ucase('cDe')"),
sql("SELECT upper('aB'), ucase('cDe')"),
Row("AB", "CDE"))
}
@ -544,7 +541,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
)
checkAnswer(
ctx.sql("SELECT lower('aB'), lcase('cDe')"),
sql("SELECT lower('aB'), lcase('cDe')"),
Row("ab", "cde"))
}

View file

@ -17,15 +17,13 @@
package org.apache.spark.sql
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{BinaryType, DecimalType}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.DecimalType
class DataFrameAggregateSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("groupBy") {
checkAnswer(

View file

@ -17,17 +17,15 @@
package org.apache.spark.sql
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
/**
* Test suite for functions in [[org.apache.spark.sql.functions]].
*/
class DataFrameFunctionsSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("array with column name") {
val df = Seq((0, 1)).toDF("a", "b")
@ -119,11 +117,11 @@ class DataFrameFunctionsSuite extends QueryTest {
test("constant functions") {
checkAnswer(
ctx.sql("SELECT E()"),
sql("SELECT E()"),
Row(scala.math.E)
)
checkAnswer(
ctx.sql("SELECT PI()"),
sql("SELECT PI()"),
Row(scala.math.Pi)
)
}
@ -153,7 +151,7 @@ class DataFrameFunctionsSuite extends QueryTest {
test("nvl function") {
checkAnswer(
ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
Row("x", "y", null))
}
@ -222,7 +220,7 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(-1)
)
checkAnswer(
ctx.sql("SELECT least(a, 2) as l from testData2 order by l"),
sql("SELECT least(a, 2) as l from testData2 order by l"),
Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2))
)
}
@ -233,7 +231,7 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(3)
)
checkAnswer(
ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"),
sql("SELECT greatest(a, 2) as g from testData2 order by g"),
Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3))
)
}

View file

@ -17,10 +17,10 @@
package org.apache.spark.sql
class DataFrameImplicitsSuite extends QueryTest {
import org.apache.spark.sql.test.SharedSQLContext
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("RDD of tuples") {
checkAnswer(

View file

@ -17,14 +17,12 @@
package org.apache.spark.sql
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
class DataFrameJoinSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("join - join using") {
val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
@ -59,7 +57,7 @@ class DataFrameJoinSuite extends QueryTest {
checkAnswer(
df1.join(df2, $"df1.key" === $"df2.key"),
ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key")
sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key")
.collect().toSeq)
}

View file

@ -19,11 +19,11 @@ package org.apache.spark.sql
import scala.collection.JavaConversions._
import org.apache.spark.sql.test.SharedSQLContext
class DataFrameNaFunctionsSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
def createDF(): DataFrame = {
Seq[(String, java.lang.Integer, java.lang.Double)](

View file

@ -19,20 +19,17 @@ package org.apache.spark.sql
import java.util.Random
import org.scalatest.Matchers._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.test.SharedSQLContext
class DataFrameStatSuite extends QueryTest {
private val sqlCtx = org.apache.spark.sql.test.TestSQLContext
import sqlCtx.implicits._
class DataFrameStatSuite extends QueryTest with SharedSQLContext {
import testImplicits._
private def toLetter(i: Int): String = (i + 97).toChar.toString
test("sample with replacement") {
val n = 100
val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
checkAnswer(
data.sample(withReplacement = true, 0.05, seed = 13),
Seq(5, 10, 52, 73).map(Row(_))
@ -41,7 +38,7 @@ class DataFrameStatSuite extends QueryTest {
test("sample without replacement") {
val n = 100
val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
checkAnswer(
data.sample(withReplacement = false, 0.05, seed = 13),
Seq(16, 23, 88, 100).map(Row(_))
@ -50,7 +47,7 @@ class DataFrameStatSuite extends QueryTest {
test("randomSplit") {
val n = 600
val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
for (seed <- 1 to 5) {
val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")
@ -167,7 +164,7 @@ class DataFrameStatSuite extends QueryTest {
}
test("Frequent Items 2") {
val rows = sqlCtx.sparkContext.parallelize(Seq.empty[Int], 4)
val rows = ctx.sparkContext.parallelize(Seq.empty[Int], 4)
// this is a regression test, where when merging partitions, we omitted values with higher
// counts than those that existed in the map when the map was full. This test should also fail
// if anything like SPARK-9614 is observed once again
@ -185,7 +182,7 @@ class DataFrameStatSuite extends QueryTest {
}
test("sampleBy") {
val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key"))
val df = ctx.range(0, 100).select((col("id") % 3).as("key"))
val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
checkAnswer(
sampled.groupBy("key").count().orderBy("key"),

View file

@ -23,18 +23,12 @@ import scala.language.postfixOps
import scala.util.Random
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.functions._
import org.apache.spark.sql.execution.datasources.json.JSONRelation
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils}
import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SharedSQLContext}
class DataFrameSuite extends QueryTest with SQLTestUtils {
import org.apache.spark.sql.TestData._
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.implicits._
class DataFrameSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("analysis error should be eagerly reported") {
// Eager analysis.

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
/**
@ -27,10 +27,8 @@ import org.apache.spark.sql.types._
* This is here for now so I can make sure Tungsten project is tested without refactoring existing
* end-to-end test infra. In the long run this should just go away.
*/
class DataFrameTungstenSuite extends QueryTest with SQLTestUtils {
override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.implicits._
class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("test simple types") {
withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") {

View file

@ -22,19 +22,18 @@ import java.text.SimpleDateFormat
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.types.CalendarInterval
class DateFunctionsSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class DateFunctionsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("function current_date") {
val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis())
val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0))
val d2 = DateTimeUtils.fromJavaDate(
ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0))
sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0))
val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis())
assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1)
}
@ -44,9 +43,9 @@ class DateFunctionsSuite extends QueryTest {
val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1))
// Execution in one query should return the same value
checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""),
checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""),
Row(true))
assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp(
assert(math.abs(sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp(
0).getTime - System.currentTimeMillis()) < 5000)
}

View file

@ -17,22 +17,15 @@
package org.apache.spark.sql
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.test.SharedSQLContext
class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
// Ensures tables are loaded.
TestData
class JoinSuite extends QueryTest with SharedSQLContext {
import testImplicits._
override def sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import ctx.logicalPlanToSparkQuery
setupTestData()
test("equi-join is hash-join") {
val x = testData2.as("x")
@ -43,7 +36,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
}
def assertJoin(sqlString: String, c: Class[_]): Any = {
val df = ctx.sql(sqlString)
val df = sql(sqlString)
val physical = df.queryExecution.sparkPlan
val operators = physical.collect {
case j: ShuffledHashJoin => j
@ -126,7 +119,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
test("broadcasted hash join operator selection") {
ctx.cacheManager.clearCache()
ctx.sql("CACHE TABLE testData")
sql("CACHE TABLE testData")
for (sortMergeJoinEnabled <- Seq(true, false)) {
withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") {
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") {
@ -141,12 +134,12 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
}
}
}
ctx.sql("UNCACHE TABLE testData")
sql("UNCACHE TABLE testData")
}
test("broadcasted hash outer join operator selection") {
ctx.cacheManager.clearCache()
ctx.sql("CACHE TABLE testData")
sql("CACHE TABLE testData")
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
Seq(
("SELECT * FROM testData LEFT JOIN testData2 ON key = a",
@ -167,7 +160,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
classOf[BroadcastHashOuterJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
}
ctx.sql("UNCACHE TABLE testData")
sql("UNCACHE TABLE testData")
}
test("multiple-key equi-join is hash-join") {
@ -279,7 +272,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
// Make sure we are choosing left.outputPartitioning as the
// outputPartitioning for the outer join operator.
checkAnswer(
ctx.sql(
sql(
"""
|SELECT l.N, count(*)
|FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
@ -293,7 +286,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
Row(6, 1) :: Nil)
checkAnswer(
ctx.sql(
sql(
"""
|SELECT r.a, count(*)
|FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
@ -339,7 +332,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
// Make sure we are choosing right.outputPartitioning as the
// outputPartitioning for the outer join operator.
checkAnswer(
ctx.sql(
sql(
"""
|SELECT l.a, count(*)
|FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
@ -348,7 +341,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
Row(null, 6))
checkAnswer(
ctx.sql(
sql(
"""
|SELECT r.N, count(*)
|FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
@ -400,7 +393,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
// Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator.
checkAnswer(
ctx.sql(
sql(
"""
|SELECT l.a, count(*)
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
@ -409,7 +402,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
Row(null, 10))
checkAnswer(
ctx.sql(
sql(
"""
|SELECT r.N, count(*)
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
@ -424,7 +417,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
Row(null, 4) :: Nil)
checkAnswer(
ctx.sql(
sql(
"""
|SELECT l.N, count(*)
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
@ -439,7 +432,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
Row(null, 4) :: Nil)
checkAnswer(
ctx.sql(
sql(
"""
|SELECT r.a, count(*)
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
@ -450,7 +443,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
test("broadcasted left semi join operator selection") {
ctx.cacheManager.clearCache()
ctx.sql("CACHE TABLE testData")
sql("CACHE TABLE testData")
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
Seq(
@ -469,11 +462,11 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
}
}
ctx.sql("UNCACHE TABLE testData")
sql("UNCACHE TABLE testData")
}
test("left semi join") {
val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
checkAnswer(df,
Row(1, 1) ::
Row(1, 2) ::

View file

@ -17,10 +17,10 @@
package org.apache.spark.sql
class JsonFunctionsSuite extends QueryTest {
import org.apache.spark.sql.test.SharedSQLContext
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("function get_json_object") {
val df: DataFrame = Seq(("""{"name": "alice", "age": 5}""", "")).toDF("a", "b")

View file

@ -19,12 +19,11 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
class ListTablesSuite extends QueryTest with BeforeAndAfter {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext {
import testImplicits._
private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value")
@ -42,7 +41,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
Row("ListTablesSuiteTable", true))
checkAnswer(
ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
@ -55,7 +54,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
Row("ListTablesSuiteTable", true))
checkAnswer(
ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
@ -67,13 +66,13 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
StructField("tableName", StringType, false) ::
StructField("isTemporary", BooleanType, false) :: Nil)
Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach {
Seq(ctx.tables(), sql("SHOW TABLes")).foreach {
case tableDF =>
assert(expectedSchema === tableDF.schema)
tableDF.registerTempTable("tables")
checkAnswer(
ctx.sql(
sql(
"SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
Row(true, "ListTablesSuiteTable")
)

View file

@ -19,18 +19,16 @@ package org.apache.spark.sql
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions.{log => logarithm}
import org.apache.spark.sql.test.SharedSQLContext
private object MathExpressionsTestData {
case class DoubleData(a: java.lang.Double, b: java.lang.Double)
case class NullDoubles(a: java.lang.Double)
}
class MathExpressionsSuite extends QueryTest {
class MathExpressionsSuite extends QueryTest with SharedSQLContext {
import MathExpressionsTestData._
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import testImplicits._
private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF()
@ -149,7 +147,7 @@ class MathExpressionsSuite extends QueryTest {
test("toDegrees") {
testOneToOneMathFunction(toDegrees, math.toDegrees)
checkAnswer(
ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"),
sql("SELECT degrees(0), degrees(1), degrees(1.5)"),
Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5)))
)
}
@ -157,7 +155,7 @@ class MathExpressionsSuite extends QueryTest {
test("toRadians") {
testOneToOneMathFunction(toRadians, math.toRadians)
checkAnswer(
ctx.sql("SELECT radians(0), radians(1), radians(1.5)"),
sql("SELECT radians(0), radians(1), radians(1.5)"),
Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5)))
)
}
@ -169,7 +167,7 @@ class MathExpressionsSuite extends QueryTest {
test("ceil and ceiling") {
testOneToOneMathFunction(ceil, math.ceil)
checkAnswer(
ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
Row(0.0, 1.0, 2.0))
}
@ -214,7 +212,7 @@ class MathExpressionsSuite extends QueryTest {
val pi = 3.1415
checkAnswer(
ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"),
Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
@ -233,7 +231,7 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneMathFunction[Double](signum, math.signum)
checkAnswer(
ctx.sql("SELECT sign(10), signum(-11)"),
sql("SELECT sign(10), signum(-11)"),
Row(1, -1))
}
@ -241,7 +239,7 @@ class MathExpressionsSuite extends QueryTest {
testTwoToOneMathFunction(pow, pow, math.pow)
checkAnswer(
ctx.sql("SELECT pow(1, 2), power(2, 1)"),
sql("SELECT pow(1, 2), power(2, 1)"),
Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1)))
)
}
@ -280,7 +278,7 @@ class MathExpressionsSuite extends QueryTest {
test("log / ln") {
testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log)
checkAnswer(
ctx.sql("SELECT ln(0), ln(1), ln(1.5)"),
sql("SELECT ln(0), ln(1), ln(1.5)"),
Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5)))
)
}
@ -375,7 +373,7 @@ class MathExpressionsSuite extends QueryTest {
df.select(log2("b") + log2("a")),
Row(1))
checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
checkAnswer(sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
}
test("sqrt") {
@ -384,13 +382,13 @@ class MathExpressionsSuite extends QueryTest {
df.select(sqrt("a"), sqrt("b")),
Row(1.0, 2.0))
checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null))
checkAnswer(sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null))
checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null))
}
test("negative") {
checkAnswer(
ctx.sql("SELECT negative(1), negative(0), negative(-1)"),
sql("SELECT negative(1), negative(0), negative(-1)"),
Row(-1, 0, 1))
}

View file

@ -71,12 +71,6 @@ class QueryTest extends PlanTest {
checkAnswer(df, expectedAnswer.collect())
}
def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) {
test(sqlString) {
checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
}
}
/**
* Asserts that a given [[DataFrame]] will be executed using the given number of cached results.
*/

View file

@ -20,13 +20,12 @@ package org.apache.spark.sql
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
class RowSuite extends SparkFunSuite {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class RowSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
test("create row") {
val expected = new GenericMutableRow(4)

View file

@ -17,11 +17,10 @@
package org.apache.spark.sql
import org.apache.spark.sql.test.SharedSQLContext
class SQLConfSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
class SQLConfSuite extends QueryTest with SharedSQLContext {
private val testKey = "test.key.0"
private val testVal = "test.val.0"
@ -52,21 +51,21 @@ class SQLConfSuite extends QueryTest {
test("parse SQL set commands") {
ctx.conf.clear()
ctx.sql(s"set $testKey=$testVal")
sql(s"set $testKey=$testVal")
assert(ctx.getConf(testKey, testVal + "_") === testVal)
assert(ctx.getConf(testKey, testVal + "_") === testVal)
ctx.sql("set some.property=20")
sql("set some.property=20")
assert(ctx.getConf("some.property", "0") === "20")
ctx.sql("set some.property = 40")
sql("set some.property = 40")
assert(ctx.getConf("some.property", "0") === "40")
val key = "spark.sql.key"
val vs = "val0,val_1,val2.3,my_table"
ctx.sql(s"set $key=$vs")
sql(s"set $key=$vs")
assert(ctx.getConf(key, "0") === vs)
ctx.sql(s"set $key=")
sql(s"set $key=")
assert(ctx.getConf(key, "0") === "")
ctx.conf.clear()
@ -74,14 +73,14 @@ class SQLConfSuite extends QueryTest {
test("deprecated property") {
ctx.conf.clear()
ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
assert(ctx.conf.numShufflePartitions === 10)
}
test("invalid conf value") {
ctx.conf.clear()
val e = intercept[IllegalArgumentException] {
ctx.sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
}
assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10")
}

View file

@ -17,16 +17,17 @@
package org.apache.spark.sql
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext
class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
class SQLContextSuite extends SparkFunSuite with SharedSQLContext {
override def afterAll(): Unit = {
try {
SQLContext.setLastInstantiatedContext(ctx)
} finally {
super.afterAll()
}
}
test("getOrCreate instantiates SQLContext") {

View file

@ -19,28 +19,23 @@ package org.apache.spark.sql
import java.sql.Timestamp
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types._
/** A SQL Dialect for testing purpose, and it can not be nested type */
class MyDialect extends DefaultParserDialect
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
// Make sure the tables are loaded.
TestData
class SQLQuerySuite extends QueryTest with SharedSQLContext {
import testImplicits._
val sqlContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.implicits._
import sqlContext.sql
setupTestData()
test("having clause") {
Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav")
@ -60,7 +55,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
}
test("show functions") {
checkAnswer(sql("SHOW functions"), FunctionRegistry.builtin.listFunction().sorted.map(Row(_)))
checkAnswer(sql("SHOW functions"),
FunctionRegistry.builtin.listFunction().sorted.map(Row(_)))
}
test("describe functions") {
@ -178,7 +174,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index")
// we except the id is materialized once
val idUDF = udf(() => UUID.randomUUID().toString)
val idUDF = org.apache.spark.sql.functions.udf(() => UUID.randomUUID().toString)
val dfWithId = df.withColumn("id", idUDF())
// Make a new DataFrame (actually the same reference to the old one)
@ -712,9 +708,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
checkAnswer(
sql(
"""
|SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3
""".stripMargin),
"SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"),
Row(2, 1, 2, 2, 1))
}
@ -1161,7 +1155,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
validateMetadata(sql("SELECT * FROM personWithMeta"))
validateMetadata(sql("SELECT id, name FROM personWithMeta"))
validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId"))
validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId"))
validateMetadata(sql(
"SELECT name, salary FROM personWithMeta JOIN salary ON id = personId"))
}
test("SPARK-3371 Renaming a function expression with group by gives error") {
@ -1627,7 +1622,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
.toDF("num", "str")
df.registerTempTable("1one")
checkAnswer(sqlContext.sql("select count(num) from 1one"), Row(10))
checkAnswer(sql("select count(num) from 1one"), Row(10))
sqlContext.dropTempTable("1one")
}

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.sql.{Date, Timestamp}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext
case class ReflectData(
stringField: String,
@ -71,17 +72,15 @@ case class ComplexReflectData(
mapFieldContainsNull: Map[Int, Option[Long]],
dataField: Data)
class ScalaReflectionRelationSuite extends SparkFunSuite {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
test("query case class RDD") {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))
Seq(data).toDF().registerTempTable("reflectData")
assert(ctx.sql("SELECT * FROM reflectData").collect().head ===
assert(sql("SELECT * FROM reflectData").collect().head ===
Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"),
new Timestamp(12345), Seq(1, 2, 3)))
@ -91,7 +90,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
val data = NullReflectData(null, null, null, null, null, null, null)
Seq(data).toDF().registerTempTable("reflectNullData")
assert(ctx.sql("SELECT * FROM reflectNullData").collect().head ===
assert(sql("SELECT * FROM reflectNullData").collect().head ===
Row.fromSeq(Seq.fill(7)(null)))
}
@ -99,7 +98,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
val data = OptionalReflectData(None, None, None, None, None, None, None)
Seq(data).toDF().registerTempTable("reflectOptionalData")
assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head ===
assert(sql("SELECT * FROM reflectOptionalData").collect().head ===
Row.fromSeq(Seq.fill(7)(null)))
}
@ -107,7 +106,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
test("query binary data") {
Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary")
val result = ctx.sql("SELECT data FROM reflectBinary")
val result = sql("SELECT data FROM reflectBinary")
.collect().head(0).asInstanceOf[Array[Byte]]
assert(result.toSeq === Seq[Byte](1))
}
@ -126,7 +125,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
Nested(None, "abc")))
Seq(data).toDF().registerTempTable("reflectComplexData")
assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head ===
assert(sql("SELECT * FROM reflectComplexData").collect().head ===
Row(
Seq(1, 2, 3),
Seq(1, 2, null),

View file

@ -19,13 +19,12 @@ package org.apache.spark.sql
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.sql.test.SharedSQLContext
class SerializationSuite extends SparkFunSuite {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
class SerializationSuite extends SparkFunSuite with SharedSQLContext {
test("[SPARK-5235] SQLContext should be serializable") {
val sqlContext = new SQLContext(ctx.sparkContext)
new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext)
val _sqlContext = new SQLContext(sqlContext.sparkContext)
new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext)
}
}

View file

@ -18,13 +18,12 @@
package org.apache.spark.sql
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.Decimal
class StringFunctionsSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class StringFunctionsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("string concat") {
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")

View file

@ -1,197 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test._
case class TestData(key: Int, value: String)
object TestData {
val testData = TestSQLContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString))).toDF()
testData.registerTempTable("testData")
val negativeData = TestSQLContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(-i, (-i).toString))).toDF()
negativeData.registerTempTable("negativeData")
case class LargeAndSmallInts(a: Int, b: Int)
val largeAndSmallInts =
TestSQLContext.sparkContext.parallelize(
LargeAndSmallInts(2147483644, 1) ::
LargeAndSmallInts(1, 2) ::
LargeAndSmallInts(2147483645, 1) ::
LargeAndSmallInts(2, 2) ::
LargeAndSmallInts(2147483646, 1) ::
LargeAndSmallInts(3, 2) :: Nil).toDF()
largeAndSmallInts.registerTempTable("largeAndSmallInts")
case class TestData2(a: Int, b: Int)
val testData2 =
TestSQLContext.sparkContext.parallelize(
TestData2(1, 1) ::
TestData2(1, 2) ::
TestData2(2, 1) ::
TestData2(2, 2) ::
TestData2(3, 1) ::
TestData2(3, 2) :: Nil, 2).toDF()
testData2.registerTempTable("testData2")
case class DecimalData(a: BigDecimal, b: BigDecimal)
val decimalData =
TestSQLContext.sparkContext.parallelize(
DecimalData(1, 1) ::
DecimalData(1, 2) ::
DecimalData(2, 1) ::
DecimalData(2, 2) ::
DecimalData(3, 1) ::
DecimalData(3, 2) :: Nil).toDF()
decimalData.registerTempTable("decimalData")
case class BinaryData(a: Array[Byte], b: Int)
val binaryData =
TestSQLContext.sparkContext.parallelize(
BinaryData("12".getBytes(), 1) ::
BinaryData("22".getBytes(), 5) ::
BinaryData("122".getBytes(), 3) ::
BinaryData("121".getBytes(), 2) ::
BinaryData("123".getBytes(), 4) :: Nil).toDF()
binaryData.registerTempTable("binaryData")
case class TestData3(a: Int, b: Option[Int])
val testData3 =
TestSQLContext.sparkContext.parallelize(
TestData3(1, None) ::
TestData3(2, Some(2)) :: Nil).toDF()
testData3.registerTempTable("testData3")
case class UpperCaseData(N: Int, L: String)
val upperCaseData =
TestSQLContext.sparkContext.parallelize(
UpperCaseData(1, "A") ::
UpperCaseData(2, "B") ::
UpperCaseData(3, "C") ::
UpperCaseData(4, "D") ::
UpperCaseData(5, "E") ::
UpperCaseData(6, "F") :: Nil).toDF()
upperCaseData.registerTempTable("upperCaseData")
case class LowerCaseData(n: Int, l: String)
val lowerCaseData =
TestSQLContext.sparkContext.parallelize(
LowerCaseData(1, "a") ::
LowerCaseData(2, "b") ::
LowerCaseData(3, "c") ::
LowerCaseData(4, "d") :: Nil).toDF()
lowerCaseData.registerTempTable("lowerCaseData")
case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
val arrayData =
TestSQLContext.sparkContext.parallelize(
ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
arrayData.toDF().registerTempTable("arrayData")
case class MapData(data: scala.collection.Map[Int, String])
val mapData =
TestSQLContext.sparkContext.parallelize(
MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
MapData(Map(1 -> "a4", 2 -> "b4")) ::
MapData(Map(1 -> "a5")) :: Nil)
mapData.toDF().registerTempTable("mapData")
case class StringData(s: String)
val repeatedData =
TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
repeatedData.toDF().registerTempTable("repeatedData")
val nullableRepeatedData =
TestSQLContext.sparkContext.parallelize(
List.fill(2)(StringData(null)) ++
List.fill(2)(StringData("test")))
nullableRepeatedData.toDF().registerTempTable("nullableRepeatedData")
case class NullInts(a: Integer)
val nullInts =
TestSQLContext.sparkContext.parallelize(
NullInts(1) ::
NullInts(2) ::
NullInts(3) ::
NullInts(null) :: Nil
).toDF()
nullInts.registerTempTable("nullInts")
val allNulls =
TestSQLContext.sparkContext.parallelize(
NullInts(null) ::
NullInts(null) ::
NullInts(null) ::
NullInts(null) :: Nil).toDF()
allNulls.registerTempTable("allNulls")
case class NullStrings(n: Int, s: String)
val nullStrings =
TestSQLContext.sparkContext.parallelize(
NullStrings(1, "abc") ::
NullStrings(2, "ABC") ::
NullStrings(3, null) :: Nil).toDF()
nullStrings.registerTempTable("nullStrings")
case class TableName(tableName: String)
TestSQLContext
.sparkContext
.parallelize(TableName("test") :: Nil)
.toDF()
.registerTempTable("tableName")
val unparsedStrings =
TestSQLContext.sparkContext.parallelize(
"1, A1, true, null" ::
"2, B2, false, null" ::
"3, C3, true, null" ::
"4, D4, true, 2147483644" :: Nil)
case class IntField(i: Int)
// An RDD with 4 elements and 8 partitions
val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
withEmptyParts.toDF().registerTempTable("withEmptyParts")
case class Person(id: Int, name: String, age: Int)
case class Salary(personId: Int, salary: Double)
val person = TestSQLContext.sparkContext.parallelize(
Person(0, "mike", 30) ::
Person(1, "jim", 20) :: Nil).toDF()
person.registerTempTable("person")
val salary = TestSQLContext.sparkContext.parallelize(
Salary(0, 2000.0) ::
Salary(1, 1000.0) :: Nil).toDF()
salary.registerTempTable("salary")
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
val complexData =
TestSQLContext.sparkContext.parallelize(
ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true)
:: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false)
:: Nil).toDF()
complexData.registerTempTable("complexData")
}

View file

@ -17,16 +17,13 @@
package org.apache.spark.sql
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
case class FunctionResult(f1: String, f2: String)
private case class FunctionResult(f1: String, f2: String)
class UDFSuite extends QueryTest with SQLTestUtils {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
override def sqlContext(): SQLContext = ctx
class UDFSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("built-in fixed arity expressions") {
val df = ctx.emptyDataFrame
@ -57,7 +54,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
test("SPARK-8003 spark_partition_id") {
val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying")
df.registerTempTable("tmp_table")
checkAnswer(ctx.sql("select spark_partition_id() from tmp_table").toDF(), Row(0))
checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0))
ctx.dropTempTable("tmp_table")
}
@ -66,9 +63,9 @@ class UDFSuite extends QueryTest with SQLTestUtils {
val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id")
data.write.parquet(dir.getCanonicalPath)
ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table")
val answer = ctx.sql("select input_file_name() from test_table").head().getString(0)
val answer = sql("select input_file_name() from test_table").head().getString(0)
assert(answer.contains(dir.getCanonicalPath))
assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2)
assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2)
ctx.dropTempTable("test_table")
}
}
@ -91,17 +88,17 @@ class UDFSuite extends QueryTest with SQLTestUtils {
test("Simple UDF") {
ctx.udf.register("strLenScala", (_: String).length)
assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4)
assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4)
}
test("ZeroArgument UDF") {
ctx.udf.register("random0", () => { Math.random()})
assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0)
assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)
}
test("TwoArgument UDF") {
ctx.udf.register("strLenScala", (_: String).length + (_: Int))
assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
}
test("UDF in a WHERE") {
@ -112,7 +109,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
df.registerTempTable("integerData")
val result =
ctx.sql("SELECT * FROM integerData WHERE oneArgFilter(key)")
sql("SELECT * FROM integerData WHERE oneArgFilter(key)")
assert(result.count() === 20)
}
@ -124,7 +121,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
df.registerTempTable("groupData")
val result =
ctx.sql(
sql(
"""
| SELECT g, SUM(v) as s
| FROM groupData
@ -143,7 +140,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
df.registerTempTable("groupData")
val result =
ctx.sql(
sql(
"""
| SELECT SUM(v)
| FROM groupData
@ -163,7 +160,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
df.registerTempTable("groupData")
val result =
ctx.sql(
sql(
"""
| SELECT timesHundred(SUM(v)) as v100
| FROM groupData
@ -178,7 +175,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
val result =
ctx.sql("SELECT returnStruct('test', 'test2') as ret")
sql("SELECT returnStruct('test', 'test2') as ret")
.select($"ret.f1").head().getString(0)
assert(result === "test")
}
@ -186,12 +183,12 @@ class UDFSuite extends QueryTest with SQLTestUtils {
test("udf that is transformed") {
ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y))
// 1 + 1 is constant folded causing a transformation.
assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
}
test("type coercion for udf inputs") {
ctx.udf.register("intExpected", (x: Int) => x)
// pass a decimal to intExpected.
assert(ctx.sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
}
}

View file

@ -24,6 +24,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.OpenHashSet
@ -66,10 +67,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
private[spark] override def asNullable: MyDenseVectorUDT = this
}
class UserDefinedTypeSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class UserDefinedTypeSuite extends QueryTest with SharedSQLContext {
import testImplicits._
private lazy val pointsRDD = Seq(
MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
@ -94,7 +93,7 @@ class UserDefinedTypeSuite extends QueryTest {
ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector])
pointsRDD.registerTempTable("points")
checkAnswer(
ctx.sql("SELECT testType(features) from points"),
sql("SELECT testType(features) from points"),
Seq(Row(true), Row(true)))
}

View file

@ -19,18 +19,16 @@ package org.apache.spark.sql.columnar
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, Row, TestData}
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
class InMemoryColumnarQuerySuite extends QueryTest {
// Make sure the tables are loaded.
TestData
class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
import testImplicits._
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import ctx.{logicalPlanToSparkQuery, sql}
setupTestData()
test("simple columnar query") {
val plan = ctx.executePlan(testData.logicalPlan).executedPlan

View file

@ -17,20 +17,19 @@
package org.apache.spark.sql.columnar
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize
private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning
override protected def beforeAll(): Unit = {
super.beforeAll()
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10)
@ -44,19 +43,17 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi
ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
// Enable in-memory table scan accumulators
ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
}
override protected def afterAll(): Unit = {
ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
}
before {
ctx.cacheTable("pruningData")
}
after {
override protected def afterAll(): Unit = {
try {
ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
ctx.uncacheTable("pruningData")
} finally {
super.afterAll()
}
}
// Comparisons
@ -110,7 +107,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi
expectedQueryResult: => Seq[Int]): Unit = {
test(query) {
val df = ctx.sql(query)
val df = sql(query)
val queryExecution = df.queryExecution
assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") {

View file

@ -19,8 +19,9 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.test.SharedSQLContext
class ExchangeSuite extends SparkPlanTest {
class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
test("shuffling UnsafeRows in exchange") {
val input = (1 to 1000).map(Tuple1.apply)
checkAnswer(

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.SparkFunSuite
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.{execution, Row, SQLConf}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans._
@ -27,19 +27,18 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test.TestSQLContext.planner._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution}
class PlannerSuite extends SparkFunSuite with SQLTestUtils {
class PlannerSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
override def sqlContext: SQLContext = TestSQLContext
setupTestData()
private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
val _ctx = ctx
import _ctx.planner._
val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption)
val planned =
plannedOption.getOrElse(
@ -54,6 +53,8 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
}
test("unions are collapsed") {
val _ctx = ctx
import _ctx.planner._
val query = testData.unionAll(testData).unionAll(testData).logicalPlan
val planned = BasicOperators(query).head
val logicalUnions = query collect { case u: logical.Union => u }
@ -81,14 +82,14 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = {
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold)
ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold)
val fields = fieldTypes.zipWithIndex.map {
case (dataType, index) => StructField(s"c${index}", dataType, true)
} :+ StructField("key", IntegerType, true)
val schema = StructType(fields)
val row = Row.fromSeq(Seq.fill(fields.size)(null))
val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil)
createDataFrame(rowRDD, schema).registerTempTable("testLimit")
val rowRDD = ctx.sparkContext.parallelize(row :: Nil)
ctx.createDataFrame(rowRDD, schema).registerTempTable("testLimit")
val planned = sql(
"""
@ -102,10 +103,10 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
dropTempTable("testLimit")
ctx.dropTempTable("testLimit")
}
val origThreshold = conf.autoBroadcastJoinThreshold
val origThreshold = ctx.conf.autoBroadcastJoinThreshold
val simpleTypes =
NullType ::
@ -137,18 +138,18 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
checkPlan(complexTypes, newThreshold = 901617)
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
}
test("InMemoryRelation statistics propagation") {
val origThreshold = conf.autoBroadcastJoinThreshold
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920)
val origThreshold = ctx.conf.autoBroadcastJoinThreshold
ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920)
testData.limit(3).registerTempTable("tiny")
sql("CACHE TABLE tiny")
val a = testData.as("a")
val b = table("tiny").as("b")
val b = ctx.table("tiny").as("b")
val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan
val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
@ -157,12 +158,12 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
}
test("efficient limit -> project -> sort") {
val query = testData.sort('key).select('value).limit(2).logicalPlan
val planned = planner.TakeOrderedAndProject(query)
val planned = ctx.planner.TakeOrderedAndProject(query)
assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
}

View file

@ -21,11 +21,11 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StructType, StringType}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StringType}
import org.apache.spark.unsafe.types.UTF8String
class RowFormatConvertersSuite extends SparkPlanTest {
class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext {
private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect {
case c: ConvertToUnsafe => c
@ -39,20 +39,20 @@ class RowFormatConvertersSuite extends SparkPlanTest {
test("planner should insert unsafe->safe conversions when required") {
val plan = Limit(10, outputsUnsafe)
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe])
}
test("filter can process unsafe rows") {
val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe)
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(getConverters(preparedPlan).size === 1)
assert(preparedPlan.outputsUnsafeRows)
}
test("filter can process safe rows") {
val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe)
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(getConverters(preparedPlan).isEmpty)
assert(!preparedPlan.outputsUnsafeRows)
}
@ -67,33 +67,33 @@ class RowFormatConvertersSuite extends SparkPlanTest {
test("union requires all of its input rows' formats to agree") {
val plan = Union(Seq(outputsSafe, outputsUnsafe))
assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows)
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(preparedPlan.outputsUnsafeRows)
}
test("union can process safe rows") {
val plan = Union(Seq(outputsSafe, outputsSafe))
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(!preparedPlan.outputsUnsafeRows)
}
test("union can process unsafe rows") {
val plan = Union(Seq(outputsUnsafe, outputsUnsafe))
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(preparedPlan.outputsUnsafeRows)
}
test("round trip with ConvertToUnsafe and ConvertToSafe") {
val input = Seq(("hello", 1), ("world", 2))
checkAnswer(
TestSQLContext.createDataFrame(input),
ctx.createDataFrame(input),
plan => ConvertToSafe(ConvertToUnsafe(plan)),
input.map(Row.fromTuple)
)
}
test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") {
SparkPlan.currentContext.set(TestSQLContext)
SparkPlan.currentContext.set(ctx)
val schema = ArrayType(StringType)
val rows = (1 to 100).map { i =>
InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString))))

View file

@ -19,8 +19,9 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.test.SharedSQLContext
class SortSuite extends SparkPlanTest {
class SortSuite extends SparkPlanTest with SharedSQLContext {
// This test was originally added as an example of how to use [[SparkPlanTest]];
// it's not designed to be a comprehensive test of ExternalSort.

View file

@ -17,29 +17,27 @@
package org.apache.spark.sql.execution
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row}
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.util._
/**
* Base class for writing tests for individual physical operators. For an example of how this
* class's test helper methods can be used, see [[SortSuite]].
*/
class SparkPlanTest extends SparkFunSuite {
protected def sqlContext: SQLContext = TestSQLContext
private[sql] abstract class SparkPlanTest extends SparkFunSuite {
protected def _sqlContext: SQLContext
/**
* Creates a DataFrame from a local Seq of Product.
*/
implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = {
sqlContext.implicits.localSeqToDataFrameHolder(data)
_sqlContext.implicits.localSeqToDataFrameHolder(data)
}
/**
@ -100,7 +98,7 @@ class SparkPlanTest extends SparkFunSuite {
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match {
SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, _sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@ -124,7 +122,7 @@ class SparkPlanTest extends SparkFunSuite {
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean = true): Unit = {
SparkPlanTest.checkAnswer(
input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match {
input, planFunction, expectedPlanFunction, sortAnswers, _sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@ -151,13 +149,13 @@ object SparkPlanTest {
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean,
sqlContext: SQLContext): Option[String] = {
_sqlContext: SQLContext): Option[String] = {
val outputPlan = planFunction(input.queryExecution.sparkPlan)
val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
val expectedAnswer: Seq[Row] = try {
executePlan(expectedOutputPlan, sqlContext)
executePlan(expectedOutputPlan, _sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@ -172,7 +170,7 @@ object SparkPlanTest {
}
val actualAnswer: Seq[Row] = try {
executePlan(outputPlan, sqlContext)
executePlan(outputPlan, _sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@ -212,12 +210,12 @@ object SparkPlanTest {
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean,
sqlContext: SQLContext): Option[String] = {
_sqlContext: SQLContext): Option[String] = {
val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
val sparkAnswer: Seq[Row] = try {
executePlan(outputPlan, sqlContext)
executePlan(outputPlan, _sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@ -280,10 +278,10 @@ object SparkPlanTest {
}
}
private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = {
// A very simple resolver to make writing tests easier. In contrast to the real resolver
// this is always case sensitive and does not try to handle scoping or complex type resolution.
val resolvedPlan = sqlContext.prepareForExecution.execute(
val resolvedPlan = _sqlContext.prepareForExecution.execute(
outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap

View file

@ -19,25 +19,28 @@ package org.apache.spark.sql.execution
import scala.util.Random
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
/**
* A test suite that generates randomized data to test the [[TungstenSort]] operator.
*/
class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
class TungstenSortSuite extends SparkPlanTest with SharedSQLContext {
override def beforeAll(): Unit = {
TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
super.beforeAll()
ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
}
override def afterAll(): Unit = {
TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
try {
ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
} finally {
super.afterAll()
}
}
test("sort followed by limit") {
@ -61,7 +64,7 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
}
test("sorting updates peak execution memory") {
val sc = TestSQLContext.sparkContext
val sc = ctx.sparkContext
AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") {
checkThatPlansAgree(
(1 to 100).map(v => Tuple1(v)).toDF("a"),
@ -80,8 +83,8 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
) {
test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
val inputData = Seq.fill(1000)(randomDataGenerator())
val inputDf = TestSQLContext.createDataFrame(
TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
val inputDf = ctx.createDataFrame(
ctx.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
StructType(StructField("a", dataType, nullable = true) :: Nil)
)
assert(TungstenSort.supportsSchema(inputDf.schema))

View file

@ -26,7 +26,7 @@ import org.scalatest.Matchers
import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
import org.apache.spark.unsafe.types.UTF8String
@ -36,7 +36,10 @@ import org.apache.spark.unsafe.types.UTF8String
*
* Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test cases.
*/
class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
class UnsafeFixedWidthAggregationMapSuite
extends SparkFunSuite
with Matchers
with SharedSQLContext {
import UnsafeFixedWidthAggregationMap._
@ -171,9 +174,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
}
testWithMemoryLeakDetection("test external sorting") {
// Calling this make sure we have block manager and everything else setup.
TestSQLContext
// Memory consumption in the beginning of the task.
val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
@ -233,8 +233,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
}
testWithMemoryLeakDetection("test external sorting with an empty map") {
// Calling this make sure we have block manager and everything else setup.
TestSQLContext
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer,
@ -282,8 +280,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
}
testWithMemoryLeakDetection("test external sorting with empty records") {
// Calling this make sure we have block manager and everything else setup.
TestSQLContext
// Memory consumption in the beginning of the task.
val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()

View file

@ -23,15 +23,14 @@ import org.apache.spark._
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
/**
* Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data.
*/
class UnsafeKVExternalSorterSuite extends SparkFunSuite {
class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType)
private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType)
@ -109,8 +108,6 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite {
inputData: Seq[(InternalRow, InternalRow)],
pageSize: Long,
spill: Boolean): Unit = {
// Calling this make sure we have block manager and everything else setup.
TestSQLContext
val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
val shuffleMemMgr = new TestShuffleMemoryManager

View file

@ -21,15 +21,12 @@ import org.apache.spark._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.memory.TaskMemoryManager
class TungstenAggregationIteratorSuite extends SparkFunSuite {
class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext {
test("memory acquired on construction") {
// set up environment
val ctx = TestSQLContext
val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager)
val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty)
TaskContext.setTaskContext(taskContext)

View file

@ -24,22 +24,16 @@ import com.fasterxml.jackson.core.JsonFactory
import org.apache.spark.rdd.RDD
import org.scalactic.Tolerance._
import org.apache.spark.sql.{SQLContext, QueryTest, Row, SQLConf}
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.util.Utils
class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext
override def sqlContext: SQLContext = ctx // used by SQLTestUtils
import ctx.sql
import ctx.implicits._
class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
import testImplicits._
test("Type promotion") {
def checkTypePromotion(expected: Any, actual: Any) {
@ -596,7 +590,8 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
val schema = StructType(StructField("a", LongType, true) :: Nil)
val logicalRelation =
ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation]
ctx.read.schema(schema).json(path)
.queryExecution.analyzed.asInstanceOf[LogicalRelation]
val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation]
assert(relationWithSchema.paths === Array(path))
assert(relationWithSchema.schema === schema)
@ -1040,31 +1035,29 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
}
test("JSONRelation equality test") {
val context = org.apache.spark.sql.test.TestSQLContext
val relation0 = new JSONRelation(
Some(empty),
1.0,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
None, None)(context)
None, None)(ctx)
val logicalRelation0 = LogicalRelation(relation0)
val relation1 = new JSONRelation(
Some(singleRow),
1.0,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
None, None)(context)
None, None)(ctx)
val logicalRelation1 = LogicalRelation(relation1)
val relation2 = new JSONRelation(
Some(singleRow),
0.5,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
None, None)(context)
None, None)(ctx)
val logicalRelation2 = LogicalRelation(relation2)
val relation3 = new JSONRelation(
Some(singleRow),
1.0,
Some(StructType(StructField("b", IntegerType, true) :: Nil)),
None, None)(context)
None, None)(ctx)
val logicalRelation3 = LogicalRelation(relation3)
assert(relation0 !== relation1)
@ -1089,14 +1082,14 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
.map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path)
val d1 = ResolvedDataSource(
context,
ctx,
userSpecifiedSchema = None,
partitionColumns = Array.empty[String],
provider = classOf[DefaultSource].getCanonicalName,
options = Map("path" -> path))
val d2 = ResolvedDataSource(
context,
ctx,
userSpecifiedSchema = None,
partitionColumns = Array.empty[String],
provider = classOf[DefaultSource].getCanonicalName,
@ -1162,11 +1155,12 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
"abd")
ctx.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part")
checkAnswer(
sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4))
checkAnswer(
sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5))
checkAnswer(sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9))
checkAnswer(sql(
"SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4))
checkAnswer(sql(
"SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5))
checkAnswer(sql(
"SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9))
})
}
}

View file

@ -20,12 +20,11 @@ package org.apache.spark.sql.execution.datasources.json
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
trait TestJsonData {
protected def ctx: SQLContext
private[json] trait TestJsonData {
protected def _sqlContext: SQLContext
def primitiveFieldAndType: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"string":"this is a simple string.",
"integer":10,
"long":21474836470,
@ -36,7 +35,7 @@ trait TestJsonData {
}""" :: Nil)
def primitiveFieldValueTypeConflict: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1,
"num_bool":true, "num_str":13.1, "str_bool":"str1"}""" ::
"""{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null,
@ -47,14 +46,14 @@ trait TestJsonData {
"num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil)
def jsonNullStruct: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":{}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":""}""" ::
"""{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil)
def complexFieldValueTypeConflict: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"num_struct":11, "str_array":[1, 2, 3],
"array":[], "struct_array":[], "struct": {}}""" ::
"""{"num_struct":{"field":false}, "str_array":null,
@ -65,14 +64,14 @@ trait TestJsonData {
"array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil)
def arrayElementTypeConflict: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}],
"array2": [{"field":214748364700}, {"field":1}]}""" ::
"""{"array3": [{"field":"str"}, {"field":1}]}""" ::
"""{"array3": [1, 2, 3]}""" :: Nil)
def missingFields: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"a":true}""" ::
"""{"b":21474836470}""" ::
"""{"c":[33, 44]}""" ::
@ -80,7 +79,7 @@ trait TestJsonData {
"""{"e":"str"}""" :: Nil)
def complexFieldAndType1: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"struct":{"field1": true, "field2": 92233720368547758070},
"structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]},
"arrayOfString":["str1", "str2"],
@ -96,7 +95,7 @@ trait TestJsonData {
}""" :: Nil)
def complexFieldAndType2: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}],
"complexArrayOfStruct": [
{
@ -150,7 +149,7 @@ trait TestJsonData {
}""" :: Nil)
def mapType1: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"map": {"a": 1}}""" ::
"""{"map": {"b": 2}}""" ::
"""{"map": {"c": 3}}""" ::
@ -158,7 +157,7 @@ trait TestJsonData {
"""{"map": {"e": null}}""" :: Nil)
def mapType2: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"map": {"a": {"field1": [1, 2, 3, null]}}}""" ::
"""{"map": {"b": {"field2": 2}}}""" ::
"""{"map": {"c": {"field1": [], "field2": 4}}}""" ::
@ -167,21 +166,21 @@ trait TestJsonData {
"""{"map": {"f": {"field1": null}}}""" :: Nil)
def nullsInArrays: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"field1":[[null], [[["Test"]]]]}""" ::
"""{"field2":[null, [{"Test":1}]]}""" ::
"""{"field3":[[null], [{"Test":"2"}]]}""" ::
"""{"field4":[[null, [1,2,3]]]}""" :: Nil)
def jsonArray: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""[{"a":"str_a_1"}]""" ::
"""[{"a":"str_a_2"}, {"b":"str_b_3"}]""" ::
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
"""[]""" :: Nil)
def corruptRecords: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{""" ::
"""""" ::
"""{"a":1, b:2}""" ::
@ -190,7 +189,7 @@ trait TestJsonData {
"""]""" :: Nil)
def emptyRecords: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{""" ::
"""""" ::
"""{"a": {}}""" ::
@ -198,9 +197,8 @@ trait TestJsonData {
"""{"b": [{"c": {}}]}""" ::
"""]""" :: Nil)
lazy val singleRow: RDD[String] =
ctx.sparkContext.parallelize(
"""{"a":123}""" :: Nil)
def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]())
lazy val singleRow: RDD[String] = _sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil)
def empty: RDD[String] = _sqlContext.sparkContext.parallelize(Seq[String]())
}

View file

@ -27,18 +27,16 @@ import org.apache.avro.generic.IndexedRecord
import org.apache.hadoop.fs.Path
import org.apache.parquet.avro.AvroParquetWriter
import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat, ParquetEnum, Suit}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.datasources.parquet.test.avro._
import org.apache.spark.sql.test.SharedSQLContext
class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest {
class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
import ParquetCompatibilityTest._
override val sqlContext: SQLContext = TestSQLContext
private def withWriter[T <: IndexedRecord]
(path: String, schema: Schema)
(f: AvroParquetWriter[T] => Unit) = {
(f: AvroParquetWriter[T] => Unit): Unit = {
val writer = new AvroParquetWriter[T](new Path(path), schema)
try f(writer) finally writer.close()
}
@ -129,7 +127,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest {
}
test("SPARK-9407 Don't push down predicates involving Parquet ENUM columns") {
import sqlContext.implicits._
import testImplicits._
withTempPath { dir =>
val path = dir.getCanonicalPath

View file

@ -22,16 +22,18 @@ import scala.collection.JavaConversions._
import org.apache.hadoop.fs.{Path, PathFilter}
import org.apache.parquet.hadoop.ParquetFileReader
import org.apache.parquet.schema.MessageType
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.QueryTest
abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll {
def readParquetSchema(path: String): MessageType = {
/**
* Helper class for testing Parquet compatibility.
*/
private[sql] abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest {
protected def readParquetSchema(path: String): MessageType = {
readParquetSchema(path, { path => !path.getName.startsWith("_") })
}
def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = {
protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = {
val fsPath = new Path(path)
val fs = fsPath.getFileSystem(configuration)
val parquetFiles = fs.listStatus(fsPath, new PathFilter {

View file

@ -20,12 +20,13 @@ package org.apache.spark.sql.execution.datasources.parquet
import org.apache.parquet.filter2.predicate.Operators._
import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators}
import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
/**
* A test suite that tests Parquet filter2 API based filter pushdown optimization.
@ -39,8 +40,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
* 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred
* data type is nullable.
*/
class ParquetFilterSuite extends QueryTest with ParquetTest {
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext {
private def checkFilterPredicate(
df: DataFrame,
@ -301,7 +301,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
}
test("SPARK-6554: don't push down predicates which reference partition columns") {
import sqlContext.implicits._
import testImplicits._
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
withTempPath { dir =>

View file

@ -37,6 +37,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
@ -62,9 +63,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS
/**
* A test suite that tests basic Parquet I/O.
*/
class ParquetIOSuite extends QueryTest with ParquetTest {
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.implicits._
class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
import testImplicits._
/**
* Writes `data` to a Parquet file, reads it back and check file contents.

View file

@ -26,13 +26,13 @@ import scala.collection.mutable.ArrayBuffer
import com.google.common.io.Files
import org.apache.hadoop.fs.Path
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.unsafe.types.UTF8String
import PartitioningUtils._
// The data where the partitioning key exists only in the directory structure.
case class ParquetData(intField: Int, stringField: String)
@ -40,11 +40,9 @@ case class ParquetData(intField: Int, stringField: String)
// The data that also includes the partitioning key
case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String)
class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.implicits._
import sqlContext.sql
class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with SharedSQLContext {
import PartitioningUtils._
import testImplicits._
val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__"

View file

@ -17,11 +17,10 @@
package org.apache.spark.sql.execution.datasources.parquet
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.test.SharedSQLContext
class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest {
override def sqlContext: SQLContext = TestSQLContext
class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
private def readParquetProtobufFile(name: String): DataFrame = {
val url = Thread.currentThread().getContextClassLoader.getResource(name)

View file

@ -21,16 +21,15 @@ import java.io.File
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
/**
* A test suite that tests various Parquet queries.
*/
class ParquetQuerySuite extends QueryTest with ParquetTest {
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.sql
class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext {
test("simple select queries") {
withParquetTable((0 until 10).map(i => (i, i.toString)), "t") {
@ -41,22 +40,22 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
test("appending") {
val data = (0 until 10).map(i => (i, i.toString))
sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
withParquetTable(data, "t") {
sql("INSERT INTO TABLE t SELECT * FROM tmp")
checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple))
checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple))
}
sqlContext.catalog.unregisterTable(Seq("tmp"))
ctx.catalog.unregisterTable(Seq("tmp"))
}
test("overwriting") {
val data = (0 until 10).map(i => (i, i.toString))
sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
withParquetTable(data, "t") {
sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple))
checkAnswer(ctx.table("t"), data.map(Row.fromTuple))
}
sqlContext.catalog.unregisterTable(Seq("tmp"))
ctx.catalog.unregisterTable(Seq("tmp"))
}
test("self-join") {
@ -119,9 +118,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
val schema = StructType(List(StructField("d", DecimalType(18, 0), false),
StructField("time", TimestampType, false)).toArray)
withTempPath { file =>
val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema)
val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema)
df.write.parquet(file.getCanonicalPath)
val df2 = sqlContext.read.parquet(file.getCanonicalPath)
val df2 = ctx.read.parquet(file.getCanonicalPath)
checkAnswer(df2, df.collect().toSeq)
}
}
@ -130,12 +129,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
def testSchemaMerging(expectedColumnNumber: Int): Unit = {
withTempDir { dir =>
val basePath = dir.getCanonicalPath
sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
// delete summary files, so if we don't merge part-files, one column will not be included.
Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata"))
Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata"))
assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber)
assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber)
}
}
@ -154,9 +153,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
def testSchemaMerging(expectedColumnNumber: Int): Unit = {
withTempDir { dir =>
val basePath = dir.getCanonicalPath
sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber)
ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber)
}
}
@ -172,19 +171,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") {
withTempPath { dir =>
val basePath = dir.getCanonicalPath
sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString)
ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString)
// Disables the global SQL option for schema merging
withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") {
assertResult(2) {
// Disables schema merging via data source option
sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length
ctx.read.option("mergeSchema", "false").parquet(basePath).columns.length
}
assertResult(3) {
// Enables schema merging via data source option
sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length
ctx.read.option("mergeSchema", "true").parquet(basePath).columns.length
}
}
}

View file

@ -22,13 +22,11 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.parquet.schema.MessageTypeParser
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest {
val sqlContext = TestSQLContext
abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext {
/**
* Checks whether the reflected Parquet message type for product type `T` conforms `messageType`.

View file

@ -22,9 +22,8 @@ import java.io.File
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.{DataFrame, SaveMode}
import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext}
/**
* A helper trait that provides convenient facilities for Parquet testing.
@ -33,7 +32,9 @@ import org.apache.spark.sql.{DataFrame, SaveMode}
* convenient to use tuples rather than special case classes when writing test cases/suites.
* Especially, `Tuple1.apply` can be used to easily wrap a single type/value.
*/
private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
private[sql] trait ParquetTest extends SQLTestUtils {
protected def _sqlContext: SQLContext
/**
* Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f`
* returns.
@ -42,7 +43,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
(data: Seq[T])
(f: String => Unit): Unit = {
withTempPath { file =>
sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
_sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
f(file.getCanonicalPath)
}
}
@ -54,7 +55,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag]
(data: Seq[T])
(f: DataFrame => Unit): Unit = {
withParquetFile(data)(path => f(sqlContext.read.parquet(path)))
withParquetFile(data)(path => f(_sqlContext.read.parquet(path)))
}
/**
@ -66,14 +67,14 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
(data: Seq[T], tableName: String)
(f: => Unit): Unit = {
withParquetDataFrame(data) { df =>
sqlContext.registerDataFrameAsTable(df, tableName)
_sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
data: Seq[T], path: File): Unit = {
sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
_sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
}
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](

View file

@ -17,14 +17,12 @@
package org.apache.spark.sql.execution.datasources.parquet
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext
class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest {
class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
import ParquetCompatibilityTest._
override val sqlContext: SQLContext = TestSQLContext
private val parquetFilePath =
Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet")

View file

@ -18,10 +18,10 @@
package org.apache.spark.sql.execution.debug
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.SharedSQLContext
class DebuggingSuite extends SparkFunSuite with SharedSQLContext {
class DebuggingSuite extends SparkFunSuite {
test("DataFrame.debug()") {
testData.debug()
}

View file

@ -23,12 +23,12 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.util.collection.CompactBuffer
class HashedRelationSuite extends SparkFunSuite {
class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
// Key is simply the record itself
private val keyProjection = new Projection {
@ -37,7 +37,7 @@ class HashedRelationSuite extends SparkFunSuite {
test("GeneralHashedRelation") {
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data")
val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
val hashed = HashedRelation(data.iterator, numDataRows, keyProjection)
assert(hashed.isInstanceOf[GeneralHashedRelation])
@ -53,7 +53,7 @@ class HashedRelationSuite extends SparkFunSuite {
test("UniqueKeyHashedRelation") {
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2))
val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data")
val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
val hashed = HashedRelation(data.iterator, numDataRows, keyProjection)
assert(hashed.isInstanceOf[UniqueKeyHashedRelation])
@ -73,7 +73,7 @@ class HashedRelationSuite extends SparkFunSuite {
test("UnsafeHashedRelation") {
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data")
val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
val toUnsafe = UnsafeProjection.create(schema)
val unsafeData = data.map(toUnsafe(_).copy()).toArray

View file

@ -17,97 +17,19 @@
package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.{DataFrame, execution, Row, SQLConf}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
private def testInnerJoin(
testName: String,
leftRows: DataFrame,
rightRows: DataFrame,
condition: Expression,
expectedAnswer: Seq[Product]): Unit = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
val broadcastHashJoin =
execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right)
boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
}
def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
val shuffledHashJoin =
execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right)
val filteredJoin =
boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
}
def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = {
val sortMergeJoin =
execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right)
val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
}
test(s"$testName using BroadcastHashJoin (build=left)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeBroadcastHashJoin(left, right, joins.BuildLeft),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
test(s"$testName using BroadcastHashJoin (build=right)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeBroadcastHashJoin(left, right, joins.BuildRight),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
test(s"$testName using ShuffledHashJoin (build=left)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeShuffledHashJoin(left, right, joins.BuildLeft),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
test(s"$testName using ShuffledHashJoin (build=right)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeShuffledHashJoin(left, right, joins.BuildRight),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
test(s"$testName using SortMergeJoin") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeSortMergeJoin(left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
}
{
val upperCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
private lazy val myUpperCaseData = ctx.createDataFrame(
ctx.sparkContext.parallelize(Seq(
Row(1, "A"),
Row(2, "B"),
Row(3, "C"),
@ -117,7 +39,8 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
Row(null, "G")
)), new StructType().add("N", IntegerType).add("L", StringType))
val lowerCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
private lazy val myLowerCaseData = ctx.createDataFrame(
ctx.sparkContext.parallelize(Seq(
Row(1, "a"),
Row(2, "b"),
Row(3, "c"),
@ -125,21 +48,7 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
Row(null, "e")
)), new StructType().add("n", IntegerType).add("l", StringType))
testInnerJoin(
"inner join, one match per row",
upperCaseData,
lowerCaseData,
(upperCaseData.col("N") === lowerCaseData.col("n")).expr,
Seq(
(1, "A", 1, "a"),
(2, "B", 2, "b"),
(3, "C", 3, "c"),
(4, "D", 4, "d")
)
)
}
private val testData2 = Seq(
private lazy val myTestData = Seq(
(1, 1),
(1, 2),
(2, 1),
@ -148,14 +57,139 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
(3, 2)
).toDF("a", "b")
// Note: the input dataframes and expression must be evaluated lazily because
// the SQLContext should be used only within a test to keep SQL tests stable
private def testInnerJoin(
testName: String,
leftRows: => DataFrame,
rightRows: => DataFrame,
condition: () => Expression,
expectedAnswer: Seq[Product]): Unit = {
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition()))
ExtractEquiJoinKeys.unapply(join)
}
def makeBroadcastHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
boundCondition: Option[Expression],
leftPlan: SparkPlan,
rightPlan: SparkPlan,
side: BuildSide) = {
val broadcastHashJoin =
execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
}
def makeShuffledHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
boundCondition: Option[Expression],
leftPlan: SparkPlan,
rightPlan: SparkPlan,
side: BuildSide) = {
val shuffledHashJoin =
execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
val filteredJoin =
boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
}
def makeSortMergeJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
boundCondition: Option[Expression],
leftPlan: SparkPlan,
rightPlan: SparkPlan) = {
val sortMergeJoin =
execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan)
val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
}
test(s"$testName using BroadcastHashJoin (build=left)") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeBroadcastHashJoin(
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
test(s"$testName using BroadcastHashJoin (build=right)") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeBroadcastHashJoin(
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
test(s"$testName using ShuffledHashJoin (build=left)") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeShuffledHashJoin(
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
test(s"$testName using ShuffledHashJoin (build=right)") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeShuffledHashJoin(
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
test(s"$testName using SortMergeJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
}
testInnerJoin(
"inner join, one match per row",
myUpperCaseData,
myLowerCaseData,
() => (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr,
Seq(
(1, "A", 1, "a"),
(2, "B", 2, "b"),
(3, "C", 3, "c"),
(4, "D", 4, "d")
)
)
{
val left = testData2.where("a = 1")
val right = testData2.where("a = 1")
lazy val left = myTestData.where("a = 1")
lazy val right = myTestData.where("a = 1")
testInnerJoin(
"inner join, multiple matches",
left,
right,
(left.col("a") === right.col("a")).expr,
() => (left.col("a") === right.col("a")).expr,
Seq(
(1, 1, 1, 1),
(1, 1, 1, 2),
@ -166,13 +200,13 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
}
{
val left = testData2.where("a = 1")
val right = testData2.where("a = 2")
lazy val left = myTestData.where("a = 1")
lazy val right = myTestData.where("a = 2")
testInnerJoin(
"inner join, no matches",
left,
right,
(left.col("a") === right.col("a")).expr,
() => (left.col("a") === right.col("a")).expr,
Seq.empty
)
}

View file

@ -17,79 +17,19 @@
package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.{DataFrame, Row, SQLConf}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType}
import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.{EnsureRequirements, joins, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType}
class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
private def testOuterJoin(
testName: String,
leftRows: DataFrame,
rightRows: DataFrame,
joinType: JoinType,
condition: Expression,
expectedAnswer: Seq[Product]): Unit = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
test(s"$testName using ShuffledHashOuterJoin") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(sqlContext).apply(
ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
if (joinType != FullOuter) {
test(s"$testName using BroadcastHashOuterJoin") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
test(s"$testName using SortMergeOuterJoin") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(sqlContext).apply(
SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = false)
}
}
}
}
test(s"$testName using BroadcastNestedLoopJoin (build=left)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
test(s"$testName using BroadcastNestedLoopJoin (build=right)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
private lazy val left = ctx.createDataFrame(
ctx.sparkContext.parallelize(Seq(
Row(1, 2.0),
Row(2, 100.0),
Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches
@ -100,7 +40,8 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
Row(null, null)
)), new StructType().add("a", IntegerType).add("b", DoubleType))
val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
private lazy val right = ctx.createDataFrame(
ctx.sparkContext.parallelize(Seq(
Row(0, 0.0),
Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches
Row(2, -1.0),
@ -113,12 +54,64 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
Row(null, null)
)), new StructType().add("c", IntegerType).add("d", DoubleType))
val condition = {
And(
(left.col("a") === right.col("c")).expr,
private lazy val condition = {
And((left.col("a") === right.col("c")).expr,
LessThan(left.col("b").expr, right.col("d").expr))
}
// Note: the input dataframes and expression must be evaluated lazily because
// the SQLContext should be used only within a test to keep SQL tests stable
private def testOuterJoin(
testName: String,
leftRows: => DataFrame,
rightRows: => DataFrame,
joinType: JoinType,
condition: => Expression,
expectedAnswer: Seq[Product]): Unit = {
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join)
}
test(s"$testName using ShuffledHashOuterJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(sqlContext).apply(
ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
if (joinType != FullOuter) {
test(s"$testName using BroadcastHashOuterJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
test(s"$testName using SortMergeOuterJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(sqlContext).apply(
SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = false)
}
}
}
}
}
// --- Basic outer joins ------------------------------------------------------------------------
testOuterJoin(

View file

@ -17,27 +17,61 @@
package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {
class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
private lazy val left = ctx.createDataFrame(
ctx.sparkContext.parallelize(Seq(
Row(1, 2.0),
Row(1, 2.0),
Row(2, 1.0),
Row(2, 1.0),
Row(3, 3.0),
Row(null, null),
Row(null, 5.0),
Row(6, null)
)), new StructType().add("a", IntegerType).add("b", DoubleType))
private lazy val right = ctx.createDataFrame(
ctx.sparkContext.parallelize(Seq(
Row(2, 3.0),
Row(2, 3.0),
Row(3, 2.0),
Row(4, 1.0),
Row(null, null),
Row(null, 5.0),
Row(6, null)
)), new StructType().add("c", IntegerType).add("d", DoubleType))
private lazy val condition = {
And((left.col("a") === right.col("c")).expr,
LessThan(left.col("b").expr, right.col("d").expr))
}
// Note: the input dataframes and expression must be evaluated lazily because
// the SQLContext should be used only within a test to keep SQL tests stable
private def testLeftSemiJoin(
testName: String,
leftRows: DataFrame,
rightRows: DataFrame,
condition: Expression,
leftRows: => DataFrame,
rightRows: => DataFrame,
condition: => Expression,
expectedAnswer: Seq[Product]): Unit = {
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
ExtractEquiJoinKeys.unapply(join)
}
test(s"$testName using LeftSemiJoinHash") {
extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(left.sqlContext).apply(
@ -46,8 +80,10 @@ class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {
sortAnswers = true)
}
}
}
test(s"$testName using BroadcastLeftSemiJoinHash") {
extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
@ -67,33 +103,6 @@ class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {
}
}
val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
Row(1, 2.0),
Row(1, 2.0),
Row(2, 1.0),
Row(2, 1.0),
Row(3, 3.0),
Row(null, null),
Row(null, 5.0),
Row(6, null)
)), new StructType().add("a", IntegerType).add("b", DoubleType))
val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
Row(2, 3.0),
Row(2, 3.0),
Row(3, 2.0),
Row(4, 1.0),
Row(null, null),
Row(null, 5.0),
Row(6, null)
)), new StructType().add("c", IntegerType).add("d", DoubleType))
val condition = {
And(
(left.col("a") === right.col("c")).expr,
LessThan(left.col("b").expr, right.col("d").expr))
}
testLeftSemiJoin(
"basic test",
left,

View file

@ -28,17 +28,15 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.execution.ui.SparkPlanGraph
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
override val sqlContext = TestSQLContext
import sqlContext.implicits._
class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
test("LongSQLMetric should not box Long") {
val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long")
val l = SQLMetrics.createLongMetric(ctx.sparkContext, "long")
val f = () => {
l += 1L
l.add(1L)
@ -52,7 +50,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
test("Normal accumulator should do boxing") {
// We need this test to make sure BoxingFinder works.
val l = TestSQLContext.sparkContext.accumulator(0L)
val l = ctx.sparkContext.accumulator(0L)
val f = () => { l += 1L }
BoxingFinder.getClassReader(f.getClass).foreach { cl =>
val boxingFinder = new BoxingFinder()
@ -73,19 +71,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
df: DataFrame,
expectedNumOfJobs: Int,
expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = {
val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet
val previousExecutionIds = ctx.listener.executionIdToData.keySet
df.collect()
TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000)
val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
ctx.sparkContext.listenerBus.waitUntilEmpty(10000)
val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds)
assert(executionIds.size === 1)
val executionId = executionIds.head
val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs
val jobs = ctx.listener.getExecution(executionId).get.jobs
// Use "<=" because there is a race condition that we may miss some jobs
// TODO Change it to "=" once we fix the race condition that missing the JobStarted event.
assert(jobs.size <= expectedNumOfJobs)
if (jobs.size == expectedNumOfJobs) {
// If we can track all jobs, check the metric values
val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId)
val metricValues = ctx.listener.getExecutionMetrics(executionId)
val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node =>
expectedMetrics.contains(node.id)
}.map { node =>
@ -111,7 +109,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
SQLConf.TUNGSTEN_ENABLED.key -> "false") {
// Assume the execution plan is
// PhysicalRDD(nodeId = 1) -> Project(nodeId = 0)
val df = TestData.person.select('name)
val df = person.select('name)
testSparkPlanMetrics(df, 1, Map(
0L ->("Project", Map(
"number of rows" -> 2L)))
@ -126,7 +124,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
SQLConf.TUNGSTEN_ENABLED.key -> "true") {
// Assume the execution plan is
// PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0)
val df = TestData.person.select('name)
val df = person.select('name)
testSparkPlanMetrics(df, 1, Map(
0L ->("TungstenProject", Map(
"number of rows" -> 2L)))
@ -137,7 +135,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
test("Filter metrics") {
// Assume the execution plan is
// PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0)
val df = TestData.person.filter('age < 25)
val df = person.filter('age < 25)
testSparkPlanMetrics(df, 1, Map(
0L -> ("Filter", Map(
"number of input rows" -> 2L,
@ -152,7 +150,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
SQLConf.TUNGSTEN_ENABLED.key -> "false") {
// Assume the execution plan is
// ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0)
val df = TestData.testData2.groupBy().count() // 2 partitions
val df = testData2.groupBy().count() // 2 partitions
testSparkPlanMetrics(df, 1, Map(
2L -> ("Aggregate", Map(
"number of input rows" -> 6L,
@ -163,7 +161,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
)
// 2 partitions and each partition contains 2 keys
val df2 = TestData.testData2.groupBy('a).count()
val df2 = testData2.groupBy('a).count()
testSparkPlanMetrics(df2, 1, Map(
2L -> ("Aggregate", Map(
"number of input rows" -> 6L,
@ -185,7 +183,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
// Assume the execution plan is
// ... -> SortBasedAggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) ->
// SortBasedAggregate(nodeId = 0)
val df = TestData.testData2.groupBy().count() // 2 partitions
val df = testData2.groupBy().count() // 2 partitions
testSparkPlanMetrics(df, 1, Map(
2L -> ("SortBasedAggregate", Map(
"number of input rows" -> 6L,
@ -199,7 +197,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
// ... -> SortBasedAggregate(nodeId = 3) -> TungstenExchange(nodeId = 2)
// -> ExternalSort(nodeId = 1)-> SortBasedAggregate(nodeId = 0)
// 2 partitions and each partition contains 2 keys
val df2 = TestData.testData2.groupBy('a).count()
val df2 = testData2.groupBy('a).count()
testSparkPlanMetrics(df2, 1, Map(
3L -> ("SortBasedAggregate", Map(
"number of input rows" -> 6L,
@ -219,7 +217,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
// Assume the execution plan is
// ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1)
// -> TungstenAggregate(nodeId = 0)
val df = TestData.testData2.groupBy().count() // 2 partitions
val df = testData2.groupBy().count() // 2 partitions
testSparkPlanMetrics(df, 1, Map(
2L -> ("TungstenAggregate", Map(
"number of input rows" -> 6L,
@ -230,7 +228,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
)
// 2 partitions and each partition contains 2 keys
val df2 = TestData.testData2.groupBy('a).count()
val df2 = testData2.groupBy('a).count()
testSparkPlanMetrics(df2, 1, Map(
2L -> ("TungstenAggregate", Map(
"number of input rows" -> 6L,
@ -246,7 +244,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
// Because SortMergeJoin may skip different rows if the number of partitions is different, this
// test should use the deterministic number of partitions.
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
@ -268,7 +266,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
// Because SortMergeOuterJoin may skip different rows if the number of partitions is different,
// this test should use the deterministic number of partitions.
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
@ -314,7 +312,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
test("ShuffledHashJoin metrics") {
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") {
val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
@ -390,7 +388,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
test("BroadcastNestedLoopJoin metrics") {
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
@ -458,7 +456,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
}
test("CartesianProduct metrics") {
val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
@ -476,19 +474,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
test("save metrics") {
withTempPath { file =>
val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet
val previousExecutionIds = ctx.listener.executionIdToData.keySet
// Assume the execution plan is
// PhysicalRDD(nodeId = 0)
TestData.person.select('name).write.format("json").save(file.getAbsolutePath)
TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000)
val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
person.select('name).write.format("json").save(file.getAbsolutePath)
ctx.sparkContext.listenerBus.waitUntilEmpty(10000)
val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds)
assert(executionIds.size === 1)
val executionId = executionIds.head
val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs
val jobs = ctx.listener.getExecution(executionId).get.jobs
// Use "<=" because there is a race condition that we may miss some jobs
// TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event.
assert(jobs.size <= 1)
val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId)
val metricValues = ctx.listener.getExecutionMetrics(executionId)
// Because "save" will create a new DataFrame internally, we cannot get the real metric id.
// However, we still can check the value.
assert(metricValues.values.toSeq === Seq(2L))

View file

@ -25,12 +25,12 @@ import org.apache.spark.sql.execution.metric.LongSQLMetricValue
import org.apache.spark.scheduler._
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.SharedSQLContext
class SQLListenerSuite extends SparkFunSuite {
class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
private def createTestDataFrame: DataFrame = {
import TestSQLContext.implicits._
Seq(
(1, 1),
(2, 2)
@ -74,7 +74,7 @@ class SQLListenerSuite extends SparkFunSuite {
}
test("basic") {
val listener = new SQLListener(TestSQLContext)
val listener = new SQLListener(ctx)
val executionId = 0
val df = createTestDataFrame
val accumulatorIds =
@ -212,7 +212,7 @@ class SQLListenerSuite extends SparkFunSuite {
}
test("onExecutionEnd happens before onJobEnd(JobSucceeded)") {
val listener = new SQLListener(TestSQLContext)
val listener = new SQLListener(ctx)
val executionId = 0
val df = createTestDataFrame
listener.onExecutionStart(
@ -241,7 +241,7 @@ class SQLListenerSuite extends SparkFunSuite {
}
test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") {
val listener = new SQLListener(TestSQLContext)
val listener = new SQLListener(ctx)
val executionId = 0
val df = createTestDataFrame
listener.onExecutionStart(
@ -281,7 +281,7 @@ class SQLListenerSuite extends SparkFunSuite {
}
test("onExecutionEnd happens before onJobEnd(JobFailed)") {
val listener = new SQLListener(TestSQLContext)
val listener = new SQLListener(ctx)
val executionId = 0
val df = createTestDataFrame
listener.onExecutionStart(

View file

@ -25,10 +25,13 @@ import org.h2.jdbc.JdbcSQLException
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
import testImplicits._
val url = "jdbc:h2:mem:testdb0"
val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
var conn: java.sql.Connection = null
@ -42,10 +45,6 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
Some(StringType)
}
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import ctx.sql
before {
Utils.classForName("org.h2.Driver")
// Extra properties that will be specified for our database. We need these to test

View file

@ -23,11 +23,13 @@ import java.util.Properties
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{SaveMode, Row}
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
val url = "jdbc:h2:mem:testdb2"
var conn: java.sql.Connection = null
val url1 = "jdbc:h2:mem:testdb3"
@ -37,10 +39,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
properties.setProperty("password", "testPass")
properties.setProperty("rowId", "false")
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import ctx.sql
before {
Utils.classForName("org.h2.Driver")
conn = DriverManager.getConnection(url)
@ -58,14 +56,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
"create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
conn1.commit()
ctx.sql(
sql(
s"""
|CREATE TEMPORARY TABLE PEOPLE
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
ctx.sql(
sql(
s"""
|CREATE TEMPORARY TABLE PEOPLE1
|USING org.apache.spark.sql.jdbc
@ -144,14 +142,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
}
test("INSERT to JDBC Datasource") {
ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
test("INSERT to JDBC Datasource with overwrite") {
ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}

View file

@ -19,28 +19,32 @@ package org.apache.spark.sql.sources
import java.io.{File, IOException}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.execution.datasources.DDLException
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
import caseInsensitiveContext.sql
class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter {
protected override lazy val sql = caseInsensitiveContext.sql _
private lazy val sparkContext = caseInsensitiveContext.sparkContext
var path: File = null
private var path: File = null
override def beforeAll(): Unit = {
super.beforeAll()
path = Utils.createTempDir()
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
caseInsensitiveContext.read.json(rdd).registerTempTable("jt")
}
override def afterAll(): Unit = {
try {
caseInsensitiveContext.dropTempTable("jt")
} finally {
super.afterAll()
}
}
after {

View file

@ -18,11 +18,12 @@
package org.apache.spark.sql.sources
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}
// please note that the META-INF/services had to be modified for the test directory for this to work
class DDLSourceLoadSuite extends DataSourceTest {
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {
test("data sources with the same name") {
intercept[RuntimeException] {

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.sources
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@ -68,10 +69,12 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo
}
}
class DDLTestSuite extends DataSourceTest {
class DDLTestSuite extends DataSourceTest with SharedSQLContext {
protected override lazy val sql = caseInsensitiveContext.sql _
before {
caseInsensitiveContext.sql(
override def beforeAll(): Unit = {
super.beforeAll()
sql(
"""
|CREATE TEMPORARY TABLE ddlPeople
|USING org.apache.spark.sql.sources.DDLScanSource
@ -105,7 +108,7 @@ class DDLTestSuite extends DataSourceTest {
))
test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") {
val attributes = caseInsensitiveContext.sql("describe ddlPeople")
val attributes = sql("describe ddlPeople")
.queryExecution.executedPlan.output
assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment"))
assert(attributes.map(_.dataType).toSet === Set(StringType))

View file

@ -17,18 +17,23 @@
package org.apache.spark.sql.sources
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql._
import org.apache.spark.sql.test.TestSQLContext
abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
private[sql] abstract class DataSourceTest extends QueryTest {
protected def _sqlContext: SQLContext
// We want to test some edge cases.
protected implicit lazy val caseInsensitiveContext = {
val ctx = new SQLContext(TestSQLContext.sparkContext)
protected lazy val caseInsensitiveContext: SQLContext = {
val ctx = new SQLContext(_sqlContext.sparkContext)
ctx.setConf(SQLConf.CASE_SENSITIVE, false)
ctx
}
protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row]) {
test(sqlString) {
checkAnswer(caseInsensitiveContext.sql(sqlString), expectedAnswer)
}
}
}

View file

@ -21,6 +21,7 @@ import scala.language.existentials
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@ -96,11 +97,11 @@ object FiltersPushed {
var list: Seq[Filter] = Nil
}
class FilteredScanSuite extends DataSourceTest {
class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
protected override lazy val sql = caseInsensitiveContext.sql _
import caseInsensitiveContext.sql
before {
override def beforeAll(): Unit = {
super.beforeAll()
sql(
"""
|CREATE TEMPORARY TABLE oneToTenFiltered

View file

@ -19,20 +19,17 @@ package org.apache.spark.sql.sources
import java.io.File
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.{SaveMode, AnalysisException, Row}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
import caseInsensitiveContext.sql
class InsertSuite extends DataSourceTest with SharedSQLContext {
protected override lazy val sql = caseInsensitiveContext.sql _
private lazy val sparkContext = caseInsensitiveContext.sparkContext
var path: File = null
private var path: File = null
override def beforeAll(): Unit = {
super.beforeAll()
path = Utils.createTempDir()
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""))
caseInsensitiveContext.read.json(rdd).registerTempTable("jt")
@ -47,9 +44,13 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
}
override def afterAll(): Unit = {
try {
caseInsensitiveContext.dropTempTable("jsonTable")
caseInsensitiveContext.dropTempTable("jt")
Utils.deleteRecursively(path)
} finally {
super.afterAll()
}
}
test("Simple INSERT OVERWRITE a JSONRelation") {
@ -221,9 +222,10 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
sql("SELECT a * 2 FROM jsonTable"),
(1 to 10).map(i => Row(i * 2)).toSeq)
assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2)
checkAnswer(
sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"),
assertCached(sql(
"SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2)
checkAnswer(sql(
"SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"),
(2 to 10).map(i => Row(i, i - 1)).toSeq)
// Insert overwrite and keep the same schema.

View file

@ -19,21 +19,21 @@ package org.apache.spark.sql.sources
import org.apache.spark.sql.{Row, QueryTest}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
class PartitionedWriteSuite extends QueryTest {
import TestSQLContext.implicits._
class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("write many partitions") {
val path = Utils.createTempDir()
path.delete()
val df = TestSQLContext.range(100).select($"id", lit(1).as("data"))
val df = ctx.range(100).select($"id", lit(1).as("data"))
df.write.partitionBy("id").save(path.getCanonicalPath)
checkAnswer(
TestSQLContext.read.load(path.getCanonicalPath),
ctx.read.load(path.getCanonicalPath),
(0 to 99).map(Row(1, _)).toSeq)
Utils.deleteRecursively(path)
@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest {
val path = Utils.createTempDir()
path.delete()
val base = TestSQLContext.range(100)
val base = ctx.range(100)
val df = base.unionAll(base).select($"id", lit(1).as("data"))
df.write.partitionBy("id").save(path.getCanonicalPath)
checkAnswer(
TestSQLContext.read.load(path.getCanonicalPath),
ctx.read.load(path.getCanonicalPath),
(0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq)
Utils.deleteRecursively(path)

View file

@ -21,6 +21,7 @@ import scala.language.existentials
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
class PrunedScanSource extends RelationProvider {
@ -51,10 +52,12 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo
}
}
class PrunedScanSuite extends DataSourceTest {
class PrunedScanSuite extends DataSourceTest with SharedSQLContext {
protected override lazy val sql = caseInsensitiveContext.sql _
before {
caseInsensitiveContext.sql(
override def beforeAll(): Unit = {
super.beforeAll()
sql(
"""
|CREATE TEMPORARY TABLE oneToTenPruned
|USING org.apache.spark.sql.sources.PrunedScanSource
@ -114,7 +117,7 @@ class PrunedScanSuite extends DataSourceTest {
def testPruning(sqlString: String, expectedColumns: String*): Unit = {
test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") {
val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution
val queryExecution = sql(sqlString).queryExecution
val rawPlan = queryExecution.executedPlan.collect {
case p: execution.PhysicalRDD => p
} match {

View file

@ -19,25 +19,22 @@ package org.apache.spark.sql.sources
import java.io.File
import org.scalatest.BeforeAndAfterAll
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
import caseInsensitiveContext.sql
class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter {
protected override lazy val sql = caseInsensitiveContext.sql _
private lazy val sparkContext = caseInsensitiveContext.sparkContext
var originalDefaultSource: String = null
var path: File = null
var df: DataFrame = null
private var originalDefaultSource: String = null
private var path: File = null
private var df: DataFrame = null
override def beforeAll(): Unit = {
super.beforeAll()
originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName
path = Utils.createTempDir()
@ -49,11 +46,14 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
}
override def afterAll(): Unit = {
try {
caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
} finally {
super.afterAll()
}
}
after {
caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
Utils.deleteRecursively(path)
}

View file

@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
class DefaultSource extends SimpleScanSource
@ -95,8 +96,8 @@ case class AllDataTypesScan(
}
}
class TableScanSuite extends DataSourceTest {
import caseInsensitiveContext.sql
class TableScanSuite extends DataSourceTest with SharedSQLContext {
protected override lazy val sql = caseInsensitiveContext.sql _
private lazy val tableWithSchemaExpected = (1 to 10).map { i =>
Row(
@ -122,7 +123,8 @@ class TableScanSuite extends DataSourceTest {
Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(Date.valueOf(s"1970-01-${i + 1}")))))
}.toSeq
before {
override def beforeAll(): Unit = {
super.beforeAll()
sql(
"""
|CREATE TEMPORARY TABLE oneToTen
@ -303,9 +305,10 @@ class TableScanSuite extends DataSourceTest {
sql("SELECT i * 2 FROM oneToTen"),
(1 to 10).map(i => Row(i * 2)).toSeq)
assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2)
checkAnswer(
sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"),
assertCached(sql(
"SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2)
checkAnswer(sql(
"SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"),
(2 to 10).map(i => Row(i, i - 1)).toSeq)
// Verify uncaching

View file

@ -0,0 +1,290 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.test
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits}
/**
* A collection of sample data used in SQL tests.
*/
private[sql] trait SQLTestData { self =>
protected def _sqlContext: SQLContext
// Helper object to import SQL implicits without a concrete SQLContext
private object internalImplicits extends SQLImplicits {
protected override def _sqlContext: SQLContext = self._sqlContext
}
import internalImplicits._
import SQLTestData._
// Note: all test data should be lazy because the SQLContext is not set up yet.
protected lazy val testData: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString))).toDF()
df.registerTempTable("testData")
df
}
protected lazy val testData2: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
TestData2(1, 1) ::
TestData2(1, 2) ::
TestData2(2, 1) ::
TestData2(2, 2) ::
TestData2(3, 1) ::
TestData2(3, 2) :: Nil, 2).toDF()
df.registerTempTable("testData2")
df
}
protected lazy val testData3: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
TestData3(1, None) ::
TestData3(2, Some(2)) :: Nil).toDF()
df.registerTempTable("testData3")
df
}
protected lazy val negativeData: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(-i, (-i).toString))).toDF()
df.registerTempTable("negativeData")
df
}
protected lazy val largeAndSmallInts: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
LargeAndSmallInts(2147483644, 1) ::
LargeAndSmallInts(1, 2) ::
LargeAndSmallInts(2147483645, 1) ::
LargeAndSmallInts(2, 2) ::
LargeAndSmallInts(2147483646, 1) ::
LargeAndSmallInts(3, 2) :: Nil).toDF()
df.registerTempTable("largeAndSmallInts")
df
}
protected lazy val decimalData: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
DecimalData(1, 1) ::
DecimalData(1, 2) ::
DecimalData(2, 1) ::
DecimalData(2, 2) ::
DecimalData(3, 1) ::
DecimalData(3, 2) :: Nil).toDF()
df.registerTempTable("decimalData")
df
}
protected lazy val binaryData: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
BinaryData("12".getBytes, 1) ::
BinaryData("22".getBytes, 5) ::
BinaryData("122".getBytes, 3) ::
BinaryData("121".getBytes, 2) ::
BinaryData("123".getBytes, 4) :: Nil).toDF()
df.registerTempTable("binaryData")
df
}
protected lazy val upperCaseData: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
UpperCaseData(1, "A") ::
UpperCaseData(2, "B") ::
UpperCaseData(3, "C") ::
UpperCaseData(4, "D") ::
UpperCaseData(5, "E") ::
UpperCaseData(6, "F") :: Nil).toDF()
df.registerTempTable("upperCaseData")
df
}
protected lazy val lowerCaseData: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
LowerCaseData(1, "a") ::
LowerCaseData(2, "b") ::
LowerCaseData(3, "c") ::
LowerCaseData(4, "d") :: Nil).toDF()
df.registerTempTable("lowerCaseData")
df
}
protected lazy val arrayData: RDD[ArrayData] = {
val rdd = _sqlContext.sparkContext.parallelize(
ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
rdd.toDF().registerTempTable("arrayData")
rdd
}
protected lazy val mapData: RDD[MapData] = {
val rdd = _sqlContext.sparkContext.parallelize(
MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
MapData(Map(1 -> "a4", 2 -> "b4")) ::
MapData(Map(1 -> "a5")) :: Nil)
rdd.toDF().registerTempTable("mapData")
rdd
}
protected lazy val repeatedData: RDD[StringData] = {
val rdd = _sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
rdd.toDF().registerTempTable("repeatedData")
rdd
}
protected lazy val nullableRepeatedData: RDD[StringData] = {
val rdd = _sqlContext.sparkContext.parallelize(
List.fill(2)(StringData(null)) ++
List.fill(2)(StringData("test")))
rdd.toDF().registerTempTable("nullableRepeatedData")
rdd
}
protected lazy val nullInts: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
NullInts(1) ::
NullInts(2) ::
NullInts(3) ::
NullInts(null) :: Nil).toDF()
df.registerTempTable("nullInts")
df
}
protected lazy val allNulls: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
NullInts(null) ::
NullInts(null) ::
NullInts(null) ::
NullInts(null) :: Nil).toDF()
df.registerTempTable("allNulls")
df
}
protected lazy val nullStrings: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
NullStrings(1, "abc") ::
NullStrings(2, "ABC") ::
NullStrings(3, null) :: Nil).toDF()
df.registerTempTable("nullStrings")
df
}
protected lazy val tableName: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF()
df.registerTempTable("tableName")
df
}
protected lazy val unparsedStrings: RDD[String] = {
_sqlContext.sparkContext.parallelize(
"1, A1, true, null" ::
"2, B2, false, null" ::
"3, C3, true, null" ::
"4, D4, true, 2147483644" :: Nil)
}
// An RDD with 4 elements and 8 partitions
protected lazy val withEmptyParts: RDD[IntField] = {
val rdd = _sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
rdd.toDF().registerTempTable("withEmptyParts")
rdd
}
protected lazy val person: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
Person(0, "mike", 30) ::
Person(1, "jim", 20) :: Nil).toDF()
df.registerTempTable("person")
df
}
protected lazy val salary: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
Salary(0, 2000.0) ::
Salary(1, 1000.0) :: Nil).toDF()
df.registerTempTable("salary")
df
}
protected lazy val complexData: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) ::
ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) ::
Nil).toDF()
df.registerTempTable("complexData")
df
}
/**
* Initialize all test data such that all temp tables are properly registered.
*/
def loadTestData(): Unit = {
assert(_sqlContext != null, "attempted to initialize test data before SQLContext.")
testData
testData2
testData3
negativeData
largeAndSmallInts
decimalData
binaryData
upperCaseData
lowerCaseData
arrayData
mapData
repeatedData
nullableRepeatedData
nullInts
allNulls
nullStrings
tableName
unparsedStrings
withEmptyParts
person
salary
complexData
}
}
/**
* Case classes used in test data.
*/
private[sql] object SQLTestData {
case class TestData(key: Int, value: String)
case class TestData2(a: Int, b: Int)
case class TestData3(a: Int, b: Option[Int])
case class LargeAndSmallInts(a: Int, b: Int)
case class DecimalData(a: BigDecimal, b: BigDecimal)
case class BinaryData(a: Array[Byte], b: Int)
case class UpperCaseData(N: Int, L: String)
case class LowerCaseData(n: Int, l: String)
case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
case class MapData(data: scala.collection.Map[Int, String])
case class StringData(s: String)
case class IntField(i: Int)
case class NullInts(a: Integer)
case class NullStrings(n: Int, s: String)
case class TableName(tableName: String)
case class Person(id: Int, name: String, age: Int)
case class Salary(personId: Int, salary: Double)
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
}

View file

@ -21,15 +21,71 @@ import java.io.File
import java.util.UUID
import scala.util.Try
import scala.language.implicitConversions
import org.apache.hadoop.conf.Configuration
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.util.Utils
trait SQLTestUtils { this: SparkFunSuite =>
protected def sqlContext: SQLContext
/**
* Helper trait that should be extended by all SQL test suites.
*
* This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data
* prepared in advance as well as all implicit conversions used extensively by dataframes.
* To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]].
*
* Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is
* prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM.
*/
private[sql] trait SQLTestUtils
extends SparkFunSuite
with BeforeAndAfterAll
with SQLTestData { self =>
protected def configuration = sqlContext.sparkContext.hadoopConfiguration
protected def _sqlContext: SQLContext
// Whether to materialize all test data before the first test is run
private var loadTestDataBeforeTests = false
// Shorthand for running a query using our SQLContext
protected lazy val sql = _sqlContext.sql _
/**
* A helper object for importing SQL implicits.
*
* Note that the alternative of importing `sqlContext.implicits._` is not possible here.
* This is because we create the [[SQLContext]] immediately before the first test is run,
* but the implicits import is needed in the constructor.
*/
protected object testImplicits extends SQLImplicits {
protected override def _sqlContext: SQLContext = self._sqlContext
}
/**
* Materialize the test data immediately after the [[SQLContext]] is set up.
* This is necessary if the data is accessed by name but not through direct reference.
*/
protected def setupTestData(): Unit = {
loadTestDataBeforeTests = true
}
protected override def beforeAll(): Unit = {
super.beforeAll()
if (loadTestDataBeforeTests) {
loadTestData()
}
}
/**
* The Hadoop configuration used by the active [[SQLContext]].
*/
protected def configuration: Configuration = {
_sqlContext.sparkContext.hadoopConfiguration
}
/**
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL
@ -39,12 +95,12 @@ trait SQLTestUtils { this: SparkFunSuite =>
*/
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val (keys, values) = pairs.unzip
val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption)
(keys, values).zipped.foreach(sqlContext.conf.setConfString)
val currentValues = keys.map(key => Try(_sqlContext.conf.getConfString(key)).toOption)
(keys, values).zipped.foreach(_sqlContext.conf.setConfString)
try f finally {
keys.zip(currentValues).foreach {
case (key, Some(value)) => sqlContext.conf.setConfString(key, value)
case (key, None) => sqlContext.conf.unsetConf(key)
case (key, Some(value)) => _sqlContext.conf.setConfString(key, value)
case (key, None) => _sqlContext.conf.unsetConf(key)
}
}
}
@ -76,7 +132,7 @@ trait SQLTestUtils { this: SparkFunSuite =>
* Drops temporary table `tableName` after calling `f`.
*/
protected def withTempTable(tableNames: String*)(f: => Unit): Unit = {
try f finally tableNames.foreach(sqlContext.dropTempTable)
try f finally tableNames.foreach(_sqlContext.dropTempTable)
}
/**
@ -85,7 +141,7 @@ trait SQLTestUtils { this: SparkFunSuite =>
protected def withTable(tableNames: String*)(f: => Unit): Unit = {
try f finally {
tableNames.foreach { name =>
sqlContext.sql(s"DROP TABLE IF EXISTS $name")
_sqlContext.sql(s"DROP TABLE IF EXISTS $name")
}
}
}
@ -98,12 +154,12 @@ trait SQLTestUtils { this: SparkFunSuite =>
val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}"
try {
sqlContext.sql(s"CREATE DATABASE $dbName")
_sqlContext.sql(s"CREATE DATABASE $dbName")
} catch { case cause: Throwable =>
fail("Failed to create temporary database", cause)
}
try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE")
try f(dbName) finally _sqlContext.sql(s"DROP DATABASE $dbName CASCADE")
}
/**
@ -111,7 +167,15 @@ trait SQLTestUtils { this: SparkFunSuite =>
* `f` returns.
*/
protected def activateDatabase(db: String)(f: => Unit): Unit = {
sqlContext.sql(s"USE $db")
try f finally sqlContext.sql(s"USE default")
_sqlContext.sql(s"USE $db")
try f finally _sqlContext.sql(s"USE default")
}
/**
* Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier
* way to construct [[DataFrame]] directly out of local data without relying on implicits.
*/
protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
DataFrame(_sqlContext, plan)
}
}

View file

@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.test
import org.apache.spark.sql.SQLContext
/**
* Helper trait for SQL test suites where all tests share a single [[TestSQLContext]].
*/
private[sql] trait SharedSQLContext extends SQLTestUtils {
/**
* The [[TestSQLContext]] to use for all tests in this suite.
*
* By default, the underlying [[org.apache.spark.SparkContext]] will be run in local
* mode with the default test configurations.
*/
private var _ctx: TestSQLContext = null
/**
* The [[TestSQLContext]] to use for all tests in this suite.
*/
protected def ctx: TestSQLContext = _ctx
protected def sqlContext: TestSQLContext = _ctx
protected override def _sqlContext: SQLContext = _ctx
/**
* Initialize the [[TestSQLContext]].
*/
protected override def beforeAll(): Unit = {
if (_ctx == null) {
_ctx = new TestSQLContext
}
// Ensure we have initialized the context before calling parent code
super.beforeAll()
}
/**
* Stop the underlying [[org.apache.spark.SparkContext]], if any.
*/
protected override def afterAll(): Unit = {
try {
if (_ctx != null) {
_ctx.sparkContext.stop()
_ctx = null
}
} finally {
super.afterAll()
}
}
}

View file

@ -17,40 +17,36 @@
package org.apache.spark.sql.test
import scala.language.implicitConversions
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.{SQLConf, SQLContext}
/** A SQLContext that can be used for local testing. */
class LocalSQLContext
extends SQLContext(
new SparkContext("local[2]", "TestSQLContext", new SparkConf()
.set("spark.sql.testkey", "true")
// SPARK-8910
.set("spark.ui.enabled", "false"))) {
override protected[sql] def createSession(): SQLSession = {
new this.SQLSession()
/**
* A special [[SQLContext]] prepared for testing.
*/
private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self =>
def this() {
this(new SparkContext("local[2]", "test-sql-context",
new SparkConf().set("spark.sql.testkey", "true")))
}
// Use fewer partitions to speed up testing
protected[sql] override def createSession(): SQLSession = new this.SQLSession()
/** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */
protected[sql] class SQLSession extends super.SQLSession {
protected[sql] override lazy val conf: SQLConf = new SQLConf {
/** Fewer partitions to speed up testing. */
override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5)
}
}
/**
* Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to
* construct [[DataFrame]] directly out of local data without relying on implicits.
*/
protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
DataFrame(this, plan)
// Needed for Java tests
def loadTestData(): Unit = {
testData.loadTestData()
}
private object testData extends SQLTestData {
protected override def _sqlContext: SQLContext = self
}
}
object TestSQLContext extends LocalSQLContext

View file

@ -27,7 +27,6 @@ import org.scalatest.concurrent.Eventually._
import org.scalatest.selenium.WebBrowser
import org.scalatest.time.SpanSugar._
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.ui.SparkUICssErrorHandler
class UISeleniumSuite
@ -36,7 +35,6 @@ class UISeleniumSuite
implicit var webDriver: WebDriver = _
var server: HiveThriftServer2 = _
var hc: HiveContext = _
val uiPort = 20000 + Random.nextInt(10000)
override def mode: ServerMode.Value = ServerMode.binary

View file

@ -26,7 +26,7 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.sources.DataSourceTest
import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils}
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.{Row, SaveMode, SQLContext}
import org.apache.spark.{Logging, SparkFunSuite}
@ -53,7 +53,8 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging {
}
class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils {
override val sqlContext = TestHive
override def _sqlContext: SQLContext = TestHive
import testImplicits._
private val testDF = range(1, 3).select(
('id + 0.1) cast DecimalType(10, 3) as 'd1,

View file

@ -19,14 +19,13 @@ package org.apache.spark.sql.hive
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.{QueryTest, Row, SQLContext}
case class Cases(lower: String, UPPER: String)
class HiveParquetSuite extends QueryTest with ParquetTest {
val sqlContext = TestHive
import sqlContext._
private val ctx = TestHive
override def _sqlContext: SQLContext = ctx
test("Case insensitive attribute names") {
withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") {
@ -54,7 +53,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest {
test("Converting Hive to Parquet Table via saveAsParquetFile") {
withTempPath { dir =>
sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath)
read.parquet(dir.getCanonicalPath).registerTempTable("p")
ctx.read.parquet(dir.getCanonicalPath).registerTempTable("p")
withTempTable("p") {
checkAnswer(
sql("SELECT * FROM src ORDER BY key"),
@ -67,7 +66,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest {
withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") {
withTempPath { file =>
sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath)
read.parquet(file.getCanonicalPath).registerTempTable("p")
ctx.read.parquet(file.getCanonicalPath).registerTempTable("p")
withTempTable("p") {
// let's do three overwrites for good measure
sql("INSERT OVERWRITE TABLE p SELECT * FROM t")

View file

@ -22,7 +22,6 @@ import java.io.{IOException, File}
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapred.InvalidInputException
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.Logging
@ -42,7 +41,8 @@ import org.apache.spark.util.Utils
*/
class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll
with Logging {
override val sqlContext = TestHive
override def _sqlContext: SQLContext = TestHive
private val sqlContext = _sqlContext
var jsonFilePath: String = _

View file

@ -22,9 +22,8 @@ import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode}
class MultiDatabaseSuite extends QueryTest with SQLTestUtils {
override val sqlContext: SQLContext = TestHive
import sqlContext.sql
override val _sqlContext: SQLContext = TestHive
private val sqlContext = _sqlContext
private val df = sqlContext.range(10).coalesce(1)

View file

@ -26,7 +26,8 @@ import org.apache.spark.sql.{Row, SQLConf, SQLContext}
class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest {
import ParquetCompatibilityTest.makeNullable
override val sqlContext: SQLContext = TestHive
override def _sqlContext: SQLContext = TestHive
private val sqlContext = _sqlContext
/**
* Set the staging directory (and hence path to ignore Parquet files under)

View file

@ -17,14 +17,12 @@
package org.apache.spark.sql.hive
import org.apache.spark.sql.{Row, QueryTest}
import org.apache.spark.sql.QueryTest
case class FunctionResult(f1: String, f2: String)
class UDFSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.hive.test.TestHive
import ctx.implicits._
test("UDF case insensitive") {
ctx.udf.register("random0", () => { Math.random() })

View file

@ -17,17 +17,18 @@
package org.apache.spark.sql.hive.execution
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql._
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql._
import org.scalatest.BeforeAndAfterAll
import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
override val sqlContext = TestHive
override def _sqlContext: SQLContext = TestHive
protected val sqlContext = _sqlContext
import sqlContext.implicits._
var originalUseAggregate2: Boolean = _

View file

@ -26,8 +26,8 @@ import org.apache.spark.sql.test.SQLTestUtils
* A set of tests that validates support for Hive Explain command.
*/
class HiveExplainSuite extends QueryTest with SQLTestUtils {
def sqlContext: SQLContext = TestHive
override def _sqlContext: SQLContext = TestHive
private val sqlContext = _sqlContext
test("explain extended command") {
checkExistence(sql(" explain select * from src where key=123 "), true,

View file

@ -66,7 +66,8 @@ class MyDialect extends DefaultParserDialect
* valid, but Hive currently cannot execute it.
*/
class SQLQuerySuite extends QueryTest with SQLTestUtils {
override def sqlContext: SQLContext = TestHive
override def _sqlContext: SQLContext = TestHive
private val sqlContext = _sqlContext
test("UDTF") {
sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}")

View file

@ -31,7 +31,8 @@ import org.apache.spark.sql.types.StringType
class ScriptTransformationSuite extends SparkPlanTest {
override def sqlContext: SQLContext = TestHive
override def _sqlContext: SQLContext = TestHive
private val sqlContext = _sqlContext
private val noSerdeIOSchema = HiveScriptIOSchema(
inputRowFormat = Seq.empty,

View file

@ -27,8 +27,8 @@ import org.apache.spark.sql._
import org.apache.spark.sql.test.SQLTestUtils
private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite =>
lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive
protected override def _sqlContext: SQLContext = org.apache.spark.sql.hive.test.TestHive
protected val sqlContext = _sqlContext
import sqlContext.implicits._
import sqlContext.sparkContext

View file

@ -685,7 +685,8 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
* A collection of tests for parquet data with various forms of partitioning.
*/
abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
override def sqlContext: SQLContext = TestHive
override def _sqlContext: SQLContext = TestHive
protected val sqlContext = _sqlContext
var partitionedTableDir: File = null
var normalTableDir: File = null

View file

@ -18,14 +18,16 @@
package org.apache.spark.sql.sources
import org.apache.hadoop.fs.Path
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils
class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils {
override val sqlContext = TestHive
override def _sqlContext: SQLContext = TestHive
private val sqlContext = _sqlContext
// When committing a task, `CommitFailureTestSource` throws an exception for testing purpose.
val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName

View file

@ -34,9 +34,8 @@ import org.apache.spark.sql.types._
abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
override lazy val sqlContext: SQLContext = TestHive
import sqlContext.sql
override def _sqlContext: SQLContext = TestHive
protected val sqlContext = _sqlContext
import sqlContext.implicits._
val dataSourceName: String