[SPARK-9580] [SQL] Replace singletons in SQL tests

A fundamental limitation of the existing SQL tests is that *there is simply no way to create your own `SparkContext`*. This is a serious limitation because the user may wish to use a different master or config. As a case in point, `BroadcastJoinSuite` is entirely commented out because there is no way to make it pass with the existing infrastructure.

This patch removes the singletons `TestSQLContext` and `TestData`, and instead introduces a `SharedSQLContext` that starts a context per suite. Unfortunately the singletons were so ingrained in the SQL tests that this patch necessarily needed to touch *all* the SQL test files.

<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/8111)
<!-- Reviewable:end -->

Author: Andrew Or <andrew@databricks.com>

Closes #8111 from andrewor14/sql-tests-refactor.
This commit is contained in:
Andrew Or 2015-08-13 17:42:01 -07:00 committed by Reynold Xin
parent c50f97dafd
commit 8187b3ae47
96 changed files with 1461 additions and 1204 deletions

View file

@ -178,6 +178,16 @@ object MimaExcludes {
// SPARK-4751 Dynamic allocation for standalone mode
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.SparkContext.supportDynamicAllocation")
) ++ Seq(
// SPARK-9580: Remove SQL test singletons
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.test.LocalSQLContext$SQLSession"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.test.LocalSQLContext"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.test.TestSQLContext"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.test.TestSQLContext$")
) ++ Seq(
// SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs
ProblemFilters.exclude[IncompatibleResultTypeProblem](

View file

@ -319,6 +319,8 @@ object SQL {
lazy val settings = Seq(
initialCommands in console :=
"""
|import org.apache.spark.SparkContext
|import org.apache.spark.sql.SQLContext
|import org.apache.spark.sql.catalyst.analysis._
|import org.apache.spark.sql.catalyst.dsl._
|import org.apache.spark.sql.catalyst.errors._
@ -328,9 +330,14 @@ object SQL {
|import org.apache.spark.sql.catalyst.util._
|import org.apache.spark.sql.execution
|import org.apache.spark.sql.functions._
|import org.apache.spark.sql.test.TestSQLContext._
|import org.apache.spark.sql.types._""".stripMargin,
cleanupCommands in console := "sparkContext.stop()"
|import org.apache.spark.sql.types._
|
|val sc = new SparkContext("local[*]", "dev-shell")
|val sqlContext = new SQLContext(sc)
|import sqlContext.implicits._
|import sqlContext._
""".stripMargin,
cleanupCommands in console := "sc.stop()"
)
}
@ -340,8 +347,6 @@ object Hive {
javaOptions += "-XX:MaxPermSize=256m",
// Specially disable assertions since some Hive tests fail them
javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"),
// Multiple queries rely on the TestHive singleton. See comments there for more details.
parallelExecution in Test := false,
// Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings
// only for this subproject.
scalacOptions <<= scalacOptions map { currentOpts: Seq[String] =>
@ -349,6 +354,7 @@ object Hive {
},
initialCommands in console :=
"""
|import org.apache.spark.SparkContext
|import org.apache.spark.sql.catalyst.analysis._
|import org.apache.spark.sql.catalyst.dsl._
|import org.apache.spark.sql.catalyst.errors._

View file

@ -17,14 +17,10 @@
package org.apache.spark.sql.catalyst.analysis
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.types._
@ -42,7 +38,7 @@ case class UnresolvedTestPlan() extends LeafNode {
override def output: Seq[Attribute] = Nil
}
class AnalysisErrorSuite extends AnalysisTest with BeforeAndAfter {
class AnalysisErrorSuite extends AnalysisTest {
import TestRelations._
def errorTest(

View file

@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConversions._
import scala.collection.immutable
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
@ -41,10 +40,9 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
/**
@ -334,97 +332,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @since 1.3.0
*/
@Experimental
object implicits extends Serializable {
// scalastyle:on
/**
* Converts $"col name" into an [[Column]].
* @since 1.3.0
*/
implicit class StringToColumn(val sc: StringContext) {
def $(args: Any*): ColumnName = {
new ColumnName(sc.s(args: _*))
}
}
/**
* An implicit conversion that turns a Scala `Symbol` into a [[Column]].
* @since 1.3.0
*/
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
/**
* Creates a DataFrame from an RDD of case classes or tuples.
* @since 1.3.0
*/
implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = {
DataFrameHolder(self.createDataFrame(rdd))
}
/**
* Creates a DataFrame from a local Seq of Product.
* @since 1.3.0
*/
implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder =
{
DataFrameHolder(self.createDataFrame(data))
}
// Do NOT add more implicit conversions. They are likely to break source compatibility by
// making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous
// because of [[DoubleRDDFunctions]].
/**
* Creates a single column DataFrame from an RDD[Int].
* @since 1.3.0
*/
implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = {
val dataType = IntegerType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
iter.map { v =>
row.setInt(0, v)
row: InternalRow
}
}
DataFrameHolder(
self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
/**
* Creates a single column DataFrame from an RDD[Long].
* @since 1.3.0
*/
implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = {
val dataType = LongType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
iter.map { v =>
row.setLong(0, v)
row: InternalRow
}
}
DataFrameHolder(
self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
/**
* Creates a single column DataFrame from an RDD[String].
* @since 1.3.0
*/
implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = {
val dataType = StringType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
iter.map { v =>
row.update(0, UTF8String.fromString(v))
row: InternalRow
}
}
DataFrameHolder(
self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
object implicits extends SQLImplicits with Serializable {
protected override def _sqlContext: SQLContext = self
}
// scalastyle:on
/**
* :: Experimental ::

View file

@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
import org.apache.spark.sql.types.StructField
import org.apache.spark.unsafe.types.UTF8String
/**
* A collection of implicit methods for converting common Scala objects into [[DataFrame]]s.
*/
private[sql] abstract class SQLImplicits {
protected def _sqlContext: SQLContext
/**
* Converts $"col name" into an [[Column]].
* @since 1.3.0
*/
implicit class StringToColumn(val sc: StringContext) {
def $(args: Any*): ColumnName = {
new ColumnName(sc.s(args: _*))
}
}
/**
* An implicit conversion that turns a Scala `Symbol` into a [[Column]].
* @since 1.3.0
*/
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
/**
* Creates a DataFrame from an RDD of case classes or tuples.
* @since 1.3.0
*/
implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = {
DataFrameHolder(_sqlContext.createDataFrame(rdd))
}
/**
* Creates a DataFrame from a local Seq of Product.
* @since 1.3.0
*/
implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder =
{
DataFrameHolder(_sqlContext.createDataFrame(data))
}
// Do NOT add more implicit conversions. They are likely to break source compatibility by
// making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous
// because of [[DoubleRDDFunctions]].
/**
* Creates a single column DataFrame from an RDD[Int].
* @since 1.3.0
*/
implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = {
val dataType = IntegerType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
iter.map { v =>
row.setInt(0, v)
row: InternalRow
}
}
DataFrameHolder(
_sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
/**
* Creates a single column DataFrame from an RDD[Long].
* @since 1.3.0
*/
implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = {
val dataType = LongType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
iter.map { v =>
row.setLong(0, v)
row: InternalRow
}
}
DataFrameHolder(
_sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
/**
* Creates a single column DataFrame from an RDD[String].
* @since 1.3.0
*/
implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = {
val dataType = StringType
val rows = data.mapPartitions { iter =>
val row = new SpecificMutableRow(dataType :: Nil)
iter.map { v =>
row.update(0, UTF8String.fromString(v))
row: InternalRow
}
}
DataFrameHolder(
_sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
}

View file

@ -27,6 +27,7 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
@ -34,7 +35,6 @@ import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
@ -48,14 +48,16 @@ public class JavaApplySchemaSuite implements Serializable {
@Before
public void setUp() {
sqlContext = TestSQLContext$.MODULE$;
javaCtx = new JavaSparkContext(sqlContext.sparkContext());
SparkContext context = new SparkContext("local[*]", "testing");
javaCtx = new JavaSparkContext(context);
sqlContext = new SQLContext(context);
}
@After
public void tearDown() {
javaCtx = null;
sqlContext.sparkContext().stop();
sqlContext = null;
javaCtx = null;
}
public static class Person implements Serializable {

View file

@ -17,44 +17,45 @@
package test.org.apache.spark.sql;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Ints;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.*;
import org.junit.*;
import scala.collection.JavaConversions;
import scala.collection.Seq;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import scala.collection.JavaConversions;
import scala.collection.Seq;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Ints;
import org.junit.*;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.types.*;
public class JavaDataFrameSuite {
private transient JavaSparkContext jsc;
private transient SQLContext context;
private transient TestSQLContext context;
@Before
public void setUp() {
// Trigger static initializer of TestData
TestData$.MODULE$.testData();
jsc = new JavaSparkContext(TestSQLContext.sparkContext());
context = TestSQLContext$.MODULE$;
SparkContext sc = new SparkContext("local[*]", "testing");
jsc = new JavaSparkContext(sc);
context = new TestSQLContext(sc);
context.loadTestData();
}
@After
public void tearDown() {
jsc = null;
context.sparkContext().stop();
context = null;
jsc = null;
}
@Test
@ -230,7 +231,7 @@ public class JavaDataFrameSuite {
@Test
public void testSampleBy() {
DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key"));
DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
Row[] actual = sampled.groupBy("key").count().orderBy("key").collect();
Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)};

View file

@ -23,12 +23,12 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.SparkContext;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.api.java.UDF2;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.DataTypes;
// The test suite itself is Serializable so that anonymous Function implementations can be
@ -40,12 +40,16 @@ public class JavaUDFSuite implements Serializable {
@Before
public void setUp() {
sqlContext = TestSQLContext$.MODULE$;
sc = new JavaSparkContext(sqlContext.sparkContext());
SparkContext _sc = new SparkContext("local[*]", "testing");
sqlContext = new SQLContext(_sc);
sc = new JavaSparkContext(_sc);
}
@After
public void tearDown() {
sqlContext.sparkContext().stop();
sqlContext = null;
sc = null;
}
@SuppressWarnings("unchecked")

View file

@ -21,13 +21,14 @@ import java.io.File;
import java.io.IOException;
import java.util.*;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
@ -52,8 +53,9 @@ public class JavaSaveLoadSuite {
@Before
public void setUp() throws IOException {
sqlContext = TestSQLContext$.MODULE$;
sc = new JavaSparkContext(sqlContext.sparkContext());
SparkContext _sc = new SparkContext("local[*]", "testing");
sqlContext = new SQLContext(_sc);
sc = new JavaSparkContext(_sc);
originalDefaultSource = sqlContext.conf().defaultDataSourceName();
path =
@ -71,6 +73,13 @@ public class JavaSaveLoadSuite {
df.registerTempTable("jsonTable");
}
@After
public void tearDown() {
sqlContext.sparkContext().stop();
sqlContext = null;
sc = null;
}
@Test
public void saveAndLoad() {
Map<String, String> options = new HashMap<String, String>();

View file

@ -18,24 +18,20 @@
package org.apache.spark.sql
import scala.concurrent.duration._
import scala.language.{implicitConversions, postfixOps}
import scala.language.postfixOps
import org.scalatest.concurrent.Eventually._
import org.apache.spark.Accumulators
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.storage.{StorageLevel, RDDBlockId}
case class BigData(s: String)
private case class BigData(s: String)
class CachedTableSuite extends QueryTest {
TestData // Load test tables.
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import ctx.sql
class CachedTableSuite extends QueryTest with SharedSQLContext {
import testImplicits._
def rddIdOf(tableName: String): Int = {
val executedPlan = ctx.table(tableName).queryExecution.executedPlan

View file

@ -21,16 +21,20 @@ import org.scalatest.Matchers._
import org.apache.spark.sql.execution.{Project, TungstenProject}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.SQLTestUtils
class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
import org.apache.spark.sql.TestData._
class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
import testImplicits._
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
override def sqlContext(): SQLContext = ctx
private lazy val booleanData = {
ctx.createDataFrame(ctx.sparkContext.parallelize(
Row(false, false) ::
Row(false, true) ::
Row(true, false) ::
Row(true, true) :: Nil),
StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType))))
}
test("column names with space") {
val df = Seq((1, "a")).toDF("name with space", "name.with.dot")
@ -258,7 +262,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
nullStrings.collect().toSeq.filter(r => r.getString(1) eq null))
checkAnswer(
ctx.sql("select isnull(null), isnull(1)"),
sql("select isnull(null), isnull(1)"),
Row(true, false))
}
@ -268,7 +272,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
nullStrings.collect().toSeq.filter(r => r.getString(1) ne null))
checkAnswer(
ctx.sql("select isnotnull(null), isnotnull('a')"),
sql("select isnotnull(null), isnotnull('a')"),
Row(false, true))
}
@ -289,7 +293,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil)
checkAnswer(
ctx.sql("select isnan(15), isnan('invalid')"),
sql("select isnan(15), isnan('invalid')"),
Row(false, false))
}
@ -309,7 +313,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
)
testData.registerTempTable("t")
checkAnswer(
ctx.sql(
sql(
"select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " +
" nanvl(b, e), nanvl(e, f) from t"),
Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0)
@ -433,13 +437,6 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
}
}
val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize(
Row(false, false) ::
Row(false, true) ::
Row(true, false) ::
Row(true, true) :: Nil),
StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType))))
test("&&") {
checkAnswer(
booleanData.filter($"a" && true),
@ -523,7 +520,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
)
checkAnswer(
ctx.sql("SELECT upper('aB'), ucase('cDe')"),
sql("SELECT upper('aB'), ucase('cDe')"),
Row("AB", "CDE"))
}
@ -544,7 +541,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
)
checkAnswer(
ctx.sql("SELECT lower('aB'), lcase('cDe')"),
sql("SELECT lower('aB'), lcase('cDe')"),
Row("ab", "cde"))
}

View file

@ -17,15 +17,13 @@
package org.apache.spark.sql
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{BinaryType, DecimalType}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.DecimalType
class DataFrameAggregateSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("groupBy") {
checkAnswer(

View file

@ -17,17 +17,15 @@
package org.apache.spark.sql
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
/**
* Test suite for functions in [[org.apache.spark.sql.functions]].
*/
class DataFrameFunctionsSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("array with column name") {
val df = Seq((0, 1)).toDF("a", "b")
@ -119,11 +117,11 @@ class DataFrameFunctionsSuite extends QueryTest {
test("constant functions") {
checkAnswer(
ctx.sql("SELECT E()"),
sql("SELECT E()"),
Row(scala.math.E)
)
checkAnswer(
ctx.sql("SELECT PI()"),
sql("SELECT PI()"),
Row(scala.math.Pi)
)
}
@ -153,7 +151,7 @@ class DataFrameFunctionsSuite extends QueryTest {
test("nvl function") {
checkAnswer(
ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
Row("x", "y", null))
}
@ -222,7 +220,7 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(-1)
)
checkAnswer(
ctx.sql("SELECT least(a, 2) as l from testData2 order by l"),
sql("SELECT least(a, 2) as l from testData2 order by l"),
Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2))
)
}
@ -233,7 +231,7 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(3)
)
checkAnswer(
ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"),
sql("SELECT greatest(a, 2) as g from testData2 order by g"),
Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3))
)
}

View file

@ -17,10 +17,10 @@
package org.apache.spark.sql
class DataFrameImplicitsSuite extends QueryTest {
import org.apache.spark.sql.test.SharedSQLContext
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("RDD of tuples") {
checkAnswer(

View file

@ -17,14 +17,12 @@
package org.apache.spark.sql
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
class DataFrameJoinSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("join - join using") {
val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
@ -59,7 +57,7 @@ class DataFrameJoinSuite extends QueryTest {
checkAnswer(
df1.join(df2, $"df1.key" === $"df2.key"),
ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key")
sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key")
.collect().toSeq)
}

View file

@ -19,11 +19,11 @@ package org.apache.spark.sql
import scala.collection.JavaConversions._
import org.apache.spark.sql.test.SharedSQLContext
class DataFrameNaFunctionsSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
def createDF(): DataFrame = {
Seq[(String, java.lang.Integer, java.lang.Double)](

View file

@ -19,20 +19,17 @@ package org.apache.spark.sql
import java.util.Random
import org.scalatest.Matchers._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.test.SharedSQLContext
class DataFrameStatSuite extends QueryTest {
private val sqlCtx = org.apache.spark.sql.test.TestSQLContext
import sqlCtx.implicits._
class DataFrameStatSuite extends QueryTest with SharedSQLContext {
import testImplicits._
private def toLetter(i: Int): String = (i + 97).toChar.toString
test("sample with replacement") {
val n = 100
val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
checkAnswer(
data.sample(withReplacement = true, 0.05, seed = 13),
Seq(5, 10, 52, 73).map(Row(_))
@ -41,7 +38,7 @@ class DataFrameStatSuite extends QueryTest {
test("sample without replacement") {
val n = 100
val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
checkAnswer(
data.sample(withReplacement = false, 0.05, seed = 13),
Seq(16, 23, 88, 100).map(Row(_))
@ -50,7 +47,7 @@ class DataFrameStatSuite extends QueryTest {
test("randomSplit") {
val n = 600
val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
for (seed <- 1 to 5) {
val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")
@ -167,7 +164,7 @@ class DataFrameStatSuite extends QueryTest {
}
test("Frequent Items 2") {
val rows = sqlCtx.sparkContext.parallelize(Seq.empty[Int], 4)
val rows = ctx.sparkContext.parallelize(Seq.empty[Int], 4)
// this is a regression test, where when merging partitions, we omitted values with higher
// counts than those that existed in the map when the map was full. This test should also fail
// if anything like SPARK-9614 is observed once again
@ -185,7 +182,7 @@ class DataFrameStatSuite extends QueryTest {
}
test("sampleBy") {
val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key"))
val df = ctx.range(0, 100).select((col("id") % 3).as("key"))
val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
checkAnswer(
sampled.groupBy("key").count().orderBy("key"),

View file

@ -23,18 +23,12 @@ import scala.language.postfixOps
import scala.util.Random
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.functions._
import org.apache.spark.sql.execution.datasources.json.JSONRelation
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils}
import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SharedSQLContext}
class DataFrameSuite extends QueryTest with SQLTestUtils {
import org.apache.spark.sql.TestData._
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.implicits._
class DataFrameSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("analysis error should be eagerly reported") {
// Eager analysis.

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
/**
@ -27,10 +27,8 @@ import org.apache.spark.sql.types._
* This is here for now so I can make sure Tungsten project is tested without refactoring existing
* end-to-end test infra. In the long run this should just go away.
*/
class DataFrameTungstenSuite extends QueryTest with SQLTestUtils {
override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.implicits._
class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("test simple types") {
withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") {

View file

@ -22,19 +22,18 @@ import java.text.SimpleDateFormat
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.types.CalendarInterval
class DateFunctionsSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class DateFunctionsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("function current_date") {
val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis())
val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0))
val d2 = DateTimeUtils.fromJavaDate(
ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0))
sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0))
val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis())
assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1)
}
@ -44,9 +43,9 @@ class DateFunctionsSuite extends QueryTest {
val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1))
// Execution in one query should return the same value
checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""),
checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""),
Row(true))
assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp(
assert(math.abs(sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp(
0).getTime - System.currentTimeMillis()) < 5000)
}

View file

@ -17,22 +17,15 @@
package org.apache.spark.sql
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.test.SharedSQLContext
class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
// Ensures tables are loaded.
TestData
class JoinSuite extends QueryTest with SharedSQLContext {
import testImplicits._
override def sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import ctx.logicalPlanToSparkQuery
setupTestData()
test("equi-join is hash-join") {
val x = testData2.as("x")
@ -43,7 +36,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
}
def assertJoin(sqlString: String, c: Class[_]): Any = {
val df = ctx.sql(sqlString)
val df = sql(sqlString)
val physical = df.queryExecution.sparkPlan
val operators = physical.collect {
case j: ShuffledHashJoin => j
@ -126,7 +119,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
test("broadcasted hash join operator selection") {
ctx.cacheManager.clearCache()
ctx.sql("CACHE TABLE testData")
sql("CACHE TABLE testData")
for (sortMergeJoinEnabled <- Seq(true, false)) {
withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") {
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") {
@ -141,12 +134,12 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
}
}
}
ctx.sql("UNCACHE TABLE testData")
sql("UNCACHE TABLE testData")
}
test("broadcasted hash outer join operator selection") {
ctx.cacheManager.clearCache()
ctx.sql("CACHE TABLE testData")
sql("CACHE TABLE testData")
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
Seq(
("SELECT * FROM testData LEFT JOIN testData2 ON key = a",
@ -167,7 +160,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
classOf[BroadcastHashOuterJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
}
ctx.sql("UNCACHE TABLE testData")
sql("UNCACHE TABLE testData")
}
test("multiple-key equi-join is hash-join") {
@ -279,7 +272,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
// Make sure we are choosing left.outputPartitioning as the
// outputPartitioning for the outer join operator.
checkAnswer(
ctx.sql(
sql(
"""
|SELECT l.N, count(*)
|FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
@ -293,7 +286,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
Row(6, 1) :: Nil)
checkAnswer(
ctx.sql(
sql(
"""
|SELECT r.a, count(*)
|FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
@ -339,7 +332,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
// Make sure we are choosing right.outputPartitioning as the
// outputPartitioning for the outer join operator.
checkAnswer(
ctx.sql(
sql(
"""
|SELECT l.a, count(*)
|FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
@ -348,7 +341,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
Row(null, 6))
checkAnswer(
ctx.sql(
sql(
"""
|SELECT r.N, count(*)
|FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
@ -400,7 +393,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
// Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator.
checkAnswer(
ctx.sql(
sql(
"""
|SELECT l.a, count(*)
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
@ -409,7 +402,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
Row(null, 10))
checkAnswer(
ctx.sql(
sql(
"""
|SELECT r.N, count(*)
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
@ -424,7 +417,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
Row(null, 4) :: Nil)
checkAnswer(
ctx.sql(
sql(
"""
|SELECT l.N, count(*)
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
@ -439,7 +432,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
Row(null, 4) :: Nil)
checkAnswer(
ctx.sql(
sql(
"""
|SELECT r.a, count(*)
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
@ -450,7 +443,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
test("broadcasted left semi join operator selection") {
ctx.cacheManager.clearCache()
ctx.sql("CACHE TABLE testData")
sql("CACHE TABLE testData")
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
Seq(
@ -469,11 +462,11 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
}
}
ctx.sql("UNCACHE TABLE testData")
sql("UNCACHE TABLE testData")
}
test("left semi join") {
val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
checkAnswer(df,
Row(1, 1) ::
Row(1, 2) ::

View file

@ -17,10 +17,10 @@
package org.apache.spark.sql
class JsonFunctionsSuite extends QueryTest {
import org.apache.spark.sql.test.SharedSQLContext
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("function get_json_object") {
val df: DataFrame = Seq(("""{"name": "alice", "age": 5}""", "")).toDF("a", "b")

View file

@ -19,12 +19,11 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
class ListTablesSuite extends QueryTest with BeforeAndAfter {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext {
import testImplicits._
private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value")
@ -42,7 +41,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
Row("ListTablesSuiteTable", true))
checkAnswer(
ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
@ -55,7 +54,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
Row("ListTablesSuiteTable", true))
checkAnswer(
ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
@ -67,13 +66,13 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
StructField("tableName", StringType, false) ::
StructField("isTemporary", BooleanType, false) :: Nil)
Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach {
Seq(ctx.tables(), sql("SHOW TABLes")).foreach {
case tableDF =>
assert(expectedSchema === tableDF.schema)
tableDF.registerTempTable("tables")
checkAnswer(
ctx.sql(
sql(
"SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
Row(true, "ListTablesSuiteTable")
)

View file

@ -19,18 +19,16 @@ package org.apache.spark.sql
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions.{log => logarithm}
import org.apache.spark.sql.test.SharedSQLContext
private object MathExpressionsTestData {
case class DoubleData(a: java.lang.Double, b: java.lang.Double)
case class NullDoubles(a: java.lang.Double)
}
class MathExpressionsSuite extends QueryTest {
class MathExpressionsSuite extends QueryTest with SharedSQLContext {
import MathExpressionsTestData._
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import testImplicits._
private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF()
@ -149,7 +147,7 @@ class MathExpressionsSuite extends QueryTest {
test("toDegrees") {
testOneToOneMathFunction(toDegrees, math.toDegrees)
checkAnswer(
ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"),
sql("SELECT degrees(0), degrees(1), degrees(1.5)"),
Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5)))
)
}
@ -157,7 +155,7 @@ class MathExpressionsSuite extends QueryTest {
test("toRadians") {
testOneToOneMathFunction(toRadians, math.toRadians)
checkAnswer(
ctx.sql("SELECT radians(0), radians(1), radians(1.5)"),
sql("SELECT radians(0), radians(1), radians(1.5)"),
Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5)))
)
}
@ -169,7 +167,7 @@ class MathExpressionsSuite extends QueryTest {
test("ceil and ceiling") {
testOneToOneMathFunction(ceil, math.ceil)
checkAnswer(
ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
Row(0.0, 1.0, 2.0))
}
@ -214,7 +212,7 @@ class MathExpressionsSuite extends QueryTest {
val pi = 3.1415
checkAnswer(
ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"),
Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
@ -233,7 +231,7 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneMathFunction[Double](signum, math.signum)
checkAnswer(
ctx.sql("SELECT sign(10), signum(-11)"),
sql("SELECT sign(10), signum(-11)"),
Row(1, -1))
}
@ -241,7 +239,7 @@ class MathExpressionsSuite extends QueryTest {
testTwoToOneMathFunction(pow, pow, math.pow)
checkAnswer(
ctx.sql("SELECT pow(1, 2), power(2, 1)"),
sql("SELECT pow(1, 2), power(2, 1)"),
Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1)))
)
}
@ -280,7 +278,7 @@ class MathExpressionsSuite extends QueryTest {
test("log / ln") {
testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log)
checkAnswer(
ctx.sql("SELECT ln(0), ln(1), ln(1.5)"),
sql("SELECT ln(0), ln(1), ln(1.5)"),
Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5)))
)
}
@ -375,7 +373,7 @@ class MathExpressionsSuite extends QueryTest {
df.select(log2("b") + log2("a")),
Row(1))
checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
checkAnswer(sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
}
test("sqrt") {
@ -384,13 +382,13 @@ class MathExpressionsSuite extends QueryTest {
df.select(sqrt("a"), sqrt("b")),
Row(1.0, 2.0))
checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null))
checkAnswer(sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null))
checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null))
}
test("negative") {
checkAnswer(
ctx.sql("SELECT negative(1), negative(0), negative(-1)"),
sql("SELECT negative(1), negative(0), negative(-1)"),
Row(-1, 0, 1))
}

View file

@ -71,12 +71,6 @@ class QueryTest extends PlanTest {
checkAnswer(df, expectedAnswer.collect())
}
def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) {
test(sqlString) {
checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
}
}
/**
* Asserts that a given [[DataFrame]] will be executed using the given number of cached results.
*/

View file

@ -20,13 +20,12 @@ package org.apache.spark.sql
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
class RowSuite extends SparkFunSuite {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class RowSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
test("create row") {
val expected = new GenericMutableRow(4)

View file

@ -17,11 +17,10 @@
package org.apache.spark.sql
import org.apache.spark.sql.test.SharedSQLContext
class SQLConfSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
class SQLConfSuite extends QueryTest with SharedSQLContext {
private val testKey = "test.key.0"
private val testVal = "test.val.0"
@ -52,21 +51,21 @@ class SQLConfSuite extends QueryTest {
test("parse SQL set commands") {
ctx.conf.clear()
ctx.sql(s"set $testKey=$testVal")
sql(s"set $testKey=$testVal")
assert(ctx.getConf(testKey, testVal + "_") === testVal)
assert(ctx.getConf(testKey, testVal + "_") === testVal)
ctx.sql("set some.property=20")
sql("set some.property=20")
assert(ctx.getConf("some.property", "0") === "20")
ctx.sql("set some.property = 40")
sql("set some.property = 40")
assert(ctx.getConf("some.property", "0") === "40")
val key = "spark.sql.key"
val vs = "val0,val_1,val2.3,my_table"
ctx.sql(s"set $key=$vs")
sql(s"set $key=$vs")
assert(ctx.getConf(key, "0") === vs)
ctx.sql(s"set $key=")
sql(s"set $key=")
assert(ctx.getConf(key, "0") === "")
ctx.conf.clear()
@ -74,14 +73,14 @@ class SQLConfSuite extends QueryTest {
test("deprecated property") {
ctx.conf.clear()
ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
assert(ctx.conf.numShufflePartitions === 10)
}
test("invalid conf value") {
ctx.conf.clear()
val e = intercept[IllegalArgumentException] {
ctx.sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
}
assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10")
}

View file

@ -17,16 +17,17 @@
package org.apache.spark.sql
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext
class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
class SQLContextSuite extends SparkFunSuite with SharedSQLContext {
override def afterAll(): Unit = {
SQLContext.setLastInstantiatedContext(ctx)
try {
SQLContext.setLastInstantiatedContext(ctx)
} finally {
super.afterAll()
}
}
test("getOrCreate instantiates SQLContext") {

View file

@ -19,28 +19,23 @@ package org.apache.spark.sql
import java.sql.Timestamp
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types._
/** A SQL Dialect for testing purpose, and it can not be nested type */
class MyDialect extends DefaultParserDialect
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
// Make sure the tables are loaded.
TestData
class SQLQuerySuite extends QueryTest with SharedSQLContext {
import testImplicits._
val sqlContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.implicits._
import sqlContext.sql
setupTestData()
test("having clause") {
Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav")
@ -60,7 +55,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
}
test("show functions") {
checkAnswer(sql("SHOW functions"), FunctionRegistry.builtin.listFunction().sorted.map(Row(_)))
checkAnswer(sql("SHOW functions"),
FunctionRegistry.builtin.listFunction().sorted.map(Row(_)))
}
test("describe functions") {
@ -178,7 +174,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index")
// we except the id is materialized once
val idUDF = udf(() => UUID.randomUUID().toString)
val idUDF = org.apache.spark.sql.functions.udf(() => UUID.randomUUID().toString)
val dfWithId = df.withColumn("id", idUDF())
// Make a new DataFrame (actually the same reference to the old one)
@ -712,9 +708,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
checkAnswer(
sql(
"""
|SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3
""".stripMargin),
"SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"),
Row(2, 1, 2, 2, 1))
}
@ -1161,7 +1155,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
validateMetadata(sql("SELECT * FROM personWithMeta"))
validateMetadata(sql("SELECT id, name FROM personWithMeta"))
validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId"))
validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId"))
validateMetadata(sql(
"SELECT name, salary FROM personWithMeta JOIN salary ON id = personId"))
}
test("SPARK-3371 Renaming a function expression with group by gives error") {
@ -1627,7 +1622,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
.toDF("num", "str")
df.registerTempTable("1one")
checkAnswer(sqlContext.sql("select count(num) from 1one"), Row(10))
checkAnswer(sql("select count(num) from 1one"), Row(10))
sqlContext.dropTempTable("1one")
}

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.sql.{Date, Timestamp}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext
case class ReflectData(
stringField: String,
@ -71,17 +72,15 @@ case class ComplexReflectData(
mapFieldContainsNull: Map[Int, Option[Long]],
dataField: Data)
class ScalaReflectionRelationSuite extends SparkFunSuite {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
test("query case class RDD") {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))
Seq(data).toDF().registerTempTable("reflectData")
assert(ctx.sql("SELECT * FROM reflectData").collect().head ===
assert(sql("SELECT * FROM reflectData").collect().head ===
Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"),
new Timestamp(12345), Seq(1, 2, 3)))
@ -91,7 +90,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
val data = NullReflectData(null, null, null, null, null, null, null)
Seq(data).toDF().registerTempTable("reflectNullData")
assert(ctx.sql("SELECT * FROM reflectNullData").collect().head ===
assert(sql("SELECT * FROM reflectNullData").collect().head ===
Row.fromSeq(Seq.fill(7)(null)))
}
@ -99,7 +98,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
val data = OptionalReflectData(None, None, None, None, None, None, None)
Seq(data).toDF().registerTempTable("reflectOptionalData")
assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head ===
assert(sql("SELECT * FROM reflectOptionalData").collect().head ===
Row.fromSeq(Seq.fill(7)(null)))
}
@ -107,7 +106,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
test("query binary data") {
Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary")
val result = ctx.sql("SELECT data FROM reflectBinary")
val result = sql("SELECT data FROM reflectBinary")
.collect().head(0).asInstanceOf[Array[Byte]]
assert(result.toSeq === Seq[Byte](1))
}
@ -126,7 +125,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
Nested(None, "abc")))
Seq(data).toDF().registerTempTable("reflectComplexData")
assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head ===
assert(sql("SELECT * FROM reflectComplexData").collect().head ===
Row(
Seq(1, 2, 3),
Seq(1, 2, null),

View file

@ -19,13 +19,12 @@ package org.apache.spark.sql
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.sql.test.SharedSQLContext
class SerializationSuite extends SparkFunSuite {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
class SerializationSuite extends SparkFunSuite with SharedSQLContext {
test("[SPARK-5235] SQLContext should be serializable") {
val sqlContext = new SQLContext(ctx.sparkContext)
new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext)
val _sqlContext = new SQLContext(sqlContext.sparkContext)
new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext)
}
}

View file

@ -18,13 +18,12 @@
package org.apache.spark.sql
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.Decimal
class StringFunctionsSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class StringFunctionsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("string concat") {
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")

View file

@ -1,197 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test._
case class TestData(key: Int, value: String)
object TestData {
val testData = TestSQLContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString))).toDF()
testData.registerTempTable("testData")
val negativeData = TestSQLContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(-i, (-i).toString))).toDF()
negativeData.registerTempTable("negativeData")
case class LargeAndSmallInts(a: Int, b: Int)
val largeAndSmallInts =
TestSQLContext.sparkContext.parallelize(
LargeAndSmallInts(2147483644, 1) ::
LargeAndSmallInts(1, 2) ::
LargeAndSmallInts(2147483645, 1) ::
LargeAndSmallInts(2, 2) ::
LargeAndSmallInts(2147483646, 1) ::
LargeAndSmallInts(3, 2) :: Nil).toDF()
largeAndSmallInts.registerTempTable("largeAndSmallInts")
case class TestData2(a: Int, b: Int)
val testData2 =
TestSQLContext.sparkContext.parallelize(
TestData2(1, 1) ::
TestData2(1, 2) ::
TestData2(2, 1) ::
TestData2(2, 2) ::
TestData2(3, 1) ::
TestData2(3, 2) :: Nil, 2).toDF()
testData2.registerTempTable("testData2")
case class DecimalData(a: BigDecimal, b: BigDecimal)
val decimalData =
TestSQLContext.sparkContext.parallelize(
DecimalData(1, 1) ::
DecimalData(1, 2) ::
DecimalData(2, 1) ::
DecimalData(2, 2) ::
DecimalData(3, 1) ::
DecimalData(3, 2) :: Nil).toDF()
decimalData.registerTempTable("decimalData")
case class BinaryData(a: Array[Byte], b: Int)
val binaryData =
TestSQLContext.sparkContext.parallelize(
BinaryData("12".getBytes(), 1) ::
BinaryData("22".getBytes(), 5) ::
BinaryData("122".getBytes(), 3) ::
BinaryData("121".getBytes(), 2) ::
BinaryData("123".getBytes(), 4) :: Nil).toDF()
binaryData.registerTempTable("binaryData")
case class TestData3(a: Int, b: Option[Int])
val testData3 =
TestSQLContext.sparkContext.parallelize(
TestData3(1, None) ::
TestData3(2, Some(2)) :: Nil).toDF()
testData3.registerTempTable("testData3")
case class UpperCaseData(N: Int, L: String)
val upperCaseData =
TestSQLContext.sparkContext.parallelize(
UpperCaseData(1, "A") ::
UpperCaseData(2, "B") ::
UpperCaseData(3, "C") ::
UpperCaseData(4, "D") ::
UpperCaseData(5, "E") ::
UpperCaseData(6, "F") :: Nil).toDF()
upperCaseData.registerTempTable("upperCaseData")
case class LowerCaseData(n: Int, l: String)
val lowerCaseData =
TestSQLContext.sparkContext.parallelize(
LowerCaseData(1, "a") ::
LowerCaseData(2, "b") ::
LowerCaseData(3, "c") ::
LowerCaseData(4, "d") :: Nil).toDF()
lowerCaseData.registerTempTable("lowerCaseData")
case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
val arrayData =
TestSQLContext.sparkContext.parallelize(
ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
arrayData.toDF().registerTempTable("arrayData")
case class MapData(data: scala.collection.Map[Int, String])
val mapData =
TestSQLContext.sparkContext.parallelize(
MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
MapData(Map(1 -> "a4", 2 -> "b4")) ::
MapData(Map(1 -> "a5")) :: Nil)
mapData.toDF().registerTempTable("mapData")
case class StringData(s: String)
val repeatedData =
TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
repeatedData.toDF().registerTempTable("repeatedData")
val nullableRepeatedData =
TestSQLContext.sparkContext.parallelize(
List.fill(2)(StringData(null)) ++
List.fill(2)(StringData("test")))
nullableRepeatedData.toDF().registerTempTable("nullableRepeatedData")
case class NullInts(a: Integer)
val nullInts =
TestSQLContext.sparkContext.parallelize(
NullInts(1) ::
NullInts(2) ::
NullInts(3) ::
NullInts(null) :: Nil
).toDF()
nullInts.registerTempTable("nullInts")
val allNulls =
TestSQLContext.sparkContext.parallelize(
NullInts(null) ::
NullInts(null) ::
NullInts(null) ::
NullInts(null) :: Nil).toDF()
allNulls.registerTempTable("allNulls")
case class NullStrings(n: Int, s: String)
val nullStrings =
TestSQLContext.sparkContext.parallelize(
NullStrings(1, "abc") ::
NullStrings(2, "ABC") ::
NullStrings(3, null) :: Nil).toDF()
nullStrings.registerTempTable("nullStrings")
case class TableName(tableName: String)
TestSQLContext
.sparkContext
.parallelize(TableName("test") :: Nil)
.toDF()
.registerTempTable("tableName")
val unparsedStrings =
TestSQLContext.sparkContext.parallelize(
"1, A1, true, null" ::
"2, B2, false, null" ::
"3, C3, true, null" ::
"4, D4, true, 2147483644" :: Nil)
case class IntField(i: Int)
// An RDD with 4 elements and 8 partitions
val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
withEmptyParts.toDF().registerTempTable("withEmptyParts")
case class Person(id: Int, name: String, age: Int)
case class Salary(personId: Int, salary: Double)
val person = TestSQLContext.sparkContext.parallelize(
Person(0, "mike", 30) ::
Person(1, "jim", 20) :: Nil).toDF()
person.registerTempTable("person")
val salary = TestSQLContext.sparkContext.parallelize(
Salary(0, 2000.0) ::
Salary(1, 1000.0) :: Nil).toDF()
salary.registerTempTable("salary")
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
val complexData =
TestSQLContext.sparkContext.parallelize(
ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true)
:: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false)
:: Nil).toDF()
complexData.registerTempTable("complexData")
}

View file

@ -17,16 +17,13 @@
package org.apache.spark.sql
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
case class FunctionResult(f1: String, f2: String)
private case class FunctionResult(f1: String, f2: String)
class UDFSuite extends QueryTest with SQLTestUtils {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
override def sqlContext(): SQLContext = ctx
class UDFSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("built-in fixed arity expressions") {
val df = ctx.emptyDataFrame
@ -57,7 +54,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
test("SPARK-8003 spark_partition_id") {
val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying")
df.registerTempTable("tmp_table")
checkAnswer(ctx.sql("select spark_partition_id() from tmp_table").toDF(), Row(0))
checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0))
ctx.dropTempTable("tmp_table")
}
@ -66,9 +63,9 @@ class UDFSuite extends QueryTest with SQLTestUtils {
val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id")
data.write.parquet(dir.getCanonicalPath)
ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table")
val answer = ctx.sql("select input_file_name() from test_table").head().getString(0)
val answer = sql("select input_file_name() from test_table").head().getString(0)
assert(answer.contains(dir.getCanonicalPath))
assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2)
assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2)
ctx.dropTempTable("test_table")
}
}
@ -91,17 +88,17 @@ class UDFSuite extends QueryTest with SQLTestUtils {
test("Simple UDF") {
ctx.udf.register("strLenScala", (_: String).length)
assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4)
assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4)
}
test("ZeroArgument UDF") {
ctx.udf.register("random0", () => { Math.random()})
assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0)
assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)
}
test("TwoArgument UDF") {
ctx.udf.register("strLenScala", (_: String).length + (_: Int))
assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
}
test("UDF in a WHERE") {
@ -112,7 +109,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
df.registerTempTable("integerData")
val result =
ctx.sql("SELECT * FROM integerData WHERE oneArgFilter(key)")
sql("SELECT * FROM integerData WHERE oneArgFilter(key)")
assert(result.count() === 20)
}
@ -124,7 +121,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
df.registerTempTable("groupData")
val result =
ctx.sql(
sql(
"""
| SELECT g, SUM(v) as s
| FROM groupData
@ -143,7 +140,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
df.registerTempTable("groupData")
val result =
ctx.sql(
sql(
"""
| SELECT SUM(v)
| FROM groupData
@ -163,7 +160,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
df.registerTempTable("groupData")
val result =
ctx.sql(
sql(
"""
| SELECT timesHundred(SUM(v)) as v100
| FROM groupData
@ -178,7 +175,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
val result =
ctx.sql("SELECT returnStruct('test', 'test2') as ret")
sql("SELECT returnStruct('test', 'test2') as ret")
.select($"ret.f1").head().getString(0)
assert(result === "test")
}
@ -186,12 +183,12 @@ class UDFSuite extends QueryTest with SQLTestUtils {
test("udf that is transformed") {
ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y))
// 1 + 1 is constant folded causing a transformation.
assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
}
test("type coercion for udf inputs") {
ctx.udf.register("intExpected", (x: Int) => x)
// pass a decimal to intExpected.
assert(ctx.sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
}
}

View file

@ -24,6 +24,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.OpenHashSet
@ -66,10 +67,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
private[spark] override def asNullable: MyDenseVectorUDT = this
}
class UserDefinedTypeSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class UserDefinedTypeSuite extends QueryTest with SharedSQLContext {
import testImplicits._
private lazy val pointsRDD = Seq(
MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
@ -94,7 +93,7 @@ class UserDefinedTypeSuite extends QueryTest {
ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector])
pointsRDD.registerTempTable("points")
checkAnswer(
ctx.sql("SELECT testType(features) from points"),
sql("SELECT testType(features) from points"),
Seq(Row(true), Row(true)))
}

View file

@ -19,18 +19,16 @@ package org.apache.spark.sql.columnar
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, Row, TestData}
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
class InMemoryColumnarQuerySuite extends QueryTest {
// Make sure the tables are loaded.
TestData
class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
import testImplicits._
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import ctx.{logicalPlanToSparkQuery, sql}
setupTestData()
test("simple columnar query") {
val plan = ctx.executePlan(testData.logicalPlan).executedPlan

View file

@ -17,20 +17,19 @@
package org.apache.spark.sql.columnar
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter {
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize
private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning
override protected def beforeAll(): Unit = {
super.beforeAll()
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10)
@ -44,19 +43,17 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi
ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
// Enable in-memory table scan accumulators
ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
}
override protected def afterAll(): Unit = {
ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
}
before {
ctx.cacheTable("pruningData")
}
after {
ctx.uncacheTable("pruningData")
override protected def afterAll(): Unit = {
try {
ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
ctx.uncacheTable("pruningData")
} finally {
super.afterAll()
}
}
// Comparisons
@ -110,7 +107,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi
expectedQueryResult: => Seq[Int]): Unit = {
test(query) {
val df = ctx.sql(query)
val df = sql(query)
val queryExecution = df.queryExecution
assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") {

View file

@ -19,8 +19,9 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.test.SharedSQLContext
class ExchangeSuite extends SparkPlanTest {
class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
test("shuffling UnsafeRows in exchange") {
val input = (1 to 1000).map(Tuple1.apply)
checkAnswer(

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.SparkFunSuite
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.{execution, Row, SQLConf}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans._
@ -27,19 +27,18 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test.TestSQLContext.planner._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution}
class PlannerSuite extends SparkFunSuite with SQLTestUtils {
class PlannerSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
override def sqlContext: SQLContext = TestSQLContext
setupTestData()
private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
val _ctx = ctx
import _ctx.planner._
val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption)
val planned =
plannedOption.getOrElse(
@ -54,6 +53,8 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
}
test("unions are collapsed") {
val _ctx = ctx
import _ctx.planner._
val query = testData.unionAll(testData).unionAll(testData).logicalPlan
val planned = BasicOperators(query).head
val logicalUnions = query collect { case u: logical.Union => u }
@ -81,14 +82,14 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = {
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold)
ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold)
val fields = fieldTypes.zipWithIndex.map {
case (dataType, index) => StructField(s"c${index}", dataType, true)
} :+ StructField("key", IntegerType, true)
val schema = StructType(fields)
val row = Row.fromSeq(Seq.fill(fields.size)(null))
val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil)
createDataFrame(rowRDD, schema).registerTempTable("testLimit")
val rowRDD = ctx.sparkContext.parallelize(row :: Nil)
ctx.createDataFrame(rowRDD, schema).registerTempTable("testLimit")
val planned = sql(
"""
@ -102,10 +103,10 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
dropTempTable("testLimit")
ctx.dropTempTable("testLimit")
}
val origThreshold = conf.autoBroadcastJoinThreshold
val origThreshold = ctx.conf.autoBroadcastJoinThreshold
val simpleTypes =
NullType ::
@ -137,18 +138,18 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
checkPlan(complexTypes, newThreshold = 901617)
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
}
test("InMemoryRelation statistics propagation") {
val origThreshold = conf.autoBroadcastJoinThreshold
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920)
val origThreshold = ctx.conf.autoBroadcastJoinThreshold
ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920)
testData.limit(3).registerTempTable("tiny")
sql("CACHE TABLE tiny")
val a = testData.as("a")
val b = table("tiny").as("b")
val b = ctx.table("tiny").as("b")
val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan
val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
@ -157,12 +158,12 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
}
test("efficient limit -> project -> sort") {
val query = testData.sort('key).select('value).limit(2).logicalPlan
val planned = planner.TakeOrderedAndProject(query)
val planned = ctx.planner.TakeOrderedAndProject(query)
assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
}

View file

@ -21,11 +21,11 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StructType, StringType}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StringType}
import org.apache.spark.unsafe.types.UTF8String
class RowFormatConvertersSuite extends SparkPlanTest {
class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext {
private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect {
case c: ConvertToUnsafe => c
@ -39,20 +39,20 @@ class RowFormatConvertersSuite extends SparkPlanTest {
test("planner should insert unsafe->safe conversions when required") {
val plan = Limit(10, outputsUnsafe)
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe])
}
test("filter can process unsafe rows") {
val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe)
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(getConverters(preparedPlan).size === 1)
assert(preparedPlan.outputsUnsafeRows)
}
test("filter can process safe rows") {
val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe)
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(getConverters(preparedPlan).isEmpty)
assert(!preparedPlan.outputsUnsafeRows)
}
@ -67,33 +67,33 @@ class RowFormatConvertersSuite extends SparkPlanTest {
test("union requires all of its input rows' formats to agree") {
val plan = Union(Seq(outputsSafe, outputsUnsafe))
assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows)
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(preparedPlan.outputsUnsafeRows)
}
test("union can process safe rows") {
val plan = Union(Seq(outputsSafe, outputsSafe))
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(!preparedPlan.outputsUnsafeRows)
}
test("union can process unsafe rows") {
val plan = Union(Seq(outputsUnsafe, outputsUnsafe))
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(preparedPlan.outputsUnsafeRows)
}
test("round trip with ConvertToUnsafe and ConvertToSafe") {
val input = Seq(("hello", 1), ("world", 2))
checkAnswer(
TestSQLContext.createDataFrame(input),
ctx.createDataFrame(input),
plan => ConvertToSafe(ConvertToUnsafe(plan)),
input.map(Row.fromTuple)
)
}
test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") {
SparkPlan.currentContext.set(TestSQLContext)
SparkPlan.currentContext.set(ctx)
val schema = ArrayType(StringType)
val rows = (1 to 100).map { i =>
InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString))))

View file

@ -19,8 +19,9 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.test.SharedSQLContext
class SortSuite extends SparkPlanTest {
class SortSuite extends SparkPlanTest with SharedSQLContext {
// This test was originally added as an example of how to use [[SparkPlanTest]];
// it's not designed to be a comprehensive test of ExternalSort.

View file

@ -17,29 +17,27 @@
package org.apache.spark.sql.execution
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row}
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.util._
/**
* Base class for writing tests for individual physical operators. For an example of how this
* class's test helper methods can be used, see [[SortSuite]].
*/
class SparkPlanTest extends SparkFunSuite {
protected def sqlContext: SQLContext = TestSQLContext
private[sql] abstract class SparkPlanTest extends SparkFunSuite {
protected def _sqlContext: SQLContext
/**
* Creates a DataFrame from a local Seq of Product.
*/
implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = {
sqlContext.implicits.localSeqToDataFrameHolder(data)
_sqlContext.implicits.localSeqToDataFrameHolder(data)
}
/**
@ -100,7 +98,7 @@ class SparkPlanTest extends SparkFunSuite {
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match {
SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, _sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@ -124,7 +122,7 @@ class SparkPlanTest extends SparkFunSuite {
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean = true): Unit = {
SparkPlanTest.checkAnswer(
input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match {
input, planFunction, expectedPlanFunction, sortAnswers, _sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@ -151,13 +149,13 @@ object SparkPlanTest {
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean,
sqlContext: SQLContext): Option[String] = {
_sqlContext: SQLContext): Option[String] = {
val outputPlan = planFunction(input.queryExecution.sparkPlan)
val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
val expectedAnswer: Seq[Row] = try {
executePlan(expectedOutputPlan, sqlContext)
executePlan(expectedOutputPlan, _sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@ -172,7 +170,7 @@ object SparkPlanTest {
}
val actualAnswer: Seq[Row] = try {
executePlan(outputPlan, sqlContext)
executePlan(outputPlan, _sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@ -212,12 +210,12 @@ object SparkPlanTest {
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean,
sqlContext: SQLContext): Option[String] = {
_sqlContext: SQLContext): Option[String] = {
val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
val sparkAnswer: Seq[Row] = try {
executePlan(outputPlan, sqlContext)
executePlan(outputPlan, _sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@ -280,10 +278,10 @@ object SparkPlanTest {
}
}
private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = {
// A very simple resolver to make writing tests easier. In contrast to the real resolver
// this is always case sensitive and does not try to handle scoping or complex type resolution.
val resolvedPlan = sqlContext.prepareForExecution.execute(
val resolvedPlan = _sqlContext.prepareForExecution.execute(
outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap

View file

@ -19,25 +19,28 @@ package org.apache.spark.sql.execution
import scala.util.Random
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
/**
* A test suite that generates randomized data to test the [[TungstenSort]] operator.
*/
class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
class TungstenSortSuite extends SparkPlanTest with SharedSQLContext {
override def beforeAll(): Unit = {
TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
super.beforeAll()
ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
}
override def afterAll(): Unit = {
TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
try {
ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
} finally {
super.afterAll()
}
}
test("sort followed by limit") {
@ -61,7 +64,7 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
}
test("sorting updates peak execution memory") {
val sc = TestSQLContext.sparkContext
val sc = ctx.sparkContext
AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") {
checkThatPlansAgree(
(1 to 100).map(v => Tuple1(v)).toDF("a"),
@ -80,8 +83,8 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
) {
test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
val inputData = Seq.fill(1000)(randomDataGenerator())
val inputDf = TestSQLContext.createDataFrame(
TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
val inputDf = ctx.createDataFrame(
ctx.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
StructType(StructField("a", dataType, nullable = true) :: Nil)
)
assert(TungstenSort.supportsSchema(inputDf.schema))

View file

@ -26,7 +26,7 @@ import org.scalatest.Matchers
import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
import org.apache.spark.unsafe.types.UTF8String
@ -36,7 +36,10 @@ import org.apache.spark.unsafe.types.UTF8String
*
* Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test cases.
*/
class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
class UnsafeFixedWidthAggregationMapSuite
extends SparkFunSuite
with Matchers
with SharedSQLContext {
import UnsafeFixedWidthAggregationMap._
@ -171,9 +174,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
}
testWithMemoryLeakDetection("test external sorting") {
// Calling this make sure we have block manager and everything else setup.
TestSQLContext
// Memory consumption in the beginning of the task.
val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
@ -233,8 +233,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
}
testWithMemoryLeakDetection("test external sorting with an empty map") {
// Calling this make sure we have block manager and everything else setup.
TestSQLContext
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer,
@ -282,8 +280,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
}
testWithMemoryLeakDetection("test external sorting with empty records") {
// Calling this make sure we have block manager and everything else setup.
TestSQLContext
// Memory consumption in the beginning of the task.
val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()

View file

@ -23,15 +23,14 @@ import org.apache.spark._
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
/**
* Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data.
*/
class UnsafeKVExternalSorterSuite extends SparkFunSuite {
class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType)
private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType)
@ -109,8 +108,6 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite {
inputData: Seq[(InternalRow, InternalRow)],
pageSize: Long,
spill: Boolean): Unit = {
// Calling this make sure we have block manager and everything else setup.
TestSQLContext
val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
val shuffleMemMgr = new TestShuffleMemoryManager

View file

@ -21,15 +21,12 @@ import org.apache.spark._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.memory.TaskMemoryManager
class TungstenAggregationIteratorSuite extends SparkFunSuite {
class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext {
test("memory acquired on construction") {
// set up environment
val ctx = TestSQLContext
val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager)
val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty)
TaskContext.setTaskContext(taskContext)

View file

@ -24,22 +24,16 @@ import com.fasterxml.jackson.core.JsonFactory
import org.apache.spark.rdd.RDD
import org.scalactic.Tolerance._
import org.apache.spark.sql.{SQLContext, QueryTest, Row, SQLConf}
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.util.Utils
class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext
override def sqlContext: SQLContext = ctx // used by SQLTestUtils
import ctx.sql
import ctx.implicits._
class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
import testImplicits._
test("Type promotion") {
def checkTypePromotion(expected: Any, actual: Any) {
@ -596,7 +590,8 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
val schema = StructType(StructField("a", LongType, true) :: Nil)
val logicalRelation =
ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation]
ctx.read.schema(schema).json(path)
.queryExecution.analyzed.asInstanceOf[LogicalRelation]
val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation]
assert(relationWithSchema.paths === Array(path))
assert(relationWithSchema.schema === schema)
@ -1040,31 +1035,29 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
}
test("JSONRelation equality test") {
val context = org.apache.spark.sql.test.TestSQLContext
val relation0 = new JSONRelation(
Some(empty),
1.0,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
None, None)(context)
None, None)(ctx)
val logicalRelation0 = LogicalRelation(relation0)
val relation1 = new JSONRelation(
Some(singleRow),
1.0,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
None, None)(context)
None, None)(ctx)
val logicalRelation1 = LogicalRelation(relation1)
val relation2 = new JSONRelation(
Some(singleRow),
0.5,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
None, None)(context)
None, None)(ctx)
val logicalRelation2 = LogicalRelation(relation2)
val relation3 = new JSONRelation(
Some(singleRow),
1.0,
Some(StructType(StructField("b", IntegerType, true) :: Nil)),
None, None)(context)
None, None)(ctx)
val logicalRelation3 = LogicalRelation(relation3)
assert(relation0 !== relation1)
@ -1089,14 +1082,14 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
.map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path)
val d1 = ResolvedDataSource(
context,
ctx,
userSpecifiedSchema = None,
partitionColumns = Array.empty[String],
provider = classOf[DefaultSource].getCanonicalName,
options = Map("path" -> path))
val d2 = ResolvedDataSource(
context,
ctx,
userSpecifiedSchema = None,
partitionColumns = Array.empty[String],
provider = classOf[DefaultSource].getCanonicalName,
@ -1162,11 +1155,12 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
"abd")
ctx.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part")
checkAnswer(
sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4))
checkAnswer(
sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5))
checkAnswer(sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9))
checkAnswer(sql(
"SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4))
checkAnswer(sql(
"SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5))
checkAnswer(sql(
"SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9))
})
}
}

View file

@ -20,12 +20,11 @@ package org.apache.spark.sql.execution.datasources.json
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
trait TestJsonData {
protected def ctx: SQLContext
private[json] trait TestJsonData {
protected def _sqlContext: SQLContext
def primitiveFieldAndType: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"string":"this is a simple string.",
"integer":10,
"long":21474836470,
@ -36,7 +35,7 @@ trait TestJsonData {
}""" :: Nil)
def primitiveFieldValueTypeConflict: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1,
"num_bool":true, "num_str":13.1, "str_bool":"str1"}""" ::
"""{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null,
@ -47,14 +46,14 @@ trait TestJsonData {
"num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil)
def jsonNullStruct: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":{}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":""}""" ::
"""{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil)
def complexFieldValueTypeConflict: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"num_struct":11, "str_array":[1, 2, 3],
"array":[], "struct_array":[], "struct": {}}""" ::
"""{"num_struct":{"field":false}, "str_array":null,
@ -65,14 +64,14 @@ trait TestJsonData {
"array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil)
def arrayElementTypeConflict: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}],
"array2": [{"field":214748364700}, {"field":1}]}""" ::
"""{"array3": [{"field":"str"}, {"field":1}]}""" ::
"""{"array3": [1, 2, 3]}""" :: Nil)
def missingFields: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"a":true}""" ::
"""{"b":21474836470}""" ::
"""{"c":[33, 44]}""" ::
@ -80,7 +79,7 @@ trait TestJsonData {
"""{"e":"str"}""" :: Nil)
def complexFieldAndType1: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"struct":{"field1": true, "field2": 92233720368547758070},
"structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]},
"arrayOfString":["str1", "str2"],
@ -96,7 +95,7 @@ trait TestJsonData {
}""" :: Nil)
def complexFieldAndType2: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}],
"complexArrayOfStruct": [
{
@ -150,7 +149,7 @@ trait TestJsonData {
}""" :: Nil)
def mapType1: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"map": {"a": 1}}""" ::
"""{"map": {"b": 2}}""" ::
"""{"map": {"c": 3}}""" ::
@ -158,7 +157,7 @@ trait TestJsonData {
"""{"map": {"e": null}}""" :: Nil)
def mapType2: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"map": {"a": {"field1": [1, 2, 3, null]}}}""" ::
"""{"map": {"b": {"field2": 2}}}""" ::
"""{"map": {"c": {"field1": [], "field2": 4}}}""" ::
@ -167,21 +166,21 @@ trait TestJsonData {
"""{"map": {"f": {"field1": null}}}""" :: Nil)
def nullsInArrays: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{"field1":[[null], [[["Test"]]]]}""" ::
"""{"field2":[null, [{"Test":1}]]}""" ::
"""{"field3":[[null], [{"Test":"2"}]]}""" ::
"""{"field4":[[null, [1,2,3]]]}""" :: Nil)
def jsonArray: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""[{"a":"str_a_1"}]""" ::
"""[{"a":"str_a_2"}, {"b":"str_b_3"}]""" ::
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
"""[]""" :: Nil)
def corruptRecords: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{""" ::
"""""" ::
"""{"a":1, b:2}""" ::
@ -190,7 +189,7 @@ trait TestJsonData {
"""]""" :: Nil)
def emptyRecords: RDD[String] =
ctx.sparkContext.parallelize(
_sqlContext.sparkContext.parallelize(
"""{""" ::
"""""" ::
"""{"a": {}}""" ::
@ -198,9 +197,8 @@ trait TestJsonData {
"""{"b": [{"c": {}}]}""" ::
"""]""" :: Nil)
lazy val singleRow: RDD[String] =
ctx.sparkContext.parallelize(
"""{"a":123}""" :: Nil)
def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]())
lazy val singleRow: RDD[String] = _sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil)
def empty: RDD[String] = _sqlContext.sparkContext.parallelize(Seq[String]())
}

View file

@ -27,18 +27,16 @@ import org.apache.avro.generic.IndexedRecord
import org.apache.hadoop.fs.Path
import org.apache.parquet.avro.AvroParquetWriter
import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat, ParquetEnum, Suit}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.datasources.parquet.test.avro._
import org.apache.spark.sql.test.SharedSQLContext
class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest {
class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
import ParquetCompatibilityTest._
override val sqlContext: SQLContext = TestSQLContext
private def withWriter[T <: IndexedRecord]
(path: String, schema: Schema)
(f: AvroParquetWriter[T] => Unit) = {
(f: AvroParquetWriter[T] => Unit): Unit = {
val writer = new AvroParquetWriter[T](new Path(path), schema)
try f(writer) finally writer.close()
}
@ -129,7 +127,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest {
}
test("SPARK-9407 Don't push down predicates involving Parquet ENUM columns") {
import sqlContext.implicits._
import testImplicits._
withTempPath { dir =>
val path = dir.getCanonicalPath

View file

@ -22,16 +22,18 @@ import scala.collection.JavaConversions._
import org.apache.hadoop.fs.{Path, PathFilter}
import org.apache.parquet.hadoop.ParquetFileReader
import org.apache.parquet.schema.MessageType
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.QueryTest
abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll {
def readParquetSchema(path: String): MessageType = {
/**
* Helper class for testing Parquet compatibility.
*/
private[sql] abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest {
protected def readParquetSchema(path: String): MessageType = {
readParquetSchema(path, { path => !path.getName.startsWith("_") })
}
def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = {
protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = {
val fsPath = new Path(path)
val fs = fsPath.getFileSystem(configuration)
val parquetFiles = fs.listStatus(fsPath, new PathFilter {

View file

@ -20,12 +20,13 @@ package org.apache.spark.sql.execution.datasources.parquet
import org.apache.parquet.filter2.predicate.Operators._
import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators}
import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
/**
* A test suite that tests Parquet filter2 API based filter pushdown optimization.
@ -39,8 +40,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
* 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred
* data type is nullable.
*/
class ParquetFilterSuite extends QueryTest with ParquetTest {
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext {
private def checkFilterPredicate(
df: DataFrame,
@ -301,7 +301,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
}
test("SPARK-6554: don't push down predicates which reference partition columns") {
import sqlContext.implicits._
import testImplicits._
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
withTempPath { dir =>

View file

@ -37,6 +37,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
@ -62,9 +63,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS
/**
* A test suite that tests basic Parquet I/O.
*/
class ParquetIOSuite extends QueryTest with ParquetTest {
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.implicits._
class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
import testImplicits._
/**
* Writes `data` to a Parquet file, reads it back and check file contents.

View file

@ -26,13 +26,13 @@ import scala.collection.mutable.ArrayBuffer
import com.google.common.io.Files
import org.apache.hadoop.fs.Path
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.unsafe.types.UTF8String
import PartitioningUtils._
// The data where the partitioning key exists only in the directory structure.
case class ParquetData(intField: Int, stringField: String)
@ -40,11 +40,9 @@ case class ParquetData(intField: Int, stringField: String)
// The data that also includes the partitioning key
case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String)
class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.implicits._
import sqlContext.sql
class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with SharedSQLContext {
import PartitioningUtils._
import testImplicits._
val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__"

View file

@ -17,11 +17,10 @@
package org.apache.spark.sql.execution.datasources.parquet
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.test.SharedSQLContext
class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest {
override def sqlContext: SQLContext = TestSQLContext
class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
private def readParquetProtobufFile(name: String): DataFrame = {
val url = Thread.currentThread().getContextClassLoader.getResource(name)

View file

@ -21,16 +21,15 @@ import java.io.File
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
/**
* A test suite that tests various Parquet queries.
*/
class ParquetQuerySuite extends QueryTest with ParquetTest {
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.sql
class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext {
test("simple select queries") {
withParquetTable((0 until 10).map(i => (i, i.toString)), "t") {
@ -41,22 +40,22 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
test("appending") {
val data = (0 until 10).map(i => (i, i.toString))
sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
withParquetTable(data, "t") {
sql("INSERT INTO TABLE t SELECT * FROM tmp")
checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple))
checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple))
}
sqlContext.catalog.unregisterTable(Seq("tmp"))
ctx.catalog.unregisterTable(Seq("tmp"))
}
test("overwriting") {
val data = (0 until 10).map(i => (i, i.toString))
sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
withParquetTable(data, "t") {
sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple))
checkAnswer(ctx.table("t"), data.map(Row.fromTuple))
}
sqlContext.catalog.unregisterTable(Seq("tmp"))
ctx.catalog.unregisterTable(Seq("tmp"))
}
test("self-join") {
@ -119,9 +118,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
val schema = StructType(List(StructField("d", DecimalType(18, 0), false),
StructField("time", TimestampType, false)).toArray)
withTempPath { file =>
val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema)
val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema)
df.write.parquet(file.getCanonicalPath)
val df2 = sqlContext.read.parquet(file.getCanonicalPath)
val df2 = ctx.read.parquet(file.getCanonicalPath)
checkAnswer(df2, df.collect().toSeq)
}
}
@ -130,12 +129,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
def testSchemaMerging(expectedColumnNumber: Int): Unit = {
withTempDir { dir =>
val basePath = dir.getCanonicalPath
sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
// delete summary files, so if we don't merge part-files, one column will not be included.
Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata"))
Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata"))
assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber)
assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber)
}
}
@ -154,9 +153,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
def testSchemaMerging(expectedColumnNumber: Int): Unit = {
withTempDir { dir =>
val basePath = dir.getCanonicalPath
sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber)
ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber)
}
}
@ -172,19 +171,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") {
withTempPath { dir =>
val basePath = dir.getCanonicalPath
sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString)
ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString)
// Disables the global SQL option for schema merging
withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") {
assertResult(2) {
// Disables schema merging via data source option
sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length
ctx.read.option("mergeSchema", "false").parquet(basePath).columns.length
}
assertResult(3) {
// Enables schema merging via data source option
sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length
ctx.read.option("mergeSchema", "true").parquet(basePath).columns.length
}
}
}

View file

@ -22,13 +22,11 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.parquet.schema.MessageTypeParser
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest {
val sqlContext = TestSQLContext
abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext {
/**
* Checks whether the reflected Parquet message type for product type `T` conforms `messageType`.

View file

@ -22,9 +22,8 @@ import java.io.File
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.{DataFrame, SaveMode}
import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext}
/**
* A helper trait that provides convenient facilities for Parquet testing.
@ -33,7 +32,9 @@ import org.apache.spark.sql.{DataFrame, SaveMode}
* convenient to use tuples rather than special case classes when writing test cases/suites.
* Especially, `Tuple1.apply` can be used to easily wrap a single type/value.
*/
private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
private[sql] trait ParquetTest extends SQLTestUtils {
protected def _sqlContext: SQLContext
/**
* Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f`
* returns.
@ -42,7 +43,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
(data: Seq[T])
(f: String => Unit): Unit = {
withTempPath { file =>
sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
_sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
f(file.getCanonicalPath)
}
}
@ -54,7 +55,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag]
(data: Seq[T])
(f: DataFrame => Unit): Unit = {
withParquetFile(data)(path => f(sqlContext.read.parquet(path)))
withParquetFile(data)(path => f(_sqlContext.read.parquet(path)))
}
/**
@ -66,14 +67,14 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
(data: Seq[T], tableName: String)
(f: => Unit): Unit = {
withParquetDataFrame(data) { df =>
sqlContext.registerDataFrameAsTable(df, tableName)
_sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
data: Seq[T], path: File): Unit = {
sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
_sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
}
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](

View file

@ -17,14 +17,12 @@
package org.apache.spark.sql.execution.datasources.parquet
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext
class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest {
class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
import ParquetCompatibilityTest._
override val sqlContext: SQLContext = TestSQLContext
private val parquetFilePath =
Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet")

View file

@ -18,10 +18,10 @@
package org.apache.spark.sql.execution.debug
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.SharedSQLContext
class DebuggingSuite extends SparkFunSuite with SharedSQLContext {
class DebuggingSuite extends SparkFunSuite {
test("DataFrame.debug()") {
testData.debug()
}

View file

@ -23,12 +23,12 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.util.collection.CompactBuffer
class HashedRelationSuite extends SparkFunSuite {
class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
// Key is simply the record itself
private val keyProjection = new Projection {
@ -37,7 +37,7 @@ class HashedRelationSuite extends SparkFunSuite {
test("GeneralHashedRelation") {
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data")
val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
val hashed = HashedRelation(data.iterator, numDataRows, keyProjection)
assert(hashed.isInstanceOf[GeneralHashedRelation])
@ -53,7 +53,7 @@ class HashedRelationSuite extends SparkFunSuite {
test("UniqueKeyHashedRelation") {
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2))
val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data")
val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
val hashed = HashedRelation(data.iterator, numDataRows, keyProjection)
assert(hashed.isInstanceOf[UniqueKeyHashedRelation])
@ -73,7 +73,7 @@ class HashedRelationSuite extends SparkFunSuite {
test("UnsafeHashedRelation") {
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data")
val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
val toUnsafe = UnsafeProjection.create(schema)
val unsafeData = data.map(toUnsafe(_).copy()).toArray

View file

@ -17,97 +17,19 @@
package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.{DataFrame, execution, Row, SQLConf}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
private def testInnerJoin(
testName: String,
leftRows: DataFrame,
rightRows: DataFrame,
condition: Expression,
expectedAnswer: Seq[Product]): Unit = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
val broadcastHashJoin =
execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right)
boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
}
def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
val shuffledHashJoin =
execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right)
val filteredJoin =
boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
}
def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = {
val sortMergeJoin =
execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right)
val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
}
test(s"$testName using BroadcastHashJoin (build=left)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeBroadcastHashJoin(left, right, joins.BuildLeft),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
test(s"$testName using BroadcastHashJoin (build=right)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeBroadcastHashJoin(left, right, joins.BuildRight),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
test(s"$testName using ShuffledHashJoin (build=left)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeShuffledHashJoin(left, right, joins.BuildLeft),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
test(s"$testName using ShuffledHashJoin (build=right)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeShuffledHashJoin(left, right, joins.BuildRight),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
test(s"$testName using SortMergeJoin") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
makeSortMergeJoin(left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
}
{
val upperCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
private lazy val myUpperCaseData = ctx.createDataFrame(
ctx.sparkContext.parallelize(Seq(
Row(1, "A"),
Row(2, "B"),
Row(3, "C"),
@ -117,7 +39,8 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
Row(null, "G")
)), new StructType().add("N", IntegerType).add("L", StringType))
val lowerCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
private lazy val myLowerCaseData = ctx.createDataFrame(
ctx.sparkContext.parallelize(Seq(
Row(1, "a"),
Row(2, "b"),
Row(3, "c"),
@ -125,21 +48,7 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
Row(null, "e")
)), new StructType().add("n", IntegerType).add("l", StringType))
testInnerJoin(
"inner join, one match per row",
upperCaseData,
lowerCaseData,
(upperCaseData.col("N") === lowerCaseData.col("n")).expr,
Seq(
(1, "A", 1, "a"),
(2, "B", 2, "b"),
(3, "C", 3, "c"),
(4, "D", 4, "d")
)
)
}
private val testData2 = Seq(
private lazy val myTestData = Seq(
(1, 1),
(1, 2),
(2, 1),
@ -148,14 +57,139 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
(3, 2)
).toDF("a", "b")
// Note: the input dataframes and expression must be evaluated lazily because
// the SQLContext should be used only within a test to keep SQL tests stable
private def testInnerJoin(
testName: String,
leftRows: => DataFrame,
rightRows: => DataFrame,
condition: () => Expression,
expectedAnswer: Seq[Product]): Unit = {
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition()))
ExtractEquiJoinKeys.unapply(join)
}
def makeBroadcastHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
boundCondition: Option[Expression],
leftPlan: SparkPlan,
rightPlan: SparkPlan,
side: BuildSide) = {
val broadcastHashJoin =
execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
}
def makeShuffledHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
boundCondition: Option[Expression],
leftPlan: SparkPlan,
rightPlan: SparkPlan,
side: BuildSide) = {
val shuffledHashJoin =
execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
val filteredJoin =
boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
}
def makeSortMergeJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
boundCondition: Option[Expression],
leftPlan: SparkPlan,
rightPlan: SparkPlan) = {
val sortMergeJoin =
execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan)
val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
}
test(s"$testName using BroadcastHashJoin (build=left)") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeBroadcastHashJoin(
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
test(s"$testName using BroadcastHashJoin (build=right)") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeBroadcastHashJoin(
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
test(s"$testName using ShuffledHashJoin (build=left)") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeShuffledHashJoin(
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
test(s"$testName using ShuffledHashJoin (build=right)") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeShuffledHashJoin(
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
test(s"$testName using SortMergeJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
}
testInnerJoin(
"inner join, one match per row",
myUpperCaseData,
myLowerCaseData,
() => (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr,
Seq(
(1, "A", 1, "a"),
(2, "B", 2, "b"),
(3, "C", 3, "c"),
(4, "D", 4, "d")
)
)
{
val left = testData2.where("a = 1")
val right = testData2.where("a = 1")
lazy val left = myTestData.where("a = 1")
lazy val right = myTestData.where("a = 1")
testInnerJoin(
"inner join, multiple matches",
left,
right,
(left.col("a") === right.col("a")).expr,
() => (left.col("a") === right.col("a")).expr,
Seq(
(1, 1, 1, 1),
(1, 1, 1, 2),
@ -166,13 +200,13 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
}
{
val left = testData2.where("a = 1")
val right = testData2.where("a = 2")
lazy val left = myTestData.where("a = 1")
lazy val right = myTestData.where("a = 2")
testInnerJoin(
"inner join, no matches",
left,
right,
(left.col("a") === right.col("a")).expr,
() => (left.col("a") === right.col("a")).expr,
Seq.empty
)
}

View file

@ -17,28 +17,65 @@
package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.{DataFrame, Row, SQLConf}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType}
import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.{EnsureRequirements, joins, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType}
class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
private lazy val left = ctx.createDataFrame(
ctx.sparkContext.parallelize(Seq(
Row(1, 2.0),
Row(2, 100.0),
Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches
Row(2, 1.0),
Row(3, 3.0),
Row(5, 1.0),
Row(6, 6.0),
Row(null, null)
)), new StructType().add("a", IntegerType).add("b", DoubleType))
private lazy val right = ctx.createDataFrame(
ctx.sparkContext.parallelize(Seq(
Row(0, 0.0),
Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches
Row(2, -1.0),
Row(2, -1.0),
Row(2, 3.0),
Row(3, 2.0),
Row(4, 1.0),
Row(5, 3.0),
Row(7, 7.0),
Row(null, null)
)), new StructType().add("c", IntegerType).add("d", DoubleType))
private lazy val condition = {
And((left.col("a") === right.col("c")).expr,
LessThan(left.col("b").expr, right.col("d").expr))
}
// Note: the input dataframes and expression must be evaluated lazily because
// the SQLContext should be used only within a test to keep SQL tests stable
private def testOuterJoin(
testName: String,
leftRows: DataFrame,
rightRows: DataFrame,
leftRows: => DataFrame,
rightRows: => DataFrame,
joinType: JoinType,
condition: Expression,
condition: => Expression,
expectedAnswer: Seq[Product]): Unit = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
test(s"$testName using ShuffledHashOuterJoin") {
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join)
}
test(s"$testName using ShuffledHashOuterJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(sqlContext).apply(
@ -46,19 +83,23 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
}
if (joinType != FullOuter) {
test(s"$testName using BroadcastHashOuterJoin") {
if (joinType != FullOuter) {
test(s"$testName using BroadcastHashOuterJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
}
test(s"$testName using SortMergeOuterJoin") {
test(s"$testName using SortMergeOuterJoin") {
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(sqlContext).apply(
@ -66,57 +107,9 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
expectedAnswer.map(Row.fromTuple),
sortAnswers = false)
}
}
}
}
test(s"$testName using BroadcastNestedLoopJoin (build=left)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
test(s"$testName using BroadcastNestedLoopJoin (build=right)") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
Row(1, 2.0),
Row(2, 100.0),
Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches
Row(2, 1.0),
Row(3, 3.0),
Row(5, 1.0),
Row(6, 6.0),
Row(null, null)
)), new StructType().add("a", IntegerType).add("b", DoubleType))
val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
Row(0, 0.0),
Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches
Row(2, -1.0),
Row(2, -1.0),
Row(2, 3.0),
Row(3, 2.0),
Row(4, 1.0),
Row(5, 3.0),
Row(7, 7.0),
Row(null, null)
)), new StructType().add("c", IntegerType).add("d", DoubleType))
val condition = {
And(
(left.col("a") === right.col("c")).expr,
LessThan(left.col("b").expr, right.col("d").expr))
}
// --- Basic outer joins ------------------------------------------------------------------------

View file

@ -17,44 +17,80 @@
package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {
class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
private lazy val left = ctx.createDataFrame(
ctx.sparkContext.parallelize(Seq(
Row(1, 2.0),
Row(1, 2.0),
Row(2, 1.0),
Row(2, 1.0),
Row(3, 3.0),
Row(null, null),
Row(null, 5.0),
Row(6, null)
)), new StructType().add("a", IntegerType).add("b", DoubleType))
private lazy val right = ctx.createDataFrame(
ctx.sparkContext.parallelize(Seq(
Row(2, 3.0),
Row(2, 3.0),
Row(3, 2.0),
Row(4, 1.0),
Row(null, null),
Row(null, 5.0),
Row(6, null)
)), new StructType().add("c", IntegerType).add("d", DoubleType))
private lazy val condition = {
And((left.col("a") === right.col("c")).expr,
LessThan(left.col("b").expr, right.col("d").expr))
}
// Note: the input dataframes and expression must be evaluated lazily because
// the SQLContext should be used only within a test to keep SQL tests stable
private def testLeftSemiJoin(
testName: String,
leftRows: DataFrame,
rightRows: DataFrame,
condition: Expression,
leftRows: => DataFrame,
rightRows: => DataFrame,
condition: => Expression,
expectedAnswer: Seq[Product]): Unit = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join).foreach {
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
test(s"$testName using LeftSemiJoinHash") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(left.sqlContext).apply(
LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
test(s"$testName using BroadcastLeftSemiJoinHash") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
ExtractEquiJoinKeys.unapply(join)
}
test(s"$testName using LeftSemiJoinHash") {
extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(left.sqlContext).apply(
LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
test(s"$testName using BroadcastLeftSemiJoinHash") {
extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
}
}
test(s"$testName using LeftSemiJoinBNL") {
@ -67,33 +103,6 @@ class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {
}
}
val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
Row(1, 2.0),
Row(1, 2.0),
Row(2, 1.0),
Row(2, 1.0),
Row(3, 3.0),
Row(null, null),
Row(null, 5.0),
Row(6, null)
)), new StructType().add("a", IntegerType).add("b", DoubleType))
val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
Row(2, 3.0),
Row(2, 3.0),
Row(3, 2.0),
Row(4, 1.0),
Row(null, null),
Row(null, 5.0),
Row(6, null)
)), new StructType().add("c", IntegerType).add("d", DoubleType))
val condition = {
And(
(left.col("a") === right.col("c")).expr,
LessThan(left.col("b").expr, right.col("d").expr))
}
testLeftSemiJoin(
"basic test",
left,

View file

@ -28,17 +28,15 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.execution.ui.SparkPlanGraph
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
override val sqlContext = TestSQLContext
import sqlContext.implicits._
class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
test("LongSQLMetric should not box Long") {
val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long")
val l = SQLMetrics.createLongMetric(ctx.sparkContext, "long")
val f = () => {
l += 1L
l.add(1L)
@ -52,7 +50,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
test("Normal accumulator should do boxing") {
// We need this test to make sure BoxingFinder works.
val l = TestSQLContext.sparkContext.accumulator(0L)
val l = ctx.sparkContext.accumulator(0L)
val f = () => { l += 1L }
BoxingFinder.getClassReader(f.getClass).foreach { cl =>
val boxingFinder = new BoxingFinder()
@ -73,19 +71,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
df: DataFrame,
expectedNumOfJobs: Int,
expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = {
val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet
val previousExecutionIds = ctx.listener.executionIdToData.keySet
df.collect()
TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000)
val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
ctx.sparkContext.listenerBus.waitUntilEmpty(10000)
val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds)
assert(executionIds.size === 1)
val executionId = executionIds.head
val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs
val jobs = ctx.listener.getExecution(executionId).get.jobs
// Use "<=" because there is a race condition that we may miss some jobs
// TODO Change it to "=" once we fix the race condition that missing the JobStarted event.
assert(jobs.size <= expectedNumOfJobs)
if (jobs.size == expectedNumOfJobs) {
// If we can track all jobs, check the metric values
val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId)
val metricValues = ctx.listener.getExecutionMetrics(executionId)
val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node =>
expectedMetrics.contains(node.id)
}.map { node =>
@ -111,7 +109,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
SQLConf.TUNGSTEN_ENABLED.key -> "false") {
// Assume the execution plan is
// PhysicalRDD(nodeId = 1) -> Project(nodeId = 0)
val df = TestData.person.select('name)
val df = person.select('name)
testSparkPlanMetrics(df, 1, Map(
0L ->("Project", Map(
"number of rows" -> 2L)))
@ -126,7 +124,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
SQLConf.TUNGSTEN_ENABLED.key -> "true") {
// Assume the execution plan is
// PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0)
val df = TestData.person.select('name)
val df = person.select('name)
testSparkPlanMetrics(df, 1, Map(
0L ->("TungstenProject", Map(
"number of rows" -> 2L)))
@ -137,7 +135,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
test("Filter metrics") {
// Assume the execution plan is
// PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0)
val df = TestData.person.filter('age < 25)
val df = person.filter('age < 25)
testSparkPlanMetrics(df, 1, Map(
0L -> ("Filter", Map(
"number of input rows" -> 2L,
@ -152,7 +150,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
SQLConf.TUNGSTEN_ENABLED.key -> "false") {
// Assume the execution plan is
// ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0)
val df = TestData.testData2.groupBy().count() // 2 partitions
val df = testData2.groupBy().count() // 2 partitions
testSparkPlanMetrics(df, 1, Map(
2L -> ("Aggregate", Map(
"number of input rows" -> 6L,
@ -163,7 +161,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
)
// 2 partitions and each partition contains 2 keys
val df2 = TestData.testData2.groupBy('a).count()
val df2 = testData2.groupBy('a).count()
testSparkPlanMetrics(df2, 1, Map(
2L -> ("Aggregate", Map(
"number of input rows" -> 6L,
@ -185,7 +183,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
// Assume the execution plan is
// ... -> SortBasedAggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) ->
// SortBasedAggregate(nodeId = 0)
val df = TestData.testData2.groupBy().count() // 2 partitions
val df = testData2.groupBy().count() // 2 partitions
testSparkPlanMetrics(df, 1, Map(
2L -> ("SortBasedAggregate", Map(
"number of input rows" -> 6L,
@ -199,7 +197,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
// ... -> SortBasedAggregate(nodeId = 3) -> TungstenExchange(nodeId = 2)
// -> ExternalSort(nodeId = 1)-> SortBasedAggregate(nodeId = 0)
// 2 partitions and each partition contains 2 keys
val df2 = TestData.testData2.groupBy('a).count()
val df2 = testData2.groupBy('a).count()
testSparkPlanMetrics(df2, 1, Map(
3L -> ("SortBasedAggregate", Map(
"number of input rows" -> 6L,
@ -219,7 +217,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
// Assume the execution plan is
// ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1)
// -> TungstenAggregate(nodeId = 0)
val df = TestData.testData2.groupBy().count() // 2 partitions
val df = testData2.groupBy().count() // 2 partitions
testSparkPlanMetrics(df, 1, Map(
2L -> ("TungstenAggregate", Map(
"number of input rows" -> 6L,
@ -230,7 +228,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
)
// 2 partitions and each partition contains 2 keys
val df2 = TestData.testData2.groupBy('a).count()
val df2 = testData2.groupBy('a).count()
testSparkPlanMetrics(df2, 1, Map(
2L -> ("TungstenAggregate", Map(
"number of input rows" -> 6L,
@ -246,7 +244,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
// Because SortMergeJoin may skip different rows if the number of partitions is different, this
// test should use the deterministic number of partitions.
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
@ -268,7 +266,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
// Because SortMergeOuterJoin may skip different rows if the number of partitions is different,
// this test should use the deterministic number of partitions.
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
@ -314,7 +312,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
test("ShuffledHashJoin metrics") {
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") {
val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
@ -390,7 +388,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
test("BroadcastNestedLoopJoin metrics") {
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
@ -458,7 +456,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
}
test("CartesianProduct metrics") {
val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
@ -476,19 +474,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
test("save metrics") {
withTempPath { file =>
val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet
val previousExecutionIds = ctx.listener.executionIdToData.keySet
// Assume the execution plan is
// PhysicalRDD(nodeId = 0)
TestData.person.select('name).write.format("json").save(file.getAbsolutePath)
TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000)
val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
person.select('name).write.format("json").save(file.getAbsolutePath)
ctx.sparkContext.listenerBus.waitUntilEmpty(10000)
val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds)
assert(executionIds.size === 1)
val executionId = executionIds.head
val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs
val jobs = ctx.listener.getExecution(executionId).get.jobs
// Use "<=" because there is a race condition that we may miss some jobs
// TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event.
assert(jobs.size <= 1)
val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId)
val metricValues = ctx.listener.getExecutionMetrics(executionId)
// Because "save" will create a new DataFrame internally, we cannot get the real metric id.
// However, we still can check the value.
assert(metricValues.values.toSeq === Seq(2L))

View file

@ -25,12 +25,12 @@ import org.apache.spark.sql.execution.metric.LongSQLMetricValue
import org.apache.spark.scheduler._
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.SharedSQLContext
class SQLListenerSuite extends SparkFunSuite {
class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
private def createTestDataFrame: DataFrame = {
import TestSQLContext.implicits._
Seq(
(1, 1),
(2, 2)
@ -74,7 +74,7 @@ class SQLListenerSuite extends SparkFunSuite {
}
test("basic") {
val listener = new SQLListener(TestSQLContext)
val listener = new SQLListener(ctx)
val executionId = 0
val df = createTestDataFrame
val accumulatorIds =
@ -212,7 +212,7 @@ class SQLListenerSuite extends SparkFunSuite {
}
test("onExecutionEnd happens before onJobEnd(JobSucceeded)") {
val listener = new SQLListener(TestSQLContext)
val listener = new SQLListener(ctx)
val executionId = 0
val df = createTestDataFrame
listener.onExecutionStart(
@ -241,7 +241,7 @@ class SQLListenerSuite extends SparkFunSuite {
}
test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") {
val listener = new SQLListener(TestSQLContext)
val listener = new SQLListener(ctx)
val executionId = 0
val df = createTestDataFrame
listener.onExecutionStart(
@ -281,7 +281,7 @@ class SQLListenerSuite extends SparkFunSuite {
}
test("onExecutionEnd happens before onJobEnd(JobFailed)") {
val listener = new SQLListener(TestSQLContext)
val listener = new SQLListener(ctx)
val executionId = 0
val df = createTestDataFrame
listener.onExecutionStart(

View file

@ -25,10 +25,13 @@ import org.h2.jdbc.JdbcSQLException
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
import testImplicits._
val url = "jdbc:h2:mem:testdb0"
val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
var conn: java.sql.Connection = null
@ -42,10 +45,6 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
Some(StringType)
}
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import ctx.sql
before {
Utils.classForName("org.h2.Driver")
// Extra properties that will be specified for our database. We need these to test

View file

@ -23,11 +23,13 @@ import java.util.Properties
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{SaveMode, Row}
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
val url = "jdbc:h2:mem:testdb2"
var conn: java.sql.Connection = null
val url1 = "jdbc:h2:mem:testdb3"
@ -37,10 +39,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
properties.setProperty("password", "testPass")
properties.setProperty("rowId", "false")
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import ctx.sql
before {
Utils.classForName("org.h2.Driver")
conn = DriverManager.getConnection(url)
@ -58,14 +56,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
"create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
conn1.commit()
ctx.sql(
sql(
s"""
|CREATE TEMPORARY TABLE PEOPLE
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
ctx.sql(
sql(
s"""
|CREATE TEMPORARY TABLE PEOPLE1
|USING org.apache.spark.sql.jdbc
@ -144,14 +142,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
}
test("INSERT to JDBC Datasource") {
ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
test("INSERT to JDBC Datasource with overwrite") {
ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}

View file

@ -19,28 +19,32 @@ package org.apache.spark.sql.sources
import java.io.{File, IOException}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.execution.datasources.DDLException
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
import caseInsensitiveContext.sql
class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter {
protected override lazy val sql = caseInsensitiveContext.sql _
private lazy val sparkContext = caseInsensitiveContext.sparkContext
var path: File = null
private var path: File = null
override def beforeAll(): Unit = {
super.beforeAll()
path = Utils.createTempDir()
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
caseInsensitiveContext.read.json(rdd).registerTempTable("jt")
}
override def afterAll(): Unit = {
caseInsensitiveContext.dropTempTable("jt")
try {
caseInsensitiveContext.dropTempTable("jt")
} finally {
super.afterAll()
}
}
after {

View file

@ -18,11 +18,12 @@
package org.apache.spark.sql.sources
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}
// please note that the META-INF/services had to be modified for the test directory for this to work
class DDLSourceLoadSuite extends DataSourceTest {
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {
test("data sources with the same name") {
intercept[RuntimeException] {

View file

@ -20,6 +20,7 @@ package org.apache.spark.sql.sources
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@ -68,10 +69,12 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo
}
}
class DDLTestSuite extends DataSourceTest {
class DDLTestSuite extends DataSourceTest with SharedSQLContext {
protected override lazy val sql = caseInsensitiveContext.sql _
before {
caseInsensitiveContext.sql(
override def beforeAll(): Unit = {
super.beforeAll()
sql(
"""
|CREATE TEMPORARY TABLE ddlPeople
|USING org.apache.spark.sql.sources.DDLScanSource
@ -105,7 +108,7 @@ class DDLTestSuite extends DataSourceTest {
))
test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") {
val attributes = caseInsensitiveContext.sql("describe ddlPeople")
val attributes = sql("describe ddlPeople")
.queryExecution.executedPlan.output
assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment"))
assert(attributes.map(_.dataType).toSet === Set(StringType))

View file

@ -17,18 +17,23 @@
package org.apache.spark.sql.sources
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql._
import org.apache.spark.sql.test.TestSQLContext
abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
private[sql] abstract class DataSourceTest extends QueryTest {
protected def _sqlContext: SQLContext
// We want to test some edge cases.
protected implicit lazy val caseInsensitiveContext = {
val ctx = new SQLContext(TestSQLContext.sparkContext)
protected lazy val caseInsensitiveContext: SQLContext = {
val ctx = new SQLContext(_sqlContext.sparkContext)
ctx.setConf(SQLConf.CASE_SENSITIVE, false)
ctx
}
protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row]) {
test(sqlString) {
checkAnswer(caseInsensitiveContext.sql(sqlString), expectedAnswer)
}
}
}

View file

@ -21,6 +21,7 @@ import scala.language.existentials
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@ -96,11 +97,11 @@ object FiltersPushed {
var list: Seq[Filter] = Nil
}
class FilteredScanSuite extends DataSourceTest {
class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
protected override lazy val sql = caseInsensitiveContext.sql _
import caseInsensitiveContext.sql
before {
override def beforeAll(): Unit = {
super.beforeAll()
sql(
"""
|CREATE TEMPORARY TABLE oneToTenFiltered

View file

@ -19,20 +19,17 @@ package org.apache.spark.sql.sources
import java.io.File
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.{SaveMode, AnalysisException, Row}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
import caseInsensitiveContext.sql
class InsertSuite extends DataSourceTest with SharedSQLContext {
protected override lazy val sql = caseInsensitiveContext.sql _
private lazy val sparkContext = caseInsensitiveContext.sparkContext
var path: File = null
private var path: File = null
override def beforeAll(): Unit = {
super.beforeAll()
path = Utils.createTempDir()
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""))
caseInsensitiveContext.read.json(rdd).registerTempTable("jt")
@ -47,9 +44,13 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
}
override def afterAll(): Unit = {
caseInsensitiveContext.dropTempTable("jsonTable")
caseInsensitiveContext.dropTempTable("jt")
Utils.deleteRecursively(path)
try {
caseInsensitiveContext.dropTempTable("jsonTable")
caseInsensitiveContext.dropTempTable("jt")
Utils.deleteRecursively(path)
} finally {
super.afterAll()
}
}
test("Simple INSERT OVERWRITE a JSONRelation") {
@ -221,9 +222,10 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
sql("SELECT a * 2 FROM jsonTable"),
(1 to 10).map(i => Row(i * 2)).toSeq)
assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2)
checkAnswer(
sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"),
assertCached(sql(
"SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2)
checkAnswer(sql(
"SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"),
(2 to 10).map(i => Row(i, i - 1)).toSeq)
// Insert overwrite and keep the same schema.

View file

@ -19,21 +19,21 @@ package org.apache.spark.sql.sources
import org.apache.spark.sql.{Row, QueryTest}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
class PartitionedWriteSuite extends QueryTest {
import TestSQLContext.implicits._
class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("write many partitions") {
val path = Utils.createTempDir()
path.delete()
val df = TestSQLContext.range(100).select($"id", lit(1).as("data"))
val df = ctx.range(100).select($"id", lit(1).as("data"))
df.write.partitionBy("id").save(path.getCanonicalPath)
checkAnswer(
TestSQLContext.read.load(path.getCanonicalPath),
ctx.read.load(path.getCanonicalPath),
(0 to 99).map(Row(1, _)).toSeq)
Utils.deleteRecursively(path)
@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest {
val path = Utils.createTempDir()
path.delete()
val base = TestSQLContext.range(100)
val base = ctx.range(100)
val df = base.unionAll(base).select($"id", lit(1).as("data"))
df.write.partitionBy("id").save(path.getCanonicalPath)
checkAnswer(
TestSQLContext.read.load(path.getCanonicalPath),
ctx.read.load(path.getCanonicalPath),
(0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq)
Utils.deleteRecursively(path)

View file

@ -21,6 +21,7 @@ import scala.language.existentials
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
class PrunedScanSource extends RelationProvider {
@ -51,10 +52,12 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo
}
}
class PrunedScanSuite extends DataSourceTest {
class PrunedScanSuite extends DataSourceTest with SharedSQLContext {
protected override lazy val sql = caseInsensitiveContext.sql _
before {
caseInsensitiveContext.sql(
override def beforeAll(): Unit = {
super.beforeAll()
sql(
"""
|CREATE TEMPORARY TABLE oneToTenPruned
|USING org.apache.spark.sql.sources.PrunedScanSource
@ -114,7 +117,7 @@ class PrunedScanSuite extends DataSourceTest {
def testPruning(sqlString: String, expectedColumns: String*): Unit = {
test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") {
val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution
val queryExecution = sql(sqlString).queryExecution
val rawPlan = queryExecution.executedPlan.collect {
case p: execution.PhysicalRDD => p
} match {

View file

@ -19,25 +19,22 @@ package org.apache.spark.sql.sources
import java.io.File
import org.scalatest.BeforeAndAfterAll
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
import caseInsensitiveContext.sql
class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter {
protected override lazy val sql = caseInsensitiveContext.sql _
private lazy val sparkContext = caseInsensitiveContext.sparkContext
var originalDefaultSource: String = null
var path: File = null
var df: DataFrame = null
private var originalDefaultSource: String = null
private var path: File = null
private var df: DataFrame = null
override def beforeAll(): Unit = {
super.beforeAll()
originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName
path = Utils.createTempDir()
@ -49,11 +46,14 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
}
override def afterAll(): Unit = {
caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
try {
caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
} finally {
super.afterAll()
}
}
after {
caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
Utils.deleteRecursively(path)
}

View file

@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
class DefaultSource extends SimpleScanSource
@ -95,8 +96,8 @@ case class AllDataTypesScan(
}
}
class TableScanSuite extends DataSourceTest {
import caseInsensitiveContext.sql
class TableScanSuite extends DataSourceTest with SharedSQLContext {
protected override lazy val sql = caseInsensitiveContext.sql _
private lazy val tableWithSchemaExpected = (1 to 10).map { i =>
Row(
@ -122,7 +123,8 @@ class TableScanSuite extends DataSourceTest {
Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(Date.valueOf(s"1970-01-${i + 1}")))))
}.toSeq
before {
override def beforeAll(): Unit = {
super.beforeAll()
sql(
"""
|CREATE TEMPORARY TABLE oneToTen
@ -303,9 +305,10 @@ class TableScanSuite extends DataSourceTest {
sql("SELECT i * 2 FROM oneToTen"),
(1 to 10).map(i => Row(i * 2)).toSeq)
assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2)
checkAnswer(
sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"),
assertCached(sql(
"SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2)
checkAnswer(sql(
"SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"),
(2 to 10).map(i => Row(i, i - 1)).toSeq)
// Verify uncaching

View file

@ -0,0 +1,290 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.test
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits}
/**
* A collection of sample data used in SQL tests.
*/
private[sql] trait SQLTestData { self =>
protected def _sqlContext: SQLContext
// Helper object to import SQL implicits without a concrete SQLContext
private object internalImplicits extends SQLImplicits {
protected override def _sqlContext: SQLContext = self._sqlContext
}
import internalImplicits._
import SQLTestData._
// Note: all test data should be lazy because the SQLContext is not set up yet.
protected lazy val testData: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString))).toDF()
df.registerTempTable("testData")
df
}
protected lazy val testData2: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
TestData2(1, 1) ::
TestData2(1, 2) ::
TestData2(2, 1) ::
TestData2(2, 2) ::
TestData2(3, 1) ::
TestData2(3, 2) :: Nil, 2).toDF()
df.registerTempTable("testData2")
df
}
protected lazy val testData3: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
TestData3(1, None) ::
TestData3(2, Some(2)) :: Nil).toDF()
df.registerTempTable("testData3")
df
}
protected lazy val negativeData: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(-i, (-i).toString))).toDF()
df.registerTempTable("negativeData")
df
}
protected lazy val largeAndSmallInts: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
LargeAndSmallInts(2147483644, 1) ::
LargeAndSmallInts(1, 2) ::
LargeAndSmallInts(2147483645, 1) ::
LargeAndSmallInts(2, 2) ::
LargeAndSmallInts(2147483646, 1) ::
LargeAndSmallInts(3, 2) :: Nil).toDF()
df.registerTempTable("largeAndSmallInts")
df
}
protected lazy val decimalData: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
DecimalData(1, 1) ::
DecimalData(1, 2) ::
DecimalData(2, 1) ::
DecimalData(2, 2) ::
DecimalData(3, 1) ::
DecimalData(3, 2) :: Nil).toDF()
df.registerTempTable("decimalData")
df
}
protected lazy val binaryData: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
BinaryData("12".getBytes, 1) ::
BinaryData("22".getBytes, 5) ::
BinaryData("122".getBytes, 3) ::
BinaryData("121".getBytes, 2) ::
BinaryData("123".getBytes, 4) :: Nil).toDF()
df.registerTempTable("binaryData")
df
}
protected lazy val upperCaseData: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
UpperCaseData(1, "A") ::
UpperCaseData(2, "B") ::
UpperCaseData(3, "C") ::
UpperCaseData(4, "D") ::
UpperCaseData(5, "E") ::
UpperCaseData(6, "F") :: Nil).toDF()
df.registerTempTable("upperCaseData")
df
}
protected lazy val lowerCaseData: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
LowerCaseData(1, "a") ::
LowerCaseData(2, "b") ::
LowerCaseData(3, "c") ::
LowerCaseData(4, "d") :: Nil).toDF()
df.registerTempTable("lowerCaseData")
df
}
protected lazy val arrayData: RDD[ArrayData] = {
val rdd = _sqlContext.sparkContext.parallelize(
ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
rdd.toDF().registerTempTable("arrayData")
rdd
}
protected lazy val mapData: RDD[MapData] = {
val rdd = _sqlContext.sparkContext.parallelize(
MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
MapData(Map(1 -> "a4", 2 -> "b4")) ::
MapData(Map(1 -> "a5")) :: Nil)
rdd.toDF().registerTempTable("mapData")
rdd
}
protected lazy val repeatedData: RDD[StringData] = {
val rdd = _sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
rdd.toDF().registerTempTable("repeatedData")
rdd
}
protected lazy val nullableRepeatedData: RDD[StringData] = {
val rdd = _sqlContext.sparkContext.parallelize(
List.fill(2)(StringData(null)) ++
List.fill(2)(StringData("test")))
rdd.toDF().registerTempTable("nullableRepeatedData")
rdd
}
protected lazy val nullInts: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
NullInts(1) ::
NullInts(2) ::
NullInts(3) ::
NullInts(null) :: Nil).toDF()
df.registerTempTable("nullInts")
df
}
protected lazy val allNulls: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
NullInts(null) ::
NullInts(null) ::
NullInts(null) ::
NullInts(null) :: Nil).toDF()
df.registerTempTable("allNulls")
df
}
protected lazy val nullStrings: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
NullStrings(1, "abc") ::
NullStrings(2, "ABC") ::
NullStrings(3, null) :: Nil).toDF()
df.registerTempTable("nullStrings")
df
}
protected lazy val tableName: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF()
df.registerTempTable("tableName")
df
}
protected lazy val unparsedStrings: RDD[String] = {
_sqlContext.sparkContext.parallelize(
"1, A1, true, null" ::
"2, B2, false, null" ::
"3, C3, true, null" ::
"4, D4, true, 2147483644" :: Nil)
}
// An RDD with 4 elements and 8 partitions
protected lazy val withEmptyParts: RDD[IntField] = {
val rdd = _sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
rdd.toDF().registerTempTable("withEmptyParts")
rdd
}
protected lazy val person: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
Person(0, "mike", 30) ::
Person(1, "jim", 20) :: Nil).toDF()
df.registerTempTable("person")
df
}
protected lazy val salary: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
Salary(0, 2000.0) ::
Salary(1, 1000.0) :: Nil).toDF()
df.registerTempTable("salary")
df
}
protected lazy val complexData: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) ::
ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) ::
Nil).toDF()
df.registerTempTable("complexData")
df
}
/**
* Initialize all test data such that all temp tables are properly registered.
*/
def loadTestData(): Unit = {
assert(_sqlContext != null, "attempted to initialize test data before SQLContext.")
testData
testData2
testData3
negativeData
largeAndSmallInts
decimalData
binaryData
upperCaseData
lowerCaseData
arrayData
mapData
repeatedData
nullableRepeatedData
nullInts
allNulls
nullStrings
tableName
unparsedStrings
withEmptyParts
person
salary
complexData
}
}
/**
* Case classes used in test data.
*/
private[sql] object SQLTestData {
case class TestData(key: Int, value: String)
case class TestData2(a: Int, b: Int)
case class TestData3(a: Int, b: Option[Int])
case class LargeAndSmallInts(a: Int, b: Int)
case class DecimalData(a: BigDecimal, b: BigDecimal)
case class BinaryData(a: Array[Byte], b: Int)
case class UpperCaseData(N: Int, L: String)
case class LowerCaseData(n: Int, l: String)
case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
case class MapData(data: scala.collection.Map[Int, String])
case class StringData(s: String)
case class IntField(i: Int)
case class NullInts(a: Integer)
case class NullStrings(n: Int, s: String)
case class TableName(tableName: String)
case class Person(id: Int, name: String, age: Int)
case class Salary(personId: Int, salary: Double)
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
}

View file

@ -21,15 +21,71 @@ import java.io.File
import java.util.UUID
import scala.util.Try
import scala.language.implicitConversions
import org.apache.hadoop.conf.Configuration
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.util.Utils
trait SQLTestUtils { this: SparkFunSuite =>
protected def sqlContext: SQLContext
/**
* Helper trait that should be extended by all SQL test suites.
*
* This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data
* prepared in advance as well as all implicit conversions used extensively by dataframes.
* To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]].
*
* Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is
* prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM.
*/
private[sql] trait SQLTestUtils
extends SparkFunSuite
with BeforeAndAfterAll
with SQLTestData { self =>
protected def configuration = sqlContext.sparkContext.hadoopConfiguration
protected def _sqlContext: SQLContext
// Whether to materialize all test data before the first test is run
private var loadTestDataBeforeTests = false
// Shorthand for running a query using our SQLContext
protected lazy val sql = _sqlContext.sql _
/**
* A helper object for importing SQL implicits.
*
* Note that the alternative of importing `sqlContext.implicits._` is not possible here.
* This is because we create the [[SQLContext]] immediately before the first test is run,
* but the implicits import is needed in the constructor.
*/
protected object testImplicits extends SQLImplicits {
protected override def _sqlContext: SQLContext = self._sqlContext
}
/**
* Materialize the test data immediately after the [[SQLContext]] is set up.
* This is necessary if the data is accessed by name but not through direct reference.
*/
protected def setupTestData(): Unit = {
loadTestDataBeforeTests = true
}
protected override def beforeAll(): Unit = {
super.beforeAll()
if (loadTestDataBeforeTests) {
loadTestData()
}
}
/**
* The Hadoop configuration used by the active [[SQLContext]].
*/
protected def configuration: Configuration = {
_sqlContext.sparkContext.hadoopConfiguration
}
/**
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL
@ -39,12 +95,12 @@ trait SQLTestUtils { this: SparkFunSuite =>
*/
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val (keys, values) = pairs.unzip
val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption)
(keys, values).zipped.foreach(sqlContext.conf.setConfString)
val currentValues = keys.map(key => Try(_sqlContext.conf.getConfString(key)).toOption)
(keys, values).zipped.foreach(_sqlContext.conf.setConfString)
try f finally {
keys.zip(currentValues).foreach {
case (key, Some(value)) => sqlContext.conf.setConfString(key, value)
case (key, None) => sqlContext.conf.unsetConf(key)
case (key, Some(value)) => _sqlContext.conf.setConfString(key, value)
case (key, None) => _sqlContext.conf.unsetConf(key)
}
}
}
@ -76,7 +132,7 @@ trait SQLTestUtils { this: SparkFunSuite =>
* Drops temporary table `tableName` after calling `f`.
*/
protected def withTempTable(tableNames: String*)(f: => Unit): Unit = {
try f finally tableNames.foreach(sqlContext.dropTempTable)
try f finally tableNames.foreach(_sqlContext.dropTempTable)
}
/**
@ -85,7 +141,7 @@ trait SQLTestUtils { this: SparkFunSuite =>
protected def withTable(tableNames: String*)(f: => Unit): Unit = {
try f finally {
tableNames.foreach { name =>
sqlContext.sql(s"DROP TABLE IF EXISTS $name")
_sqlContext.sql(s"DROP TABLE IF EXISTS $name")
}
}
}
@ -98,12 +154,12 @@ trait SQLTestUtils { this: SparkFunSuite =>
val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}"
try {
sqlContext.sql(s"CREATE DATABASE $dbName")
_sqlContext.sql(s"CREATE DATABASE $dbName")
} catch { case cause: Throwable =>
fail("Failed to create temporary database", cause)
}
try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE")
try f(dbName) finally _sqlContext.sql(s"DROP DATABASE $dbName CASCADE")
}
/**
@ -111,7 +167,15 @@ trait SQLTestUtils { this: SparkFunSuite =>
* `f` returns.
*/
protected def activateDatabase(db: String)(f: => Unit): Unit = {
sqlContext.sql(s"USE $db")
try f finally sqlContext.sql(s"USE default")
_sqlContext.sql(s"USE $db")
try f finally _sqlContext.sql(s"USE default")
}
/**
* Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier
* way to construct [[DataFrame]] directly out of local data without relying on implicits.
*/
protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
DataFrame(_sqlContext, plan)
}
}

View file

@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.test
import org.apache.spark.sql.SQLContext
/**
* Helper trait for SQL test suites where all tests share a single [[TestSQLContext]].
*/
private[sql] trait SharedSQLContext extends SQLTestUtils {
/**
* The [[TestSQLContext]] to use for all tests in this suite.
*
* By default, the underlying [[org.apache.spark.SparkContext]] will be run in local
* mode with the default test configurations.
*/
private var _ctx: TestSQLContext = null
/**
* The [[TestSQLContext]] to use for all tests in this suite.
*/
protected def ctx: TestSQLContext = _ctx
protected def sqlContext: TestSQLContext = _ctx
protected override def _sqlContext: SQLContext = _ctx
/**
* Initialize the [[TestSQLContext]].
*/
protected override def beforeAll(): Unit = {
if (_ctx == null) {
_ctx = new TestSQLContext
}
// Ensure we have initialized the context before calling parent code
super.beforeAll()
}
/**
* Stop the underlying [[org.apache.spark.SparkContext]], if any.
*/
protected override def afterAll(): Unit = {
try {
if (_ctx != null) {
_ctx.sparkContext.stop()
_ctx = null
}
} finally {
super.afterAll()
}
}
}

View file

@ -17,40 +17,36 @@
package org.apache.spark.sql.test
import scala.language.implicitConversions
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.{SQLConf, SQLContext}
/** A SQLContext that can be used for local testing. */
class LocalSQLContext
extends SQLContext(
new SparkContext("local[2]", "TestSQLContext", new SparkConf()
.set("spark.sql.testkey", "true")
// SPARK-8910
.set("spark.ui.enabled", "false"))) {
override protected[sql] def createSession(): SQLSession = {
new this.SQLSession()
/**
* A special [[SQLContext]] prepared for testing.
*/
private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self =>
def this() {
this(new SparkContext("local[2]", "test-sql-context",
new SparkConf().set("spark.sql.testkey", "true")))
}
// Use fewer partitions to speed up testing
protected[sql] override def createSession(): SQLSession = new this.SQLSession()
/** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */
protected[sql] class SQLSession extends super.SQLSession {
protected[sql] override lazy val conf: SQLConf = new SQLConf {
/** Fewer partitions to speed up testing. */
override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5)
}
}
/**
* Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to
* construct [[DataFrame]] directly out of local data without relying on implicits.
*/
protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
DataFrame(this, plan)
// Needed for Java tests
def loadTestData(): Unit = {
testData.loadTestData()
}
private object testData extends SQLTestData {
protected override def _sqlContext: SQLContext = self
}
}
object TestSQLContext extends LocalSQLContext

View file

@ -27,7 +27,6 @@ import org.scalatest.concurrent.Eventually._
import org.scalatest.selenium.WebBrowser
import org.scalatest.time.SpanSugar._
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.ui.SparkUICssErrorHandler
class UISeleniumSuite
@ -36,7 +35,6 @@ class UISeleniumSuite
implicit var webDriver: WebDriver = _
var server: HiveThriftServer2 = _
var hc: HiveContext = _
val uiPort = 20000 + Random.nextInt(10000)
override def mode: ServerMode.Value = ServerMode.binary

View file

@ -26,7 +26,7 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.sources.DataSourceTest
import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils}
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.{Row, SaveMode, SQLContext}
import org.apache.spark.{Logging, SparkFunSuite}
@ -53,7 +53,8 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging {
}
class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils {
override val sqlContext = TestHive
override def _sqlContext: SQLContext = TestHive
import testImplicits._
private val testDF = range(1, 3).select(
('id + 0.1) cast DecimalType(10, 3) as 'd1,

View file

@ -19,14 +19,13 @@ package org.apache.spark.sql.hive
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.{QueryTest, Row, SQLContext}
case class Cases(lower: String, UPPER: String)
class HiveParquetSuite extends QueryTest with ParquetTest {
val sqlContext = TestHive
import sqlContext._
private val ctx = TestHive
override def _sqlContext: SQLContext = ctx
test("Case insensitive attribute names") {
withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") {
@ -54,7 +53,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest {
test("Converting Hive to Parquet Table via saveAsParquetFile") {
withTempPath { dir =>
sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath)
read.parquet(dir.getCanonicalPath).registerTempTable("p")
ctx.read.parquet(dir.getCanonicalPath).registerTempTable("p")
withTempTable("p") {
checkAnswer(
sql("SELECT * FROM src ORDER BY key"),
@ -67,7 +66,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest {
withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") {
withTempPath { file =>
sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath)
read.parquet(file.getCanonicalPath).registerTempTable("p")
ctx.read.parquet(file.getCanonicalPath).registerTempTable("p")
withTempTable("p") {
// let's do three overwrites for good measure
sql("INSERT OVERWRITE TABLE p SELECT * FROM t")

View file

@ -22,7 +22,6 @@ import java.io.{IOException, File}
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapred.InvalidInputException
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.Logging
@ -42,7 +41,8 @@ import org.apache.spark.util.Utils
*/
class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll
with Logging {
override val sqlContext = TestHive
override def _sqlContext: SQLContext = TestHive
private val sqlContext = _sqlContext
var jsonFilePath: String = _

View file

@ -22,9 +22,8 @@ import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode}
class MultiDatabaseSuite extends QueryTest with SQLTestUtils {
override val sqlContext: SQLContext = TestHive
import sqlContext.sql
override val _sqlContext: SQLContext = TestHive
private val sqlContext = _sqlContext
private val df = sqlContext.range(10).coalesce(1)

View file

@ -26,7 +26,8 @@ import org.apache.spark.sql.{Row, SQLConf, SQLContext}
class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest {
import ParquetCompatibilityTest.makeNullable
override val sqlContext: SQLContext = TestHive
override def _sqlContext: SQLContext = TestHive
private val sqlContext = _sqlContext
/**
* Set the staging directory (and hence path to ignore Parquet files under)

View file

@ -17,14 +17,12 @@
package org.apache.spark.sql.hive
import org.apache.spark.sql.{Row, QueryTest}
import org.apache.spark.sql.QueryTest
case class FunctionResult(f1: String, f2: String)
class UDFSuite extends QueryTest {
private lazy val ctx = org.apache.spark.sql.hive.test.TestHive
import ctx.implicits._
test("UDF case insensitive") {
ctx.udf.register("random0", () => { Math.random() })

View file

@ -17,17 +17,18 @@
package org.apache.spark.sql.hive.execution
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql._
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql._
import org.scalatest.BeforeAndAfterAll
import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
override val sqlContext = TestHive
override def _sqlContext: SQLContext = TestHive
protected val sqlContext = _sqlContext
import sqlContext.implicits._
var originalUseAggregate2: Boolean = _

View file

@ -26,8 +26,8 @@ import org.apache.spark.sql.test.SQLTestUtils
* A set of tests that validates support for Hive Explain command.
*/
class HiveExplainSuite extends QueryTest with SQLTestUtils {
def sqlContext: SQLContext = TestHive
override def _sqlContext: SQLContext = TestHive
private val sqlContext = _sqlContext
test("explain extended command") {
checkExistence(sql(" explain select * from src where key=123 "), true,

View file

@ -66,7 +66,8 @@ class MyDialect extends DefaultParserDialect
* valid, but Hive currently cannot execute it.
*/
class SQLQuerySuite extends QueryTest with SQLTestUtils {
override def sqlContext: SQLContext = TestHive
override def _sqlContext: SQLContext = TestHive
private val sqlContext = _sqlContext
test("UDTF") {
sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}")

View file

@ -31,7 +31,8 @@ import org.apache.spark.sql.types.StringType
class ScriptTransformationSuite extends SparkPlanTest {
override def sqlContext: SQLContext = TestHive
override def _sqlContext: SQLContext = TestHive
private val sqlContext = _sqlContext
private val noSerdeIOSchema = HiveScriptIOSchema(
inputRowFormat = Seq.empty,

View file

@ -27,8 +27,8 @@ import org.apache.spark.sql._
import org.apache.spark.sql.test.SQLTestUtils
private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite =>
lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive
protected override def _sqlContext: SQLContext = org.apache.spark.sql.hive.test.TestHive
protected val sqlContext = _sqlContext
import sqlContext.implicits._
import sqlContext.sparkContext

View file

@ -685,7 +685,8 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
* A collection of tests for parquet data with various forms of partitioning.
*/
abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
override def sqlContext: SQLContext = TestHive
override def _sqlContext: SQLContext = TestHive
protected val sqlContext = _sqlContext
var partitionedTableDir: File = null
var normalTableDir: File = null

View file

@ -18,14 +18,16 @@
package org.apache.spark.sql.sources
import org.apache.hadoop.fs.Path
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils
class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils {
override val sqlContext = TestHive
override def _sqlContext: SQLContext = TestHive
private val sqlContext = _sqlContext
// When committing a task, `CommitFailureTestSource` throws an exception for testing purpose.
val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName

View file

@ -34,9 +34,8 @@ import org.apache.spark.sql.types._
abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
override lazy val sqlContext: SQLContext = TestHive
import sqlContext.sql
override def _sqlContext: SQLContext = TestHive
protected val sqlContext = _sqlContext
import sqlContext.implicits._
val dataSourceName: String