[SPARK-9580] [SQL] Replace singletons in SQL tests
A fundamental limitation of the existing SQL tests is that *there is simply no way to create your own `SparkContext`*. This is a serious limitation because the user may wish to use a different master or config. As a case in point, `BroadcastJoinSuite` is entirely commented out because there is no way to make it pass with the existing infrastructure. This patch removes the singletons `TestSQLContext` and `TestData`, and instead introduces a `SharedSQLContext` that starts a context per suite. Unfortunately the singletons were so ingrained in the SQL tests that this patch necessarily needed to touch *all* the SQL test files. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/8111) <!-- Reviewable:end --> Author: Andrew Or <andrew@databricks.com> Closes #8111 from andrewor14/sql-tests-refactor.
This commit is contained in:
parent
c50f97dafd
commit
8187b3ae47
|
@ -178,6 +178,16 @@ object MimaExcludes {
|
|||
// SPARK-4751 Dynamic allocation for standalone mode
|
||||
ProblemFilters.exclude[MissingMethodProblem](
|
||||
"org.apache.spark.SparkContext.supportDynamicAllocation")
|
||||
) ++ Seq(
|
||||
// SPARK-9580: Remove SQL test singletons
|
||||
ProblemFilters.exclude[MissingClassProblem](
|
||||
"org.apache.spark.sql.test.LocalSQLContext$SQLSession"),
|
||||
ProblemFilters.exclude[MissingClassProblem](
|
||||
"org.apache.spark.sql.test.LocalSQLContext"),
|
||||
ProblemFilters.exclude[MissingClassProblem](
|
||||
"org.apache.spark.sql.test.TestSQLContext"),
|
||||
ProblemFilters.exclude[MissingClassProblem](
|
||||
"org.apache.spark.sql.test.TestSQLContext$")
|
||||
) ++ Seq(
|
||||
// SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs
|
||||
ProblemFilters.exclude[IncompatibleResultTypeProblem](
|
||||
|
|
|
@ -319,6 +319,8 @@ object SQL {
|
|||
lazy val settings = Seq(
|
||||
initialCommands in console :=
|
||||
"""
|
||||
|import org.apache.spark.SparkContext
|
||||
|import org.apache.spark.sql.SQLContext
|
||||
|import org.apache.spark.sql.catalyst.analysis._
|
||||
|import org.apache.spark.sql.catalyst.dsl._
|
||||
|import org.apache.spark.sql.catalyst.errors._
|
||||
|
@ -328,9 +330,14 @@ object SQL {
|
|||
|import org.apache.spark.sql.catalyst.util._
|
||||
|import org.apache.spark.sql.execution
|
||||
|import org.apache.spark.sql.functions._
|
||||
|import org.apache.spark.sql.test.TestSQLContext._
|
||||
|import org.apache.spark.sql.types._""".stripMargin,
|
||||
cleanupCommands in console := "sparkContext.stop()"
|
||||
|import org.apache.spark.sql.types._
|
||||
|
|
||||
|val sc = new SparkContext("local[*]", "dev-shell")
|
||||
|val sqlContext = new SQLContext(sc)
|
||||
|import sqlContext.implicits._
|
||||
|import sqlContext._
|
||||
""".stripMargin,
|
||||
cleanupCommands in console := "sc.stop()"
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -340,8 +347,6 @@ object Hive {
|
|||
javaOptions += "-XX:MaxPermSize=256m",
|
||||
// Specially disable assertions since some Hive tests fail them
|
||||
javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"),
|
||||
// Multiple queries rely on the TestHive singleton. See comments there for more details.
|
||||
parallelExecution in Test := false,
|
||||
// Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings
|
||||
// only for this subproject.
|
||||
scalacOptions <<= scalacOptions map { currentOpts: Seq[String] =>
|
||||
|
@ -349,6 +354,7 @@ object Hive {
|
|||
},
|
||||
initialCommands in console :=
|
||||
"""
|
||||
|import org.apache.spark.SparkContext
|
||||
|import org.apache.spark.sql.catalyst.analysis._
|
||||
|import org.apache.spark.sql.catalyst.dsl._
|
||||
|import org.apache.spark.sql.catalyst.errors._
|
||||
|
|
|
@ -17,14 +17,10 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.analysis
|
||||
|
||||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.plans.Inner
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -42,7 +38,7 @@ case class UnresolvedTestPlan() extends LeafNode {
|
|||
override def output: Seq[Attribute] = Nil
|
||||
}
|
||||
|
||||
class AnalysisErrorSuite extends AnalysisTest with BeforeAndAfter {
|
||||
class AnalysisErrorSuite extends AnalysisTest {
|
||||
import TestRelations._
|
||||
|
||||
def errorTest(
|
||||
|
|
|
@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicReference
|
|||
|
||||
import scala.collection.JavaConversions._
|
||||
import scala.collection.immutable
|
||||
import scala.language.implicitConversions
|
||||
import scala.reflect.runtime.universe.TypeTag
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
|
@ -41,10 +40,9 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
|||
import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
|
||||
import org.apache.spark.sql.execution._
|
||||
import org.apache.spark.sql.execution.datasources._
|
||||
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
|
||||
import org.apache.spark.sql.sources.BaseRelation
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
|
@ -334,98 +332,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
|
|||
* @since 1.3.0
|
||||
*/
|
||||
@Experimental
|
||||
object implicits extends Serializable {
|
||||
object implicits extends SQLImplicits with Serializable {
|
||||
protected override def _sqlContext: SQLContext = self
|
||||
}
|
||||
// scalastyle:on
|
||||
|
||||
/**
|
||||
* Converts $"col name" into an [[Column]].
|
||||
* @since 1.3.0
|
||||
*/
|
||||
implicit class StringToColumn(val sc: StringContext) {
|
||||
def $(args: Any*): ColumnName = {
|
||||
new ColumnName(sc.s(args: _*))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* An implicit conversion that turns a Scala `Symbol` into a [[Column]].
|
||||
* @since 1.3.0
|
||||
*/
|
||||
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
|
||||
|
||||
/**
|
||||
* Creates a DataFrame from an RDD of case classes or tuples.
|
||||
* @since 1.3.0
|
||||
*/
|
||||
implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = {
|
||||
DataFrameHolder(self.createDataFrame(rdd))
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a DataFrame from a local Seq of Product.
|
||||
* @since 1.3.0
|
||||
*/
|
||||
implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder =
|
||||
{
|
||||
DataFrameHolder(self.createDataFrame(data))
|
||||
}
|
||||
|
||||
// Do NOT add more implicit conversions. They are likely to break source compatibility by
|
||||
// making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous
|
||||
// because of [[DoubleRDDFunctions]].
|
||||
|
||||
/**
|
||||
* Creates a single column DataFrame from an RDD[Int].
|
||||
* @since 1.3.0
|
||||
*/
|
||||
implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = {
|
||||
val dataType = IntegerType
|
||||
val rows = data.mapPartitions { iter =>
|
||||
val row = new SpecificMutableRow(dataType :: Nil)
|
||||
iter.map { v =>
|
||||
row.setInt(0, v)
|
||||
row: InternalRow
|
||||
}
|
||||
}
|
||||
DataFrameHolder(
|
||||
self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a single column DataFrame from an RDD[Long].
|
||||
* @since 1.3.0
|
||||
*/
|
||||
implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = {
|
||||
val dataType = LongType
|
||||
val rows = data.mapPartitions { iter =>
|
||||
val row = new SpecificMutableRow(dataType :: Nil)
|
||||
iter.map { v =>
|
||||
row.setLong(0, v)
|
||||
row: InternalRow
|
||||
}
|
||||
}
|
||||
DataFrameHolder(
|
||||
self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a single column DataFrame from an RDD[String].
|
||||
* @since 1.3.0
|
||||
*/
|
||||
implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = {
|
||||
val dataType = StringType
|
||||
val rows = data.mapPartitions { iter =>
|
||||
val row = new SpecificMutableRow(dataType :: Nil)
|
||||
iter.map { v =>
|
||||
row.update(0, UTF8String.fromString(v))
|
||||
row: InternalRow
|
||||
}
|
||||
}
|
||||
DataFrameHolder(
|
||||
self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
* Creates a DataFrame from an RDD of case classes.
|
||||
|
|
123
sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
Normal file
123
sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
Normal file
|
@ -0,0 +1,123 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import scala.language.implicitConversions
|
||||
import scala.reflect.runtime.universe.TypeTag
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
|
||||
import org.apache.spark.sql.types.StructField
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
/**
|
||||
* A collection of implicit methods for converting common Scala objects into [[DataFrame]]s.
|
||||
*/
|
||||
private[sql] abstract class SQLImplicits {
|
||||
protected def _sqlContext: SQLContext
|
||||
|
||||
/**
|
||||
* Converts $"col name" into an [[Column]].
|
||||
* @since 1.3.0
|
||||
*/
|
||||
implicit class StringToColumn(val sc: StringContext) {
|
||||
def $(args: Any*): ColumnName = {
|
||||
new ColumnName(sc.s(args: _*))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* An implicit conversion that turns a Scala `Symbol` into a [[Column]].
|
||||
* @since 1.3.0
|
||||
*/
|
||||
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
|
||||
|
||||
/**
|
||||
* Creates a DataFrame from an RDD of case classes or tuples.
|
||||
* @since 1.3.0
|
||||
*/
|
||||
implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = {
|
||||
DataFrameHolder(_sqlContext.createDataFrame(rdd))
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a DataFrame from a local Seq of Product.
|
||||
* @since 1.3.0
|
||||
*/
|
||||
implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder =
|
||||
{
|
||||
DataFrameHolder(_sqlContext.createDataFrame(data))
|
||||
}
|
||||
|
||||
// Do NOT add more implicit conversions. They are likely to break source compatibility by
|
||||
// making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous
|
||||
// because of [[DoubleRDDFunctions]].
|
||||
|
||||
/**
|
||||
* Creates a single column DataFrame from an RDD[Int].
|
||||
* @since 1.3.0
|
||||
*/
|
||||
implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = {
|
||||
val dataType = IntegerType
|
||||
val rows = data.mapPartitions { iter =>
|
||||
val row = new SpecificMutableRow(dataType :: Nil)
|
||||
iter.map { v =>
|
||||
row.setInt(0, v)
|
||||
row: InternalRow
|
||||
}
|
||||
}
|
||||
DataFrameHolder(
|
||||
_sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a single column DataFrame from an RDD[Long].
|
||||
* @since 1.3.0
|
||||
*/
|
||||
implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = {
|
||||
val dataType = LongType
|
||||
val rows = data.mapPartitions { iter =>
|
||||
val row = new SpecificMutableRow(dataType :: Nil)
|
||||
iter.map { v =>
|
||||
row.setLong(0, v)
|
||||
row: InternalRow
|
||||
}
|
||||
}
|
||||
DataFrameHolder(
|
||||
_sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a single column DataFrame from an RDD[String].
|
||||
* @since 1.3.0
|
||||
*/
|
||||
implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = {
|
||||
val dataType = StringType
|
||||
val rows = data.mapPartitions { iter =>
|
||||
val row = new SpecificMutableRow(dataType :: Nil)
|
||||
iter.map { v =>
|
||||
row.update(0, UTF8String.fromString(v))
|
||||
row: InternalRow
|
||||
}
|
||||
}
|
||||
DataFrameHolder(
|
||||
_sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
|
||||
}
|
||||
}
|
|
@ -27,6 +27,7 @@ import org.junit.Assert;
|
|||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import org.apache.spark.SparkContext;
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.api.java.function.Function;
|
||||
|
@ -34,7 +35,6 @@ import org.apache.spark.sql.DataFrame;
|
|||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.RowFactory;
|
||||
import org.apache.spark.sql.SQLContext;
|
||||
import org.apache.spark.sql.test.TestSQLContext$;
|
||||
import org.apache.spark.sql.types.DataTypes;
|
||||
import org.apache.spark.sql.types.StructField;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
|
@ -48,14 +48,16 @@ public class JavaApplySchemaSuite implements Serializable {
|
|||
|
||||
@Before
|
||||
public void setUp() {
|
||||
sqlContext = TestSQLContext$.MODULE$;
|
||||
javaCtx = new JavaSparkContext(sqlContext.sparkContext());
|
||||
SparkContext context = new SparkContext("local[*]", "testing");
|
||||
javaCtx = new JavaSparkContext(context);
|
||||
sqlContext = new SQLContext(context);
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
javaCtx = null;
|
||||
sqlContext.sparkContext().stop();
|
||||
sqlContext = null;
|
||||
javaCtx = null;
|
||||
}
|
||||
|
||||
public static class Person implements Serializable {
|
||||
|
|
|
@ -17,44 +17,45 @@
|
|||
|
||||
package test.org.apache.spark.sql;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.primitives.Ints;
|
||||
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.sql.*;
|
||||
import org.apache.spark.sql.test.TestSQLContext;
|
||||
import org.apache.spark.sql.test.TestSQLContext$;
|
||||
import org.apache.spark.sql.types.*;
|
||||
import org.junit.*;
|
||||
|
||||
import scala.collection.JavaConversions;
|
||||
import scala.collection.Seq;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import scala.collection.JavaConversions;
|
||||
import scala.collection.Seq;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.primitives.Ints;
|
||||
import org.junit.*;
|
||||
|
||||
import org.apache.spark.SparkContext;
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.sql.*;
|
||||
import static org.apache.spark.sql.functions.*;
|
||||
import org.apache.spark.sql.test.TestSQLContext;
|
||||
import org.apache.spark.sql.types.*;
|
||||
|
||||
public class JavaDataFrameSuite {
|
||||
private transient JavaSparkContext jsc;
|
||||
private transient SQLContext context;
|
||||
private transient TestSQLContext context;
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
// Trigger static initializer of TestData
|
||||
TestData$.MODULE$.testData();
|
||||
jsc = new JavaSparkContext(TestSQLContext.sparkContext());
|
||||
context = TestSQLContext$.MODULE$;
|
||||
SparkContext sc = new SparkContext("local[*]", "testing");
|
||||
jsc = new JavaSparkContext(sc);
|
||||
context = new TestSQLContext(sc);
|
||||
context.loadTestData();
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
jsc = null;
|
||||
context.sparkContext().stop();
|
||||
context = null;
|
||||
jsc = null;
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -230,7 +231,7 @@ public class JavaDataFrameSuite {
|
|||
|
||||
@Test
|
||||
public void testSampleBy() {
|
||||
DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key"));
|
||||
DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
|
||||
DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
|
||||
Row[] actual = sampled.groupBy("key").count().orderBy("key").collect();
|
||||
Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)};
|
||||
|
|
|
@ -23,12 +23,12 @@ import org.junit.After;
|
|||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import org.apache.spark.SparkContext;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.SQLContext;
|
||||
import org.apache.spark.sql.api.java.UDF1;
|
||||
import org.apache.spark.sql.api.java.UDF2;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.sql.test.TestSQLContext$;
|
||||
import org.apache.spark.sql.types.DataTypes;
|
||||
|
||||
// The test suite itself is Serializable so that anonymous Function implementations can be
|
||||
|
@ -40,12 +40,16 @@ public class JavaUDFSuite implements Serializable {
|
|||
|
||||
@Before
|
||||
public void setUp() {
|
||||
sqlContext = TestSQLContext$.MODULE$;
|
||||
sc = new JavaSparkContext(sqlContext.sparkContext());
|
||||
SparkContext _sc = new SparkContext("local[*]", "testing");
|
||||
sqlContext = new SQLContext(_sc);
|
||||
sc = new JavaSparkContext(_sc);
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
sqlContext.sparkContext().stop();
|
||||
sqlContext = null;
|
||||
sc = null;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
|
|
@ -21,13 +21,14 @@ import java.io.File;
|
|||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
|
||||
import org.junit.After;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import org.apache.spark.SparkContext;
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.sql.test.TestSQLContext$;
|
||||
import org.apache.spark.sql.*;
|
||||
import org.apache.spark.sql.types.DataTypes;
|
||||
import org.apache.spark.sql.types.StructField;
|
||||
|
@ -52,8 +53,9 @@ public class JavaSaveLoadSuite {
|
|||
|
||||
@Before
|
||||
public void setUp() throws IOException {
|
||||
sqlContext = TestSQLContext$.MODULE$;
|
||||
sc = new JavaSparkContext(sqlContext.sparkContext());
|
||||
SparkContext _sc = new SparkContext("local[*]", "testing");
|
||||
sqlContext = new SQLContext(_sc);
|
||||
sc = new JavaSparkContext(_sc);
|
||||
|
||||
originalDefaultSource = sqlContext.conf().defaultDataSourceName();
|
||||
path =
|
||||
|
@ -71,6 +73,13 @@ public class JavaSaveLoadSuite {
|
|||
df.registerTempTable("jsonTable");
|
||||
}
|
||||
|
||||
@After
|
||||
public void tearDown() {
|
||||
sqlContext.sparkContext().stop();
|
||||
sqlContext = null;
|
||||
sc = null;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void saveAndLoad() {
|
||||
Map<String, String> options = new HashMap<String, String>();
|
||||
|
|
|
@ -18,24 +18,20 @@
|
|||
package org.apache.spark.sql
|
||||
|
||||
import scala.concurrent.duration._
|
||||
import scala.language.{implicitConversions, postfixOps}
|
||||
import scala.language.postfixOps
|
||||
|
||||
import org.scalatest.concurrent.Eventually._
|
||||
|
||||
import org.apache.spark.Accumulators
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.columnar._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.storage.{StorageLevel, RDDBlockId}
|
||||
|
||||
case class BigData(s: String)
|
||||
private case class BigData(s: String)
|
||||
|
||||
class CachedTableSuite extends QueryTest {
|
||||
TestData // Load test tables.
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
import ctx.sql
|
||||
class CachedTableSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
def rddIdOf(tableName: String): Int = {
|
||||
val executedPlan = ctx.table(tableName).queryExecution.executedPlan
|
||||
|
|
|
@ -21,16 +21,20 @@ import org.scalatest.Matchers._
|
|||
|
||||
import org.apache.spark.sql.execution.{Project, TungstenProject}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
|
||||
class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
|
||||
import org.apache.spark.sql.TestData._
|
||||
class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
|
||||
override def sqlContext(): SQLContext = ctx
|
||||
private lazy val booleanData = {
|
||||
ctx.createDataFrame(ctx.sparkContext.parallelize(
|
||||
Row(false, false) ::
|
||||
Row(false, true) ::
|
||||
Row(true, false) ::
|
||||
Row(true, true) :: Nil),
|
||||
StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType))))
|
||||
}
|
||||
|
||||
test("column names with space") {
|
||||
val df = Seq((1, "a")).toDF("name with space", "name.with.dot")
|
||||
|
@ -258,7 +262,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
|
|||
nullStrings.collect().toSeq.filter(r => r.getString(1) eq null))
|
||||
|
||||
checkAnswer(
|
||||
ctx.sql("select isnull(null), isnull(1)"),
|
||||
sql("select isnull(null), isnull(1)"),
|
||||
Row(true, false))
|
||||
}
|
||||
|
||||
|
@ -268,7 +272,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
|
|||
nullStrings.collect().toSeq.filter(r => r.getString(1) ne null))
|
||||
|
||||
checkAnswer(
|
||||
ctx.sql("select isnotnull(null), isnotnull('a')"),
|
||||
sql("select isnotnull(null), isnotnull('a')"),
|
||||
Row(false, true))
|
||||
}
|
||||
|
||||
|
@ -289,7 +293,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
|
|||
Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
ctx.sql("select isnan(15), isnan('invalid')"),
|
||||
sql("select isnan(15), isnan('invalid')"),
|
||||
Row(false, false))
|
||||
}
|
||||
|
||||
|
@ -309,7 +313,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
|
|||
)
|
||||
testData.registerTempTable("t")
|
||||
checkAnswer(
|
||||
ctx.sql(
|
||||
sql(
|
||||
"select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " +
|
||||
" nanvl(b, e), nanvl(e, f) from t"),
|
||||
Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0)
|
||||
|
@ -433,13 +437,6 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
|
|||
}
|
||||
}
|
||||
|
||||
val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize(
|
||||
Row(false, false) ::
|
||||
Row(false, true) ::
|
||||
Row(true, false) ::
|
||||
Row(true, true) :: Nil),
|
||||
StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType))))
|
||||
|
||||
test("&&") {
|
||||
checkAnswer(
|
||||
booleanData.filter($"a" && true),
|
||||
|
@ -523,7 +520,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
|
|||
)
|
||||
|
||||
checkAnswer(
|
||||
ctx.sql("SELECT upper('aB'), ucase('cDe')"),
|
||||
sql("SELECT upper('aB'), ucase('cDe')"),
|
||||
Row("AB", "CDE"))
|
||||
}
|
||||
|
||||
|
@ -544,7 +541,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
|
|||
)
|
||||
|
||||
checkAnswer(
|
||||
ctx.sql("SELECT lower('aB'), lcase('cDe')"),
|
||||
sql("SELECT lower('aB'), lcase('cDe')"),
|
||||
Row("ab", "cde"))
|
||||
}
|
||||
|
||||
|
|
|
@ -17,15 +17,13 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.{BinaryType, DecimalType}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types.DecimalType
|
||||
|
||||
|
||||
class DataFrameAggregateSuite extends QueryTest {
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
test("groupBy") {
|
||||
checkAnswer(
|
||||
|
|
|
@ -17,17 +17,15 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
* Test suite for functions in [[org.apache.spark.sql.functions]].
|
||||
*/
|
||||
class DataFrameFunctionsSuite extends QueryTest {
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
test("array with column name") {
|
||||
val df = Seq((0, 1)).toDF("a", "b")
|
||||
|
@ -119,11 +117,11 @@ class DataFrameFunctionsSuite extends QueryTest {
|
|||
|
||||
test("constant functions") {
|
||||
checkAnswer(
|
||||
ctx.sql("SELECT E()"),
|
||||
sql("SELECT E()"),
|
||||
Row(scala.math.E)
|
||||
)
|
||||
checkAnswer(
|
||||
ctx.sql("SELECT PI()"),
|
||||
sql("SELECT PI()"),
|
||||
Row(scala.math.Pi)
|
||||
)
|
||||
}
|
||||
|
@ -153,7 +151,7 @@ class DataFrameFunctionsSuite extends QueryTest {
|
|||
|
||||
test("nvl function") {
|
||||
checkAnswer(
|
||||
ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
|
||||
sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
|
||||
Row("x", "y", null))
|
||||
}
|
||||
|
||||
|
@ -222,7 +220,7 @@ class DataFrameFunctionsSuite extends QueryTest {
|
|||
Row(-1)
|
||||
)
|
||||
checkAnswer(
|
||||
ctx.sql("SELECT least(a, 2) as l from testData2 order by l"),
|
||||
sql("SELECT least(a, 2) as l from testData2 order by l"),
|
||||
Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2))
|
||||
)
|
||||
}
|
||||
|
@ -233,7 +231,7 @@ class DataFrameFunctionsSuite extends QueryTest {
|
|||
Row(3)
|
||||
)
|
||||
checkAnswer(
|
||||
ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"),
|
||||
sql("SELECT greatest(a, 2) as g from testData2 order by g"),
|
||||
Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3))
|
||||
)
|
||||
}
|
||||
|
|
|
@ -17,10 +17,10 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
class DataFrameImplicitsSuite extends QueryTest {
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
test("RDD of tuples") {
|
||||
checkAnswer(
|
||||
|
|
|
@ -17,14 +17,12 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class DataFrameJoinSuite extends QueryTest {
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
test("join - join using") {
|
||||
val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
|
||||
|
@ -59,7 +57,7 @@ class DataFrameJoinSuite extends QueryTest {
|
|||
|
||||
checkAnswer(
|
||||
df1.join(df2, $"df1.key" === $"df2.key"),
|
||||
ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key")
|
||||
sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key")
|
||||
.collect().toSeq)
|
||||
}
|
||||
|
||||
|
|
|
@ -19,11 +19,11 @@ package org.apache.spark.sql
|
|||
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class DataFrameNaFunctionsSuite extends QueryTest {
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
def createDF(): DataFrame = {
|
||||
Seq[(String, java.lang.Integer, java.lang.Double)](
|
||||
|
|
|
@ -19,20 +19,17 @@ package org.apache.spark.sql
|
|||
|
||||
import java.util.Random
|
||||
|
||||
import org.scalatest.Matchers._
|
||||
|
||||
import org.apache.spark.sql.functions.col
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class DataFrameStatSuite extends QueryTest {
|
||||
|
||||
private val sqlCtx = org.apache.spark.sql.test.TestSQLContext
|
||||
import sqlCtx.implicits._
|
||||
class DataFrameStatSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
private def toLetter(i: Int): String = (i + 97).toChar.toString
|
||||
|
||||
test("sample with replacement") {
|
||||
val n = 100
|
||||
val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
|
||||
val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
|
||||
checkAnswer(
|
||||
data.sample(withReplacement = true, 0.05, seed = 13),
|
||||
Seq(5, 10, 52, 73).map(Row(_))
|
||||
|
@ -41,7 +38,7 @@ class DataFrameStatSuite extends QueryTest {
|
|||
|
||||
test("sample without replacement") {
|
||||
val n = 100
|
||||
val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
|
||||
val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
|
||||
checkAnswer(
|
||||
data.sample(withReplacement = false, 0.05, seed = 13),
|
||||
Seq(16, 23, 88, 100).map(Row(_))
|
||||
|
@ -50,7 +47,7 @@ class DataFrameStatSuite extends QueryTest {
|
|||
|
||||
test("randomSplit") {
|
||||
val n = 600
|
||||
val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
|
||||
val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
|
||||
for (seed <- 1 to 5) {
|
||||
val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
|
||||
assert(splits.length == 3, "wrong number of splits")
|
||||
|
@ -167,7 +164,7 @@ class DataFrameStatSuite extends QueryTest {
|
|||
}
|
||||
|
||||
test("Frequent Items 2") {
|
||||
val rows = sqlCtx.sparkContext.parallelize(Seq.empty[Int], 4)
|
||||
val rows = ctx.sparkContext.parallelize(Seq.empty[Int], 4)
|
||||
// this is a regression test, where when merging partitions, we omitted values with higher
|
||||
// counts than those that existed in the map when the map was full. This test should also fail
|
||||
// if anything like SPARK-9614 is observed once again
|
||||
|
@ -185,7 +182,7 @@ class DataFrameStatSuite extends QueryTest {
|
|||
}
|
||||
|
||||
test("sampleBy") {
|
||||
val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key"))
|
||||
val df = ctx.range(0, 100).select((col("id") % 3).as("key"))
|
||||
val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
|
||||
checkAnswer(
|
||||
sampled.groupBy("key").count().orderBy("key"),
|
||||
|
|
|
@ -23,18 +23,12 @@ import scala.language.postfixOps
|
|||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
|
||||
import org.apache.spark.sql.execution.datasources.LogicalRelation
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.execution.datasources.json.JSONRelation
|
||||
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils}
|
||||
import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SharedSQLContext}
|
||||
|
||||
class DataFrameSuite extends QueryTest with SQLTestUtils {
|
||||
import org.apache.spark.sql.TestData._
|
||||
|
||||
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
|
||||
import sqlContext.implicits._
|
||||
class DataFrameSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
test("analysis error should be eagerly reported") {
|
||||
// Eager analysis.
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
|
@ -27,10 +27,8 @@ import org.apache.spark.sql.types._
|
|||
* This is here for now so I can make sure Tungsten project is tested without refactoring existing
|
||||
* end-to-end test infra. In the long run this should just go away.
|
||||
*/
|
||||
class DataFrameTungstenSuite extends QueryTest with SQLTestUtils {
|
||||
|
||||
override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
|
||||
import sqlContext.implicits._
|
||||
class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
test("test simple types") {
|
||||
withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") {
|
||||
|
|
|
@ -22,19 +22,18 @@ import java.text.SimpleDateFormat
|
|||
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.unsafe.types.CalendarInterval
|
||||
|
||||
class DateFunctionsSuite extends QueryTest {
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
|
||||
import ctx.implicits._
|
||||
class DateFunctionsSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
test("function current_date") {
|
||||
val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
|
||||
val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis())
|
||||
val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0))
|
||||
val d2 = DateTimeUtils.fromJavaDate(
|
||||
ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0))
|
||||
sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0))
|
||||
val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis())
|
||||
assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1)
|
||||
}
|
||||
|
@ -44,9 +43,9 @@ class DateFunctionsSuite extends QueryTest {
|
|||
val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
|
||||
checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1))
|
||||
// Execution in one query should return the same value
|
||||
checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""),
|
||||
checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""),
|
||||
Row(true))
|
||||
assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp(
|
||||
assert(math.abs(sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp(
|
||||
0).getTime - System.currentTimeMillis()) < 5000)
|
||||
}
|
||||
|
||||
|
|
|
@ -17,22 +17,15 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.scalatest.BeforeAndAfterEach
|
||||
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
|
||||
import org.apache.spark.sql.execution.joins._
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
|
||||
class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
|
||||
// Ensures tables are loaded.
|
||||
TestData
|
||||
class JoinSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
override def sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
|
||||
lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
import ctx.logicalPlanToSparkQuery
|
||||
setupTestData()
|
||||
|
||||
test("equi-join is hash-join") {
|
||||
val x = testData2.as("x")
|
||||
|
@ -43,7 +36,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
|
|||
}
|
||||
|
||||
def assertJoin(sqlString: String, c: Class[_]): Any = {
|
||||
val df = ctx.sql(sqlString)
|
||||
val df = sql(sqlString)
|
||||
val physical = df.queryExecution.sparkPlan
|
||||
val operators = physical.collect {
|
||||
case j: ShuffledHashJoin => j
|
||||
|
@ -126,7 +119,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
|
|||
|
||||
test("broadcasted hash join operator selection") {
|
||||
ctx.cacheManager.clearCache()
|
||||
ctx.sql("CACHE TABLE testData")
|
||||
sql("CACHE TABLE testData")
|
||||
for (sortMergeJoinEnabled <- Seq(true, false)) {
|
||||
withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") {
|
||||
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") {
|
||||
|
@ -141,12 +134,12 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
|
|||
}
|
||||
}
|
||||
}
|
||||
ctx.sql("UNCACHE TABLE testData")
|
||||
sql("UNCACHE TABLE testData")
|
||||
}
|
||||
|
||||
test("broadcasted hash outer join operator selection") {
|
||||
ctx.cacheManager.clearCache()
|
||||
ctx.sql("CACHE TABLE testData")
|
||||
sql("CACHE TABLE testData")
|
||||
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
|
||||
Seq(
|
||||
("SELECT * FROM testData LEFT JOIN testData2 ON key = a",
|
||||
|
@ -167,7 +160,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
|
|||
classOf[BroadcastHashOuterJoin])
|
||||
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
|
||||
}
|
||||
ctx.sql("UNCACHE TABLE testData")
|
||||
sql("UNCACHE TABLE testData")
|
||||
}
|
||||
|
||||
test("multiple-key equi-join is hash-join") {
|
||||
|
@ -279,7 +272,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
|
|||
// Make sure we are choosing left.outputPartitioning as the
|
||||
// outputPartitioning for the outer join operator.
|
||||
checkAnswer(
|
||||
ctx.sql(
|
||||
sql(
|
||||
"""
|
||||
|SELECT l.N, count(*)
|
||||
|FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
|
||||
|
@ -293,7 +286,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
|
|||
Row(6, 1) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
ctx.sql(
|
||||
sql(
|
||||
"""
|
||||
|SELECT r.a, count(*)
|
||||
|FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
|
||||
|
@ -339,7 +332,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
|
|||
// Make sure we are choosing right.outputPartitioning as the
|
||||
// outputPartitioning for the outer join operator.
|
||||
checkAnswer(
|
||||
ctx.sql(
|
||||
sql(
|
||||
"""
|
||||
|SELECT l.a, count(*)
|
||||
|FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
|
||||
|
@ -348,7 +341,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
|
|||
Row(null, 6))
|
||||
|
||||
checkAnswer(
|
||||
ctx.sql(
|
||||
sql(
|
||||
"""
|
||||
|SELECT r.N, count(*)
|
||||
|FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
|
||||
|
@ -400,7 +393,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
|
|||
|
||||
// Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator.
|
||||
checkAnswer(
|
||||
ctx.sql(
|
||||
sql(
|
||||
"""
|
||||
|SELECT l.a, count(*)
|
||||
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
|
||||
|
@ -409,7 +402,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
|
|||
Row(null, 10))
|
||||
|
||||
checkAnswer(
|
||||
ctx.sql(
|
||||
sql(
|
||||
"""
|
||||
|SELECT r.N, count(*)
|
||||
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
|
||||
|
@ -424,7 +417,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
|
|||
Row(null, 4) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
ctx.sql(
|
||||
sql(
|
||||
"""
|
||||
|SELECT l.N, count(*)
|
||||
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
|
||||
|
@ -439,7 +432,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
|
|||
Row(null, 4) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
ctx.sql(
|
||||
sql(
|
||||
"""
|
||||
|SELECT r.a, count(*)
|
||||
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
|
||||
|
@ -450,7 +443,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
|
|||
|
||||
test("broadcasted left semi join operator selection") {
|
||||
ctx.cacheManager.clearCache()
|
||||
ctx.sql("CACHE TABLE testData")
|
||||
sql("CACHE TABLE testData")
|
||||
|
||||
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
|
||||
Seq(
|
||||
|
@ -469,11 +462,11 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
|
|||
}
|
||||
}
|
||||
|
||||
ctx.sql("UNCACHE TABLE testData")
|
||||
sql("UNCACHE TABLE testData")
|
||||
}
|
||||
|
||||
test("left semi join") {
|
||||
val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
|
||||
val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
|
||||
checkAnswer(df,
|
||||
Row(1, 1) ::
|
||||
Row(1, 2) ::
|
||||
|
|
|
@ -17,10 +17,10 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
class JsonFunctionsSuite extends QueryTest {
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
test("function get_json_object") {
|
||||
val df: DataFrame = Seq(("""{"name": "alice", "age": 5}""", "")).toDF("a", "b")
|
||||
|
|
|
@ -19,12 +19,11 @@ package org.apache.spark.sql
|
|||
|
||||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
|
||||
|
||||
class ListTablesSuite extends QueryTest with BeforeAndAfter {
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value")
|
||||
|
||||
|
@ -42,7 +41,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
|
|||
Row("ListTablesSuiteTable", true))
|
||||
|
||||
checkAnswer(
|
||||
ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
|
||||
sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
|
||||
Row("ListTablesSuiteTable", true))
|
||||
|
||||
ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
|
||||
|
@ -55,7 +54,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
|
|||
Row("ListTablesSuiteTable", true))
|
||||
|
||||
checkAnswer(
|
||||
ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
|
||||
sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
|
||||
Row("ListTablesSuiteTable", true))
|
||||
|
||||
ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
|
||||
|
@ -67,13 +66,13 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
|
|||
StructField("tableName", StringType, false) ::
|
||||
StructField("isTemporary", BooleanType, false) :: Nil)
|
||||
|
||||
Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach {
|
||||
Seq(ctx.tables(), sql("SHOW TABLes")).foreach {
|
||||
case tableDF =>
|
||||
assert(expectedSchema === tableDF.schema)
|
||||
|
||||
tableDF.registerTempTable("tables")
|
||||
checkAnswer(
|
||||
ctx.sql(
|
||||
sql(
|
||||
"SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
|
||||
Row(true, "ListTablesSuiteTable")
|
||||
)
|
||||
|
|
|
@ -19,18 +19,16 @@ package org.apache.spark.sql
|
|||
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.functions.{log => logarithm}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
private object MathExpressionsTestData {
|
||||
case class DoubleData(a: java.lang.Double, b: java.lang.Double)
|
||||
case class NullDoubles(a: java.lang.Double)
|
||||
}
|
||||
|
||||
class MathExpressionsSuite extends QueryTest {
|
||||
|
||||
class MathExpressionsSuite extends QueryTest with SharedSQLContext {
|
||||
import MathExpressionsTestData._
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
import testImplicits._
|
||||
|
||||
private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF()
|
||||
|
||||
|
@ -149,7 +147,7 @@ class MathExpressionsSuite extends QueryTest {
|
|||
test("toDegrees") {
|
||||
testOneToOneMathFunction(toDegrees, math.toDegrees)
|
||||
checkAnswer(
|
||||
ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"),
|
||||
sql("SELECT degrees(0), degrees(1), degrees(1.5)"),
|
||||
Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5)))
|
||||
)
|
||||
}
|
||||
|
@ -157,7 +155,7 @@ class MathExpressionsSuite extends QueryTest {
|
|||
test("toRadians") {
|
||||
testOneToOneMathFunction(toRadians, math.toRadians)
|
||||
checkAnswer(
|
||||
ctx.sql("SELECT radians(0), radians(1), radians(1.5)"),
|
||||
sql("SELECT radians(0), radians(1), radians(1.5)"),
|
||||
Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5)))
|
||||
)
|
||||
}
|
||||
|
@ -169,7 +167,7 @@ class MathExpressionsSuite extends QueryTest {
|
|||
test("ceil and ceiling") {
|
||||
testOneToOneMathFunction(ceil, math.ceil)
|
||||
checkAnswer(
|
||||
ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
|
||||
sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
|
||||
Row(0.0, 1.0, 2.0))
|
||||
}
|
||||
|
||||
|
@ -214,7 +212,7 @@ class MathExpressionsSuite extends QueryTest {
|
|||
|
||||
val pi = 3.1415
|
||||
checkAnswer(
|
||||
ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
|
||||
sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
|
||||
s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"),
|
||||
Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
|
||||
BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
|
||||
|
@ -233,7 +231,7 @@ class MathExpressionsSuite extends QueryTest {
|
|||
testOneToOneMathFunction[Double](signum, math.signum)
|
||||
|
||||
checkAnswer(
|
||||
ctx.sql("SELECT sign(10), signum(-11)"),
|
||||
sql("SELECT sign(10), signum(-11)"),
|
||||
Row(1, -1))
|
||||
}
|
||||
|
||||
|
@ -241,7 +239,7 @@ class MathExpressionsSuite extends QueryTest {
|
|||
testTwoToOneMathFunction(pow, pow, math.pow)
|
||||
|
||||
checkAnswer(
|
||||
ctx.sql("SELECT pow(1, 2), power(2, 1)"),
|
||||
sql("SELECT pow(1, 2), power(2, 1)"),
|
||||
Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1)))
|
||||
)
|
||||
}
|
||||
|
@ -280,7 +278,7 @@ class MathExpressionsSuite extends QueryTest {
|
|||
test("log / ln") {
|
||||
testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log)
|
||||
checkAnswer(
|
||||
ctx.sql("SELECT ln(0), ln(1), ln(1.5)"),
|
||||
sql("SELECT ln(0), ln(1), ln(1.5)"),
|
||||
Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5)))
|
||||
)
|
||||
}
|
||||
|
@ -375,7 +373,7 @@ class MathExpressionsSuite extends QueryTest {
|
|||
df.select(log2("b") + log2("a")),
|
||||
Row(1))
|
||||
|
||||
checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
|
||||
checkAnswer(sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
|
||||
}
|
||||
|
||||
test("sqrt") {
|
||||
|
@ -384,13 +382,13 @@ class MathExpressionsSuite extends QueryTest {
|
|||
df.select(sqrt("a"), sqrt("b")),
|
||||
Row(1.0, 2.0))
|
||||
|
||||
checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null))
|
||||
checkAnswer(sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null))
|
||||
checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null))
|
||||
}
|
||||
|
||||
test("negative") {
|
||||
checkAnswer(
|
||||
ctx.sql("SELECT negative(1), negative(0), negative(-1)"),
|
||||
sql("SELECT negative(1), negative(0), negative(-1)"),
|
||||
Row(-1, 0, 1))
|
||||
}
|
||||
|
||||
|
|
|
@ -71,12 +71,6 @@ class QueryTest extends PlanTest {
|
|||
checkAnswer(df, expectedAnswer.collect())
|
||||
}
|
||||
|
||||
def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) {
|
||||
test(sqlString) {
|
||||
checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Asserts that a given [[DataFrame]] will be executed using the given number of cached results.
|
||||
*/
|
||||
|
|
|
@ -20,13 +20,12 @@ package org.apache.spark.sql
|
|||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.execution.SparkSqlSerializer
|
||||
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
class RowSuite extends SparkFunSuite {
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
class RowSuite extends SparkFunSuite with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
test("create row") {
|
||||
val expected = new GenericMutableRow(4)
|
||||
|
|
|
@ -17,11 +17,10 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class SQLConfSuite extends QueryTest {
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
|
||||
class SQLConfSuite extends QueryTest with SharedSQLContext {
|
||||
private val testKey = "test.key.0"
|
||||
private val testVal = "test.val.0"
|
||||
|
||||
|
@ -52,21 +51,21 @@ class SQLConfSuite extends QueryTest {
|
|||
|
||||
test("parse SQL set commands") {
|
||||
ctx.conf.clear()
|
||||
ctx.sql(s"set $testKey=$testVal")
|
||||
sql(s"set $testKey=$testVal")
|
||||
assert(ctx.getConf(testKey, testVal + "_") === testVal)
|
||||
assert(ctx.getConf(testKey, testVal + "_") === testVal)
|
||||
|
||||
ctx.sql("set some.property=20")
|
||||
sql("set some.property=20")
|
||||
assert(ctx.getConf("some.property", "0") === "20")
|
||||
ctx.sql("set some.property = 40")
|
||||
sql("set some.property = 40")
|
||||
assert(ctx.getConf("some.property", "0") === "40")
|
||||
|
||||
val key = "spark.sql.key"
|
||||
val vs = "val0,val_1,val2.3,my_table"
|
||||
ctx.sql(s"set $key=$vs")
|
||||
sql(s"set $key=$vs")
|
||||
assert(ctx.getConf(key, "0") === vs)
|
||||
|
||||
ctx.sql(s"set $key=")
|
||||
sql(s"set $key=")
|
||||
assert(ctx.getConf(key, "0") === "")
|
||||
|
||||
ctx.conf.clear()
|
||||
|
@ -74,14 +73,14 @@ class SQLConfSuite extends QueryTest {
|
|||
|
||||
test("deprecated property") {
|
||||
ctx.conf.clear()
|
||||
ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
|
||||
sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
|
||||
assert(ctx.conf.numShufflePartitions === 10)
|
||||
}
|
||||
|
||||
test("invalid conf value") {
|
||||
ctx.conf.clear()
|
||||
val e = intercept[IllegalArgumentException] {
|
||||
ctx.sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
|
||||
sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
|
||||
}
|
||||
assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10")
|
||||
}
|
||||
|
|
|
@ -17,16 +17,17 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll {
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
class SQLContextSuite extends SparkFunSuite with SharedSQLContext {
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
try {
|
||||
SQLContext.setLastInstantiatedContext(ctx)
|
||||
} finally {
|
||||
super.afterAll()
|
||||
}
|
||||
}
|
||||
|
||||
test("getOrCreate instantiates SQLContext") {
|
||||
|
|
|
@ -19,28 +19,23 @@ package org.apache.spark.sql
|
|||
|
||||
import java.sql.Timestamp
|
||||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
import org.apache.spark.AccumulatorSuite
|
||||
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
|
||||
import org.apache.spark.sql.catalyst.DefaultParserDialect
|
||||
import org.apache.spark.sql.catalyst.errors.DialectException
|
||||
import org.apache.spark.sql.execution.aggregate
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.test.SQLTestData._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/** A SQL Dialect for testing purpose, and it can not be nested type */
|
||||
class MyDialect extends DefaultParserDialect
|
||||
|
||||
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
|
||||
// Make sure the tables are loaded.
|
||||
TestData
|
||||
class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
val sqlContext = org.apache.spark.sql.test.TestSQLContext
|
||||
import sqlContext.implicits._
|
||||
import sqlContext.sql
|
||||
setupTestData()
|
||||
|
||||
test("having clause") {
|
||||
Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav")
|
||||
|
@ -60,7 +55,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
|
|||
}
|
||||
|
||||
test("show functions") {
|
||||
checkAnswer(sql("SHOW functions"), FunctionRegistry.builtin.listFunction().sorted.map(Row(_)))
|
||||
checkAnswer(sql("SHOW functions"),
|
||||
FunctionRegistry.builtin.listFunction().sorted.map(Row(_)))
|
||||
}
|
||||
|
||||
test("describe functions") {
|
||||
|
@ -178,7 +174,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
|
|||
|
||||
val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index")
|
||||
// we except the id is materialized once
|
||||
val idUDF = udf(() => UUID.randomUUID().toString)
|
||||
val idUDF = org.apache.spark.sql.functions.udf(() => UUID.randomUUID().toString)
|
||||
|
||||
val dfWithId = df.withColumn("id", idUDF())
|
||||
// Make a new DataFrame (actually the same reference to the old one)
|
||||
|
@ -712,9 +708,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
|
|||
|
||||
checkAnswer(
|
||||
sql(
|
||||
"""
|
||||
|SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3
|
||||
""".stripMargin),
|
||||
"SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"),
|
||||
Row(2, 1, 2, 2, 1))
|
||||
}
|
||||
|
||||
|
@ -1161,7 +1155,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
|
|||
validateMetadata(sql("SELECT * FROM personWithMeta"))
|
||||
validateMetadata(sql("SELECT id, name FROM personWithMeta"))
|
||||
validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId"))
|
||||
validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId"))
|
||||
validateMetadata(sql(
|
||||
"SELECT name, salary FROM personWithMeta JOIN salary ON id = personId"))
|
||||
}
|
||||
|
||||
test("SPARK-3371 Renaming a function expression with group by gives error") {
|
||||
|
@ -1627,7 +1622,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
|
|||
.toDF("num", "str")
|
||||
df.registerTempTable("1one")
|
||||
|
||||
checkAnswer(sqlContext.sql("select count(num) from 1one"), Row(10))
|
||||
checkAnswer(sql("select count(num) from 1one"), Row(10))
|
||||
|
||||
sqlContext.dropTempTable("1one")
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql
|
|||
import java.sql.{Date, Timestamp}
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
case class ReflectData(
|
||||
stringField: String,
|
||||
|
@ -71,17 +72,15 @@ case class ComplexReflectData(
|
|||
mapFieldContainsNull: Map[Int, Option[Long]],
|
||||
dataField: Data)
|
||||
|
||||
class ScalaReflectionRelationSuite extends SparkFunSuite {
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
test("query case class RDD") {
|
||||
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
|
||||
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))
|
||||
Seq(data).toDF().registerTempTable("reflectData")
|
||||
|
||||
assert(ctx.sql("SELECT * FROM reflectData").collect().head ===
|
||||
assert(sql("SELECT * FROM reflectData").collect().head ===
|
||||
Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
|
||||
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"),
|
||||
new Timestamp(12345), Seq(1, 2, 3)))
|
||||
|
@ -91,7 +90,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
|
|||
val data = NullReflectData(null, null, null, null, null, null, null)
|
||||
Seq(data).toDF().registerTempTable("reflectNullData")
|
||||
|
||||
assert(ctx.sql("SELECT * FROM reflectNullData").collect().head ===
|
||||
assert(sql("SELECT * FROM reflectNullData").collect().head ===
|
||||
Row.fromSeq(Seq.fill(7)(null)))
|
||||
}
|
||||
|
||||
|
@ -99,7 +98,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
|
|||
val data = OptionalReflectData(None, None, None, None, None, None, None)
|
||||
Seq(data).toDF().registerTempTable("reflectOptionalData")
|
||||
|
||||
assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head ===
|
||||
assert(sql("SELECT * FROM reflectOptionalData").collect().head ===
|
||||
Row.fromSeq(Seq.fill(7)(null)))
|
||||
}
|
||||
|
||||
|
@ -107,7 +106,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
|
|||
test("query binary data") {
|
||||
Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary")
|
||||
|
||||
val result = ctx.sql("SELECT data FROM reflectBinary")
|
||||
val result = sql("SELECT data FROM reflectBinary")
|
||||
.collect().head(0).asInstanceOf[Array[Byte]]
|
||||
assert(result.toSeq === Seq[Byte](1))
|
||||
}
|
||||
|
@ -126,7 +125,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
|
|||
Nested(None, "abc")))
|
||||
|
||||
Seq(data).toDF().registerTempTable("reflectComplexData")
|
||||
assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head ===
|
||||
assert(sql("SELECT * FROM reflectComplexData").collect().head ===
|
||||
Row(
|
||||
Seq(1, 2, 3),
|
||||
Seq(1, 2, null),
|
||||
|
|
|
@ -19,13 +19,12 @@ package org.apache.spark.sql
|
|||
|
||||
import org.apache.spark.{SparkConf, SparkFunSuite}
|
||||
import org.apache.spark.serializer.JavaSerializer
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class SerializationSuite extends SparkFunSuite {
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
class SerializationSuite extends SparkFunSuite with SharedSQLContext {
|
||||
|
||||
test("[SPARK-5235] SQLContext should be serializable") {
|
||||
val sqlContext = new SQLContext(ctx.sparkContext)
|
||||
new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext)
|
||||
val _sqlContext = new SQLContext(sqlContext.sparkContext)
|
||||
new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,13 +18,12 @@
|
|||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types.Decimal
|
||||
|
||||
|
||||
class StringFunctionsSuite extends QueryTest {
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
class StringFunctionsSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
test("string concat") {
|
||||
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")
|
||||
|
|
|
@ -1,197 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.test.TestSQLContext.implicits._
|
||||
import org.apache.spark.sql.test._
|
||||
|
||||
|
||||
case class TestData(key: Int, value: String)
|
||||
|
||||
object TestData {
|
||||
val testData = TestSQLContext.sparkContext.parallelize(
|
||||
(1 to 100).map(i => TestData(i, i.toString))).toDF()
|
||||
testData.registerTempTable("testData")
|
||||
|
||||
val negativeData = TestSQLContext.sparkContext.parallelize(
|
||||
(1 to 100).map(i => TestData(-i, (-i).toString))).toDF()
|
||||
negativeData.registerTempTable("negativeData")
|
||||
|
||||
case class LargeAndSmallInts(a: Int, b: Int)
|
||||
val largeAndSmallInts =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
LargeAndSmallInts(2147483644, 1) ::
|
||||
LargeAndSmallInts(1, 2) ::
|
||||
LargeAndSmallInts(2147483645, 1) ::
|
||||
LargeAndSmallInts(2, 2) ::
|
||||
LargeAndSmallInts(2147483646, 1) ::
|
||||
LargeAndSmallInts(3, 2) :: Nil).toDF()
|
||||
largeAndSmallInts.registerTempTable("largeAndSmallInts")
|
||||
|
||||
case class TestData2(a: Int, b: Int)
|
||||
val testData2 =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
TestData2(1, 1) ::
|
||||
TestData2(1, 2) ::
|
||||
TestData2(2, 1) ::
|
||||
TestData2(2, 2) ::
|
||||
TestData2(3, 1) ::
|
||||
TestData2(3, 2) :: Nil, 2).toDF()
|
||||
testData2.registerTempTable("testData2")
|
||||
|
||||
case class DecimalData(a: BigDecimal, b: BigDecimal)
|
||||
|
||||
val decimalData =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
DecimalData(1, 1) ::
|
||||
DecimalData(1, 2) ::
|
||||
DecimalData(2, 1) ::
|
||||
DecimalData(2, 2) ::
|
||||
DecimalData(3, 1) ::
|
||||
DecimalData(3, 2) :: Nil).toDF()
|
||||
decimalData.registerTempTable("decimalData")
|
||||
|
||||
case class BinaryData(a: Array[Byte], b: Int)
|
||||
val binaryData =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
BinaryData("12".getBytes(), 1) ::
|
||||
BinaryData("22".getBytes(), 5) ::
|
||||
BinaryData("122".getBytes(), 3) ::
|
||||
BinaryData("121".getBytes(), 2) ::
|
||||
BinaryData("123".getBytes(), 4) :: Nil).toDF()
|
||||
binaryData.registerTempTable("binaryData")
|
||||
|
||||
case class TestData3(a: Int, b: Option[Int])
|
||||
val testData3 =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
TestData3(1, None) ::
|
||||
TestData3(2, Some(2)) :: Nil).toDF()
|
||||
testData3.registerTempTable("testData3")
|
||||
|
||||
case class UpperCaseData(N: Int, L: String)
|
||||
val upperCaseData =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
UpperCaseData(1, "A") ::
|
||||
UpperCaseData(2, "B") ::
|
||||
UpperCaseData(3, "C") ::
|
||||
UpperCaseData(4, "D") ::
|
||||
UpperCaseData(5, "E") ::
|
||||
UpperCaseData(6, "F") :: Nil).toDF()
|
||||
upperCaseData.registerTempTable("upperCaseData")
|
||||
|
||||
case class LowerCaseData(n: Int, l: String)
|
||||
val lowerCaseData =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
LowerCaseData(1, "a") ::
|
||||
LowerCaseData(2, "b") ::
|
||||
LowerCaseData(3, "c") ::
|
||||
LowerCaseData(4, "d") :: Nil).toDF()
|
||||
lowerCaseData.registerTempTable("lowerCaseData")
|
||||
|
||||
case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
|
||||
val arrayData =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
|
||||
ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
|
||||
arrayData.toDF().registerTempTable("arrayData")
|
||||
|
||||
case class MapData(data: scala.collection.Map[Int, String])
|
||||
val mapData =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
|
||||
MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
|
||||
MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
|
||||
MapData(Map(1 -> "a4", 2 -> "b4")) ::
|
||||
MapData(Map(1 -> "a5")) :: Nil)
|
||||
mapData.toDF().registerTempTable("mapData")
|
||||
|
||||
case class StringData(s: String)
|
||||
val repeatedData =
|
||||
TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
|
||||
repeatedData.toDF().registerTempTable("repeatedData")
|
||||
|
||||
val nullableRepeatedData =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
List.fill(2)(StringData(null)) ++
|
||||
List.fill(2)(StringData("test")))
|
||||
nullableRepeatedData.toDF().registerTempTable("nullableRepeatedData")
|
||||
|
||||
case class NullInts(a: Integer)
|
||||
val nullInts =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
NullInts(1) ::
|
||||
NullInts(2) ::
|
||||
NullInts(3) ::
|
||||
NullInts(null) :: Nil
|
||||
).toDF()
|
||||
nullInts.registerTempTable("nullInts")
|
||||
|
||||
val allNulls =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
NullInts(null) ::
|
||||
NullInts(null) ::
|
||||
NullInts(null) ::
|
||||
NullInts(null) :: Nil).toDF()
|
||||
allNulls.registerTempTable("allNulls")
|
||||
|
||||
case class NullStrings(n: Int, s: String)
|
||||
val nullStrings =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
NullStrings(1, "abc") ::
|
||||
NullStrings(2, "ABC") ::
|
||||
NullStrings(3, null) :: Nil).toDF()
|
||||
nullStrings.registerTempTable("nullStrings")
|
||||
|
||||
case class TableName(tableName: String)
|
||||
TestSQLContext
|
||||
.sparkContext
|
||||
.parallelize(TableName("test") :: Nil)
|
||||
.toDF()
|
||||
.registerTempTable("tableName")
|
||||
|
||||
val unparsedStrings =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
"1, A1, true, null" ::
|
||||
"2, B2, false, null" ::
|
||||
"3, C3, true, null" ::
|
||||
"4, D4, true, 2147483644" :: Nil)
|
||||
|
||||
case class IntField(i: Int)
|
||||
// An RDD with 4 elements and 8 partitions
|
||||
val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
|
||||
withEmptyParts.toDF().registerTempTable("withEmptyParts")
|
||||
|
||||
case class Person(id: Int, name: String, age: Int)
|
||||
case class Salary(personId: Int, salary: Double)
|
||||
val person = TestSQLContext.sparkContext.parallelize(
|
||||
Person(0, "mike", 30) ::
|
||||
Person(1, "jim", 20) :: Nil).toDF()
|
||||
person.registerTempTable("person")
|
||||
val salary = TestSQLContext.sparkContext.parallelize(
|
||||
Salary(0, 2000.0) ::
|
||||
Salary(1, 1000.0) :: Nil).toDF()
|
||||
salary.registerTempTable("salary")
|
||||
|
||||
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
|
||||
val complexData =
|
||||
TestSQLContext.sparkContext.parallelize(
|
||||
ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true)
|
||||
:: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false)
|
||||
:: Nil).toDF()
|
||||
complexData.registerTempTable("complexData")
|
||||
}
|
|
@ -17,16 +17,13 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.test.SQLTestData._
|
||||
|
||||
case class FunctionResult(f1: String, f2: String)
|
||||
private case class FunctionResult(f1: String, f2: String)
|
||||
|
||||
class UDFSuite extends QueryTest with SQLTestUtils {
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
|
||||
override def sqlContext(): SQLContext = ctx
|
||||
class UDFSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
test("built-in fixed arity expressions") {
|
||||
val df = ctx.emptyDataFrame
|
||||
|
@ -57,7 +54,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
|
|||
test("SPARK-8003 spark_partition_id") {
|
||||
val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying")
|
||||
df.registerTempTable("tmp_table")
|
||||
checkAnswer(ctx.sql("select spark_partition_id() from tmp_table").toDF(), Row(0))
|
||||
checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0))
|
||||
ctx.dropTempTable("tmp_table")
|
||||
}
|
||||
|
||||
|
@ -66,9 +63,9 @@ class UDFSuite extends QueryTest with SQLTestUtils {
|
|||
val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id")
|
||||
data.write.parquet(dir.getCanonicalPath)
|
||||
ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table")
|
||||
val answer = ctx.sql("select input_file_name() from test_table").head().getString(0)
|
||||
val answer = sql("select input_file_name() from test_table").head().getString(0)
|
||||
assert(answer.contains(dir.getCanonicalPath))
|
||||
assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2)
|
||||
assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2)
|
||||
ctx.dropTempTable("test_table")
|
||||
}
|
||||
}
|
||||
|
@ -91,17 +88,17 @@ class UDFSuite extends QueryTest with SQLTestUtils {
|
|||
|
||||
test("Simple UDF") {
|
||||
ctx.udf.register("strLenScala", (_: String).length)
|
||||
assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4)
|
||||
assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4)
|
||||
}
|
||||
|
||||
test("ZeroArgument UDF") {
|
||||
ctx.udf.register("random0", () => { Math.random()})
|
||||
assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0)
|
||||
assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)
|
||||
}
|
||||
|
||||
test("TwoArgument UDF") {
|
||||
ctx.udf.register("strLenScala", (_: String).length + (_: Int))
|
||||
assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
|
||||
assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
|
||||
}
|
||||
|
||||
test("UDF in a WHERE") {
|
||||
|
@ -112,7 +109,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
|
|||
df.registerTempTable("integerData")
|
||||
|
||||
val result =
|
||||
ctx.sql("SELECT * FROM integerData WHERE oneArgFilter(key)")
|
||||
sql("SELECT * FROM integerData WHERE oneArgFilter(key)")
|
||||
assert(result.count() === 20)
|
||||
}
|
||||
|
||||
|
@ -124,7 +121,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
|
|||
df.registerTempTable("groupData")
|
||||
|
||||
val result =
|
||||
ctx.sql(
|
||||
sql(
|
||||
"""
|
||||
| SELECT g, SUM(v) as s
|
||||
| FROM groupData
|
||||
|
@ -143,7 +140,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
|
|||
df.registerTempTable("groupData")
|
||||
|
||||
val result =
|
||||
ctx.sql(
|
||||
sql(
|
||||
"""
|
||||
| SELECT SUM(v)
|
||||
| FROM groupData
|
||||
|
@ -163,7 +160,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
|
|||
df.registerTempTable("groupData")
|
||||
|
||||
val result =
|
||||
ctx.sql(
|
||||
sql(
|
||||
"""
|
||||
| SELECT timesHundred(SUM(v)) as v100
|
||||
| FROM groupData
|
||||
|
@ -178,7 +175,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
|
|||
ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
|
||||
|
||||
val result =
|
||||
ctx.sql("SELECT returnStruct('test', 'test2') as ret")
|
||||
sql("SELECT returnStruct('test', 'test2') as ret")
|
||||
.select($"ret.f1").head().getString(0)
|
||||
assert(result === "test")
|
||||
}
|
||||
|
@ -186,12 +183,12 @@ class UDFSuite extends QueryTest with SQLTestUtils {
|
|||
test("udf that is transformed") {
|
||||
ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y))
|
||||
// 1 + 1 is constant folded causing a transformation.
|
||||
assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
|
||||
assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
|
||||
}
|
||||
|
||||
test("type coercion for udf inputs") {
|
||||
ctx.udf.register("intExpected", (x: Int) => x)
|
||||
// pass a decimal to intExpected.
|
||||
assert(ctx.sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
|
||||
assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.util.collection.OpenHashSet
|
||||
|
@ -66,10 +67,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
|
|||
private[spark] override def asNullable: MyDenseVectorUDT = this
|
||||
}
|
||||
|
||||
class UserDefinedTypeSuite extends QueryTest {
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
class UserDefinedTypeSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
private lazy val pointsRDD = Seq(
|
||||
MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
|
||||
|
@ -94,7 +93,7 @@ class UserDefinedTypeSuite extends QueryTest {
|
|||
ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector])
|
||||
pointsRDD.registerTempTable("points")
|
||||
checkAnswer(
|
||||
ctx.sql("SELECT testType(features) from points"),
|
||||
sql("SELECT testType(features) from points"),
|
||||
Seq(Row(true), Row(true)))
|
||||
}
|
||||
|
||||
|
|
|
@ -19,18 +19,16 @@ package org.apache.spark.sql.columnar
|
|||
|
||||
import java.sql.{Date, Timestamp}
|
||||
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.{QueryTest, Row}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.test.SQLTestData._
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.{QueryTest, Row, TestData}
|
||||
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
|
||||
|
||||
class InMemoryColumnarQuerySuite extends QueryTest {
|
||||
// Make sure the tables are loaded.
|
||||
TestData
|
||||
class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
import ctx.{logicalPlanToSparkQuery, sql}
|
||||
setupTestData()
|
||||
|
||||
test("simple columnar query") {
|
||||
val plan = ctx.executePlan(testData.logicalPlan).executedPlan
|
||||
|
|
|
@ -17,20 +17,19 @@
|
|||
|
||||
package org.apache.spark.sql.columnar
|
||||
|
||||
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.test.SQLTestData._
|
||||
|
||||
class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter {
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize
|
||||
private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning
|
||||
|
||||
override protected def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
|
||||
ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10)
|
||||
|
||||
|
@ -44,19 +43,17 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi
|
|||
ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
|
||||
// Enable in-memory table scan accumulators
|
||||
ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
|
||||
}
|
||||
|
||||
override protected def afterAll(): Unit = {
|
||||
ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
|
||||
ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
|
||||
}
|
||||
|
||||
before {
|
||||
ctx.cacheTable("pruningData")
|
||||
}
|
||||
|
||||
after {
|
||||
override protected def afterAll(): Unit = {
|
||||
try {
|
||||
ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
|
||||
ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
|
||||
ctx.uncacheTable("pruningData")
|
||||
} finally {
|
||||
super.afterAll()
|
||||
}
|
||||
}
|
||||
|
||||
// Comparisons
|
||||
|
@ -110,7 +107,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi
|
|||
expectedQueryResult: => Seq[Int]): Unit = {
|
||||
|
||||
test(query) {
|
||||
val df = ctx.sql(query)
|
||||
val df = sql(query)
|
||||
val queryExecution = df.queryExecution
|
||||
|
||||
assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") {
|
||||
|
|
|
@ -19,8 +19,9 @@ package org.apache.spark.sql.execution
|
|||
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class ExchangeSuite extends SparkPlanTest {
|
||||
class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
|
||||
test("shuffling UnsafeRows in exchange") {
|
||||
val input = (1 to 1000).map(Tuple1.apply)
|
||||
checkAnswer(
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
|
|||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.{execution, Row, SQLConf}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
|
@ -27,19 +27,18 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
|||
import org.apache.spark.sql.catalyst.plans.physical._
|
||||
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
|
||||
import org.apache.spark.sql.test.TestSQLContext._
|
||||
import org.apache.spark.sql.test.TestSQLContext.implicits._
|
||||
import org.apache.spark.sql.test.TestSQLContext.planner._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution}
|
||||
|
||||
|
||||
class PlannerSuite extends SparkFunSuite with SQLTestUtils {
|
||||
class PlannerSuite extends SparkFunSuite with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
override def sqlContext: SQLContext = TestSQLContext
|
||||
setupTestData()
|
||||
|
||||
private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
|
||||
val _ctx = ctx
|
||||
import _ctx.planner._
|
||||
val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption)
|
||||
val planned =
|
||||
plannedOption.getOrElse(
|
||||
|
@ -54,6 +53,8 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
|
|||
}
|
||||
|
||||
test("unions are collapsed") {
|
||||
val _ctx = ctx
|
||||
import _ctx.planner._
|
||||
val query = testData.unionAll(testData).unionAll(testData).logicalPlan
|
||||
val planned = BasicOperators(query).head
|
||||
val logicalUnions = query collect { case u: logical.Union => u }
|
||||
|
@ -81,14 +82,14 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
|
|||
|
||||
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
|
||||
def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = {
|
||||
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold)
|
||||
ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold)
|
||||
val fields = fieldTypes.zipWithIndex.map {
|
||||
case (dataType, index) => StructField(s"c${index}", dataType, true)
|
||||
} :+ StructField("key", IntegerType, true)
|
||||
val schema = StructType(fields)
|
||||
val row = Row.fromSeq(Seq.fill(fields.size)(null))
|
||||
val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil)
|
||||
createDataFrame(rowRDD, schema).registerTempTable("testLimit")
|
||||
val rowRDD = ctx.sparkContext.parallelize(row :: Nil)
|
||||
ctx.createDataFrame(rowRDD, schema).registerTempTable("testLimit")
|
||||
|
||||
val planned = sql(
|
||||
"""
|
||||
|
@ -102,10 +103,10 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
|
|||
assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
|
||||
assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
|
||||
|
||||
dropTempTable("testLimit")
|
||||
ctx.dropTempTable("testLimit")
|
||||
}
|
||||
|
||||
val origThreshold = conf.autoBroadcastJoinThreshold
|
||||
val origThreshold = ctx.conf.autoBroadcastJoinThreshold
|
||||
|
||||
val simpleTypes =
|
||||
NullType ::
|
||||
|
@ -137,18 +138,18 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
|
|||
|
||||
checkPlan(complexTypes, newThreshold = 901617)
|
||||
|
||||
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
|
||||
ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
|
||||
}
|
||||
|
||||
test("InMemoryRelation statistics propagation") {
|
||||
val origThreshold = conf.autoBroadcastJoinThreshold
|
||||
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920)
|
||||
val origThreshold = ctx.conf.autoBroadcastJoinThreshold
|
||||
ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920)
|
||||
|
||||
testData.limit(3).registerTempTable("tiny")
|
||||
sql("CACHE TABLE tiny")
|
||||
|
||||
val a = testData.as("a")
|
||||
val b = table("tiny").as("b")
|
||||
val b = ctx.table("tiny").as("b")
|
||||
val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan
|
||||
|
||||
val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
|
||||
|
@ -157,12 +158,12 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
|
|||
assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
|
||||
assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
|
||||
|
||||
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
|
||||
ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
|
||||
}
|
||||
|
||||
test("efficient limit -> project -> sort") {
|
||||
val query = testData.sort('key).select('value).limit(2).logicalPlan
|
||||
val planned = planner.TakeOrderedAndProject(query)
|
||||
val planned = ctx.planner.TakeOrderedAndProject(query)
|
||||
assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
|
||||
}
|
||||
|
||||
|
|
|
@ -21,11 +21,11 @@ import org.apache.spark.rdd.RDD
|
|||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull}
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StructType, StringType}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StringType}
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
class RowFormatConvertersSuite extends SparkPlanTest {
|
||||
class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext {
|
||||
|
||||
private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect {
|
||||
case c: ConvertToUnsafe => c
|
||||
|
@ -39,20 +39,20 @@ class RowFormatConvertersSuite extends SparkPlanTest {
|
|||
|
||||
test("planner should insert unsafe->safe conversions when required") {
|
||||
val plan = Limit(10, outputsUnsafe)
|
||||
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
|
||||
val preparedPlan = ctx.prepareForExecution.execute(plan)
|
||||
assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe])
|
||||
}
|
||||
|
||||
test("filter can process unsafe rows") {
|
||||
val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe)
|
||||
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
|
||||
val preparedPlan = ctx.prepareForExecution.execute(plan)
|
||||
assert(getConverters(preparedPlan).size === 1)
|
||||
assert(preparedPlan.outputsUnsafeRows)
|
||||
}
|
||||
|
||||
test("filter can process safe rows") {
|
||||
val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe)
|
||||
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
|
||||
val preparedPlan = ctx.prepareForExecution.execute(plan)
|
||||
assert(getConverters(preparedPlan).isEmpty)
|
||||
assert(!preparedPlan.outputsUnsafeRows)
|
||||
}
|
||||
|
@ -67,33 +67,33 @@ class RowFormatConvertersSuite extends SparkPlanTest {
|
|||
test("union requires all of its input rows' formats to agree") {
|
||||
val plan = Union(Seq(outputsSafe, outputsUnsafe))
|
||||
assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows)
|
||||
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
|
||||
val preparedPlan = ctx.prepareForExecution.execute(plan)
|
||||
assert(preparedPlan.outputsUnsafeRows)
|
||||
}
|
||||
|
||||
test("union can process safe rows") {
|
||||
val plan = Union(Seq(outputsSafe, outputsSafe))
|
||||
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
|
||||
val preparedPlan = ctx.prepareForExecution.execute(plan)
|
||||
assert(!preparedPlan.outputsUnsafeRows)
|
||||
}
|
||||
|
||||
test("union can process unsafe rows") {
|
||||
val plan = Union(Seq(outputsUnsafe, outputsUnsafe))
|
||||
val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
|
||||
val preparedPlan = ctx.prepareForExecution.execute(plan)
|
||||
assert(preparedPlan.outputsUnsafeRows)
|
||||
}
|
||||
|
||||
test("round trip with ConvertToUnsafe and ConvertToSafe") {
|
||||
val input = Seq(("hello", 1), ("world", 2))
|
||||
checkAnswer(
|
||||
TestSQLContext.createDataFrame(input),
|
||||
ctx.createDataFrame(input),
|
||||
plan => ConvertToSafe(ConvertToUnsafe(plan)),
|
||||
input.map(Row.fromTuple)
|
||||
)
|
||||
}
|
||||
|
||||
test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") {
|
||||
SparkPlan.currentContext.set(TestSQLContext)
|
||||
SparkPlan.currentContext.set(ctx)
|
||||
val schema = ArrayType(StringType)
|
||||
val rows = (1 to 100).map { i =>
|
||||
InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString))))
|
||||
|
|
|
@ -19,8 +19,9 @@ package org.apache.spark.sql.execution
|
|||
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class SortSuite extends SparkPlanTest {
|
||||
class SortSuite extends SparkPlanTest with SharedSQLContext {
|
||||
|
||||
// This test was originally added as an example of how to use [[SparkPlanTest]];
|
||||
// it's not designed to be a comprehensive test of ExternalSort.
|
||||
|
|
|
@ -17,29 +17,27 @@
|
|||
|
||||
package org.apache.spark.sql.execution
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
|
||||
import org.apache.spark.sql.catalyst.util._
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row}
|
||||
|
||||
import scala.language.implicitConversions
|
||||
import scala.reflect.runtime.universe.TypeTag
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
|
||||
import org.apache.spark.sql.catalyst.util._
|
||||
|
||||
/**
|
||||
* Base class for writing tests for individual physical operators. For an example of how this
|
||||
* class's test helper methods can be used, see [[SortSuite]].
|
||||
*/
|
||||
class SparkPlanTest extends SparkFunSuite {
|
||||
|
||||
protected def sqlContext: SQLContext = TestSQLContext
|
||||
private[sql] abstract class SparkPlanTest extends SparkFunSuite {
|
||||
protected def _sqlContext: SQLContext
|
||||
|
||||
/**
|
||||
* Creates a DataFrame from a local Seq of Product.
|
||||
*/
|
||||
implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = {
|
||||
sqlContext.implicits.localSeqToDataFrameHolder(data)
|
||||
_sqlContext.implicits.localSeqToDataFrameHolder(data)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -100,7 +98,7 @@ class SparkPlanTest extends SparkFunSuite {
|
|||
planFunction: Seq[SparkPlan] => SparkPlan,
|
||||
expectedAnswer: Seq[Row],
|
||||
sortAnswers: Boolean = true): Unit = {
|
||||
SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match {
|
||||
SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, _sqlContext) match {
|
||||
case Some(errorMessage) => fail(errorMessage)
|
||||
case None =>
|
||||
}
|
||||
|
@ -124,7 +122,7 @@ class SparkPlanTest extends SparkFunSuite {
|
|||
expectedPlanFunction: SparkPlan => SparkPlan,
|
||||
sortAnswers: Boolean = true): Unit = {
|
||||
SparkPlanTest.checkAnswer(
|
||||
input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match {
|
||||
input, planFunction, expectedPlanFunction, sortAnswers, _sqlContext) match {
|
||||
case Some(errorMessage) => fail(errorMessage)
|
||||
case None =>
|
||||
}
|
||||
|
@ -151,13 +149,13 @@ object SparkPlanTest {
|
|||
planFunction: SparkPlan => SparkPlan,
|
||||
expectedPlanFunction: SparkPlan => SparkPlan,
|
||||
sortAnswers: Boolean,
|
||||
sqlContext: SQLContext): Option[String] = {
|
||||
_sqlContext: SQLContext): Option[String] = {
|
||||
|
||||
val outputPlan = planFunction(input.queryExecution.sparkPlan)
|
||||
val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
|
||||
|
||||
val expectedAnswer: Seq[Row] = try {
|
||||
executePlan(expectedOutputPlan, sqlContext)
|
||||
executePlan(expectedOutputPlan, _sqlContext)
|
||||
} catch {
|
||||
case NonFatal(e) =>
|
||||
val errorMessage =
|
||||
|
@ -172,7 +170,7 @@ object SparkPlanTest {
|
|||
}
|
||||
|
||||
val actualAnswer: Seq[Row] = try {
|
||||
executePlan(outputPlan, sqlContext)
|
||||
executePlan(outputPlan, _sqlContext)
|
||||
} catch {
|
||||
case NonFatal(e) =>
|
||||
val errorMessage =
|
||||
|
@ -212,12 +210,12 @@ object SparkPlanTest {
|
|||
planFunction: Seq[SparkPlan] => SparkPlan,
|
||||
expectedAnswer: Seq[Row],
|
||||
sortAnswers: Boolean,
|
||||
sqlContext: SQLContext): Option[String] = {
|
||||
_sqlContext: SQLContext): Option[String] = {
|
||||
|
||||
val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
|
||||
|
||||
val sparkAnswer: Seq[Row] = try {
|
||||
executePlan(outputPlan, sqlContext)
|
||||
executePlan(outputPlan, _sqlContext)
|
||||
} catch {
|
||||
case NonFatal(e) =>
|
||||
val errorMessage =
|
||||
|
@ -280,10 +278,10 @@ object SparkPlanTest {
|
|||
}
|
||||
}
|
||||
|
||||
private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
|
||||
private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = {
|
||||
// A very simple resolver to make writing tests easier. In contrast to the real resolver
|
||||
// this is always case sensitive and does not try to handle scoping or complex type resolution.
|
||||
val resolvedPlan = sqlContext.prepareForExecution.execute(
|
||||
val resolvedPlan = _sqlContext.prepareForExecution.execute(
|
||||
outputPlan transform {
|
||||
case plan: SparkPlan =>
|
||||
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
|
||||
|
|
|
@ -19,25 +19,28 @@ package org.apache.spark.sql.execution
|
|||
|
||||
import scala.util.Random
|
||||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
import org.apache.spark.AccumulatorSuite
|
||||
import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
* A test suite that generates randomized data to test the [[TungstenSort]] operator.
|
||||
*/
|
||||
class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
|
||||
class TungstenSortSuite extends SparkPlanTest with SharedSQLContext {
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
|
||||
super.beforeAll()
|
||||
ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
|
||||
}
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
|
||||
try {
|
||||
ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
|
||||
} finally {
|
||||
super.afterAll()
|
||||
}
|
||||
}
|
||||
|
||||
test("sort followed by limit") {
|
||||
|
@ -61,7 +64,7 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
|
|||
}
|
||||
|
||||
test("sorting updates peak execution memory") {
|
||||
val sc = TestSQLContext.sparkContext
|
||||
val sc = ctx.sparkContext
|
||||
AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") {
|
||||
checkThatPlansAgree(
|
||||
(1 to 100).map(v => Tuple1(v)).toDF("a"),
|
||||
|
@ -80,8 +83,8 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
|
|||
) {
|
||||
test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
|
||||
val inputData = Seq.fill(1000)(randomDataGenerator())
|
||||
val inputDf = TestSQLContext.createDataFrame(
|
||||
TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
|
||||
val inputDf = ctx.createDataFrame(
|
||||
ctx.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
|
||||
StructType(StructField("a", dataType, nullable = true) :: Nil)
|
||||
)
|
||||
assert(TungstenSort.supportsSchema(inputDf.schema))
|
||||
|
|
|
@ -26,7 +26,7 @@ import org.scalatest.Matchers
|
|||
import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
|
||||
import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
@ -36,7 +36,10 @@ import org.apache.spark.unsafe.types.UTF8String
|
|||
*
|
||||
* Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test cases.
|
||||
*/
|
||||
class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
|
||||
class UnsafeFixedWidthAggregationMapSuite
|
||||
extends SparkFunSuite
|
||||
with Matchers
|
||||
with SharedSQLContext {
|
||||
|
||||
import UnsafeFixedWidthAggregationMap._
|
||||
|
||||
|
@ -171,9 +174,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
|
|||
}
|
||||
|
||||
testWithMemoryLeakDetection("test external sorting") {
|
||||
// Calling this make sure we have block manager and everything else setup.
|
||||
TestSQLContext
|
||||
|
||||
// Memory consumption in the beginning of the task.
|
||||
val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
|
||||
|
||||
|
@ -233,8 +233,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
|
|||
}
|
||||
|
||||
testWithMemoryLeakDetection("test external sorting with an empty map") {
|
||||
// Calling this make sure we have block manager and everything else setup.
|
||||
TestSQLContext
|
||||
|
||||
val map = new UnsafeFixedWidthAggregationMap(
|
||||
emptyAggregationBuffer,
|
||||
|
@ -282,8 +280,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
|
|||
}
|
||||
|
||||
testWithMemoryLeakDetection("test external sorting with empty records") {
|
||||
// Calling this make sure we have block manager and everything else setup.
|
||||
TestSQLContext
|
||||
|
||||
// Memory consumption in the beginning of the task.
|
||||
val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
|
||||
|
|
|
@ -23,15 +23,14 @@ import org.apache.spark._
|
|||
import org.apache.spark.sql.{RandomDataGenerator, Row}
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||
import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection}
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
|
||||
|
||||
/**
|
||||
* Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data.
|
||||
*/
|
||||
class UnsafeKVExternalSorterSuite extends SparkFunSuite {
|
||||
|
||||
class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
|
||||
private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType)
|
||||
private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType)
|
||||
|
||||
|
@ -109,8 +108,6 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite {
|
|||
inputData: Seq[(InternalRow, InternalRow)],
|
||||
pageSize: Long,
|
||||
spill: Boolean): Unit = {
|
||||
// Calling this make sure we have block manager and everything else setup.
|
||||
TestSQLContext
|
||||
|
||||
val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
|
||||
val shuffleMemMgr = new TestShuffleMemoryManager
|
||||
|
|
|
@ -21,15 +21,12 @@ import org.apache.spark._
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.unsafe.memory.TaskMemoryManager
|
||||
|
||||
class TungstenAggregationIteratorSuite extends SparkFunSuite {
|
||||
class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext {
|
||||
|
||||
test("memory acquired on construction") {
|
||||
// set up environment
|
||||
val ctx = TestSQLContext
|
||||
|
||||
val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager)
|
||||
val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty)
|
||||
TaskContext.setTaskContext(taskContext)
|
||||
|
|
|
@ -24,22 +24,16 @@ import com.fasterxml.jackson.core.JsonFactory
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.scalactic.Tolerance._
|
||||
|
||||
import org.apache.spark.sql.{SQLContext, QueryTest, Row, SQLConf}
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation}
|
||||
import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
|
||||
|
||||
protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
override def sqlContext: SQLContext = ctx // used by SQLTestUtils
|
||||
|
||||
import ctx.sql
|
||||
import ctx.implicits._
|
||||
class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
|
||||
import testImplicits._
|
||||
|
||||
test("Type promotion") {
|
||||
def checkTypePromotion(expected: Any, actual: Any) {
|
||||
|
@ -596,7 +590,8 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
|
|||
|
||||
val schema = StructType(StructField("a", LongType, true) :: Nil)
|
||||
val logicalRelation =
|
||||
ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation]
|
||||
ctx.read.schema(schema).json(path)
|
||||
.queryExecution.analyzed.asInstanceOf[LogicalRelation]
|
||||
val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation]
|
||||
assert(relationWithSchema.paths === Array(path))
|
||||
assert(relationWithSchema.schema === schema)
|
||||
|
@ -1040,31 +1035,29 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
|
|||
}
|
||||
|
||||
test("JSONRelation equality test") {
|
||||
val context = org.apache.spark.sql.test.TestSQLContext
|
||||
|
||||
val relation0 = new JSONRelation(
|
||||
Some(empty),
|
||||
1.0,
|
||||
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
|
||||
None, None)(context)
|
||||
None, None)(ctx)
|
||||
val logicalRelation0 = LogicalRelation(relation0)
|
||||
val relation1 = new JSONRelation(
|
||||
Some(singleRow),
|
||||
1.0,
|
||||
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
|
||||
None, None)(context)
|
||||
None, None)(ctx)
|
||||
val logicalRelation1 = LogicalRelation(relation1)
|
||||
val relation2 = new JSONRelation(
|
||||
Some(singleRow),
|
||||
0.5,
|
||||
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
|
||||
None, None)(context)
|
||||
None, None)(ctx)
|
||||
val logicalRelation2 = LogicalRelation(relation2)
|
||||
val relation3 = new JSONRelation(
|
||||
Some(singleRow),
|
||||
1.0,
|
||||
Some(StructType(StructField("b", IntegerType, true) :: Nil)),
|
||||
None, None)(context)
|
||||
None, None)(ctx)
|
||||
val logicalRelation3 = LogicalRelation(relation3)
|
||||
|
||||
assert(relation0 !== relation1)
|
||||
|
@ -1089,14 +1082,14 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
|
|||
.map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path)
|
||||
|
||||
val d1 = ResolvedDataSource(
|
||||
context,
|
||||
ctx,
|
||||
userSpecifiedSchema = None,
|
||||
partitionColumns = Array.empty[String],
|
||||
provider = classOf[DefaultSource].getCanonicalName,
|
||||
options = Map("path" -> path))
|
||||
|
||||
val d2 = ResolvedDataSource(
|
||||
context,
|
||||
ctx,
|
||||
userSpecifiedSchema = None,
|
||||
partitionColumns = Array.empty[String],
|
||||
provider = classOf[DefaultSource].getCanonicalName,
|
||||
|
@ -1162,11 +1155,12 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
|
|||
"abd")
|
||||
|
||||
ctx.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part")
|
||||
checkAnswer(
|
||||
sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4))
|
||||
checkAnswer(
|
||||
sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5))
|
||||
checkAnswer(sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9))
|
||||
checkAnswer(sql(
|
||||
"SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4))
|
||||
checkAnswer(sql(
|
||||
"SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5))
|
||||
checkAnswer(sql(
|
||||
"SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,12 +20,11 @@ package org.apache.spark.sql.execution.datasources.json
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.SQLContext
|
||||
|
||||
trait TestJsonData {
|
||||
|
||||
protected def ctx: SQLContext
|
||||
private[json] trait TestJsonData {
|
||||
protected def _sqlContext: SQLContext
|
||||
|
||||
def primitiveFieldAndType: RDD[String] =
|
||||
ctx.sparkContext.parallelize(
|
||||
_sqlContext.sparkContext.parallelize(
|
||||
"""{"string":"this is a simple string.",
|
||||
"integer":10,
|
||||
"long":21474836470,
|
||||
|
@ -36,7 +35,7 @@ trait TestJsonData {
|
|||
}""" :: Nil)
|
||||
|
||||
def primitiveFieldValueTypeConflict: RDD[String] =
|
||||
ctx.sparkContext.parallelize(
|
||||
_sqlContext.sparkContext.parallelize(
|
||||
"""{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1,
|
||||
"num_bool":true, "num_str":13.1, "str_bool":"str1"}""" ::
|
||||
"""{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null,
|
||||
|
@ -47,14 +46,14 @@ trait TestJsonData {
|
|||
"num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil)
|
||||
|
||||
def jsonNullStruct: RDD[String] =
|
||||
ctx.sparkContext.parallelize(
|
||||
_sqlContext.sparkContext.parallelize(
|
||||
"""{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" ::
|
||||
"""{"nullstr":"","ip":"27.31.100.29","headers":{}}""" ::
|
||||
"""{"nullstr":"","ip":"27.31.100.29","headers":""}""" ::
|
||||
"""{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil)
|
||||
|
||||
def complexFieldValueTypeConflict: RDD[String] =
|
||||
ctx.sparkContext.parallelize(
|
||||
_sqlContext.sparkContext.parallelize(
|
||||
"""{"num_struct":11, "str_array":[1, 2, 3],
|
||||
"array":[], "struct_array":[], "struct": {}}""" ::
|
||||
"""{"num_struct":{"field":false}, "str_array":null,
|
||||
|
@ -65,14 +64,14 @@ trait TestJsonData {
|
|||
"array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil)
|
||||
|
||||
def arrayElementTypeConflict: RDD[String] =
|
||||
ctx.sparkContext.parallelize(
|
||||
_sqlContext.sparkContext.parallelize(
|
||||
"""{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}],
|
||||
"array2": [{"field":214748364700}, {"field":1}]}""" ::
|
||||
"""{"array3": [{"field":"str"}, {"field":1}]}""" ::
|
||||
"""{"array3": [1, 2, 3]}""" :: Nil)
|
||||
|
||||
def missingFields: RDD[String] =
|
||||
ctx.sparkContext.parallelize(
|
||||
_sqlContext.sparkContext.parallelize(
|
||||
"""{"a":true}""" ::
|
||||
"""{"b":21474836470}""" ::
|
||||
"""{"c":[33, 44]}""" ::
|
||||
|
@ -80,7 +79,7 @@ trait TestJsonData {
|
|||
"""{"e":"str"}""" :: Nil)
|
||||
|
||||
def complexFieldAndType1: RDD[String] =
|
||||
ctx.sparkContext.parallelize(
|
||||
_sqlContext.sparkContext.parallelize(
|
||||
"""{"struct":{"field1": true, "field2": 92233720368547758070},
|
||||
"structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]},
|
||||
"arrayOfString":["str1", "str2"],
|
||||
|
@ -96,7 +95,7 @@ trait TestJsonData {
|
|||
}""" :: Nil)
|
||||
|
||||
def complexFieldAndType2: RDD[String] =
|
||||
ctx.sparkContext.parallelize(
|
||||
_sqlContext.sparkContext.parallelize(
|
||||
"""{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}],
|
||||
"complexArrayOfStruct": [
|
||||
{
|
||||
|
@ -150,7 +149,7 @@ trait TestJsonData {
|
|||
}""" :: Nil)
|
||||
|
||||
def mapType1: RDD[String] =
|
||||
ctx.sparkContext.parallelize(
|
||||
_sqlContext.sparkContext.parallelize(
|
||||
"""{"map": {"a": 1}}""" ::
|
||||
"""{"map": {"b": 2}}""" ::
|
||||
"""{"map": {"c": 3}}""" ::
|
||||
|
@ -158,7 +157,7 @@ trait TestJsonData {
|
|||
"""{"map": {"e": null}}""" :: Nil)
|
||||
|
||||
def mapType2: RDD[String] =
|
||||
ctx.sparkContext.parallelize(
|
||||
_sqlContext.sparkContext.parallelize(
|
||||
"""{"map": {"a": {"field1": [1, 2, 3, null]}}}""" ::
|
||||
"""{"map": {"b": {"field2": 2}}}""" ::
|
||||
"""{"map": {"c": {"field1": [], "field2": 4}}}""" ::
|
||||
|
@ -167,21 +166,21 @@ trait TestJsonData {
|
|||
"""{"map": {"f": {"field1": null}}}""" :: Nil)
|
||||
|
||||
def nullsInArrays: RDD[String] =
|
||||
ctx.sparkContext.parallelize(
|
||||
_sqlContext.sparkContext.parallelize(
|
||||
"""{"field1":[[null], [[["Test"]]]]}""" ::
|
||||
"""{"field2":[null, [{"Test":1}]]}""" ::
|
||||
"""{"field3":[[null], [{"Test":"2"}]]}""" ::
|
||||
"""{"field4":[[null, [1,2,3]]]}""" :: Nil)
|
||||
|
||||
def jsonArray: RDD[String] =
|
||||
ctx.sparkContext.parallelize(
|
||||
_sqlContext.sparkContext.parallelize(
|
||||
"""[{"a":"str_a_1"}]""" ::
|
||||
"""[{"a":"str_a_2"}, {"b":"str_b_3"}]""" ::
|
||||
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
|
||||
"""[]""" :: Nil)
|
||||
|
||||
def corruptRecords: RDD[String] =
|
||||
ctx.sparkContext.parallelize(
|
||||
_sqlContext.sparkContext.parallelize(
|
||||
"""{""" ::
|
||||
"""""" ::
|
||||
"""{"a":1, b:2}""" ::
|
||||
|
@ -190,7 +189,7 @@ trait TestJsonData {
|
|||
"""]""" :: Nil)
|
||||
|
||||
def emptyRecords: RDD[String] =
|
||||
ctx.sparkContext.parallelize(
|
||||
_sqlContext.sparkContext.parallelize(
|
||||
"""{""" ::
|
||||
"""""" ::
|
||||
"""{"a": {}}""" ::
|
||||
|
@ -198,9 +197,8 @@ trait TestJsonData {
|
|||
"""{"b": [{"c": {}}]}""" ::
|
||||
"""]""" :: Nil)
|
||||
|
||||
lazy val singleRow: RDD[String] =
|
||||
ctx.sparkContext.parallelize(
|
||||
"""{"a":123}""" :: Nil)
|
||||
|
||||
def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]())
|
||||
lazy val singleRow: RDD[String] = _sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil)
|
||||
|
||||
def empty: RDD[String] = _sqlContext.sparkContext.parallelize(Seq[String]())
|
||||
}
|
||||
|
|
|
@ -27,18 +27,16 @@ import org.apache.avro.generic.IndexedRecord
|
|||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.parquet.avro.AvroParquetWriter
|
||||
|
||||
import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat, ParquetEnum, Suit}
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.{Row, SQLContext}
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.execution.datasources.parquet.test.avro._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest {
|
||||
class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
|
||||
import ParquetCompatibilityTest._
|
||||
|
||||
override val sqlContext: SQLContext = TestSQLContext
|
||||
|
||||
private def withWriter[T <: IndexedRecord]
|
||||
(path: String, schema: Schema)
|
||||
(f: AvroParquetWriter[T] => Unit) = {
|
||||
(f: AvroParquetWriter[T] => Unit): Unit = {
|
||||
val writer = new AvroParquetWriter[T](new Path(path), schema)
|
||||
try f(writer) finally writer.close()
|
||||
}
|
||||
|
@ -129,7 +127,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest {
|
|||
}
|
||||
|
||||
test("SPARK-9407 Don't push down predicates involving Parquet ENUM columns") {
|
||||
import sqlContext.implicits._
|
||||
import testImplicits._
|
||||
|
||||
withTempPath { dir =>
|
||||
val path = dir.getCanonicalPath
|
||||
|
|
|
@ -22,16 +22,18 @@ import scala.collection.JavaConversions._
|
|||
import org.apache.hadoop.fs.{Path, PathFilter}
|
||||
import org.apache.parquet.hadoop.ParquetFileReader
|
||||
import org.apache.parquet.schema.MessageType
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
import org.apache.spark.sql.QueryTest
|
||||
|
||||
abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll {
|
||||
def readParquetSchema(path: String): MessageType = {
|
||||
/**
|
||||
* Helper class for testing Parquet compatibility.
|
||||
*/
|
||||
private[sql] abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest {
|
||||
protected def readParquetSchema(path: String): MessageType = {
|
||||
readParquetSchema(path, { path => !path.getName.startsWith("_") })
|
||||
}
|
||||
|
||||
def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = {
|
||||
protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = {
|
||||
val fsPath = new Path(path)
|
||||
val fs = fsPath.getFileSystem(configuration)
|
||||
val parquetFiles = fs.listStatus(fsPath, new PathFilter {
|
||||
|
|
|
@ -20,12 +20,13 @@ package org.apache.spark.sql.execution.datasources.parquet
|
|||
import org.apache.parquet.filter2.predicate.Operators._
|
||||
import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators}
|
||||
|
||||
import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
|
||||
import org.apache.spark.sql.execution.datasources.LogicalRelation
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
|
||||
|
||||
/**
|
||||
* A test suite that tests Parquet filter2 API based filter pushdown optimization.
|
||||
|
@ -39,8 +40,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
|
|||
* 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred
|
||||
* data type is nullable.
|
||||
*/
|
||||
class ParquetFilterSuite extends QueryTest with ParquetTest {
|
||||
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
|
||||
class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext {
|
||||
|
||||
private def checkFilterPredicate(
|
||||
df: DataFrame,
|
||||
|
@ -301,7 +301,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
|
|||
}
|
||||
|
||||
test("SPARK-6554: don't push down predicates which reference partition columns") {
|
||||
import sqlContext.implicits._
|
||||
import testImplicits._
|
||||
|
||||
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
|
||||
withTempPath { dir =>
|
||||
|
|
|
@ -37,6 +37,7 @@ import org.apache.spark.SparkException
|
|||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
|
||||
|
@ -62,9 +63,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS
|
|||
/**
|
||||
* A test suite that tests basic Parquet I/O.
|
||||
*/
|
||||
class ParquetIOSuite extends QueryTest with ParquetTest {
|
||||
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
|
||||
import sqlContext.implicits._
|
||||
class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
/**
|
||||
* Writes `data` to a Parquet file, reads it back and check file contents.
|
||||
|
|
|
@ -26,13 +26,13 @@ import scala.collection.mutable.ArrayBuffer
|
|||
import com.google.common.io.Files
|
||||
import org.apache.hadoop.fs.Path
|
||||
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.Literal
|
||||
import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
import PartitioningUtils._
|
||||
|
||||
// The data where the partitioning key exists only in the directory structure.
|
||||
case class ParquetData(intField: Int, stringField: String)
|
||||
|
@ -40,11 +40,9 @@ case class ParquetData(intField: Int, stringField: String)
|
|||
// The data that also includes the partitioning key
|
||||
case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String)
|
||||
|
||||
class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
|
||||
|
||||
override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
|
||||
import sqlContext.implicits._
|
||||
import sqlContext.sql
|
||||
class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with SharedSQLContext {
|
||||
import PartitioningUtils._
|
||||
import testImplicits._
|
||||
|
||||
val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__"
|
||||
|
||||
|
|
|
@ -17,11 +17,10 @@
|
|||
|
||||
package org.apache.spark.sql.execution.datasources.parquet
|
||||
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest {
|
||||
override def sqlContext: SQLContext = TestSQLContext
|
||||
class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
|
||||
|
||||
private def readParquetProtobufFile(name: String): DataFrame = {
|
||||
val url = Thread.currentThread().getContextClassLoader.getResource(name)
|
||||
|
|
|
@ -21,16 +21,15 @@ import java.io.File
|
|||
|
||||
import org.apache.hadoop.fs.Path
|
||||
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
* A test suite that tests various Parquet queries.
|
||||
*/
|
||||
class ParquetQuerySuite extends QueryTest with ParquetTest {
|
||||
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
|
||||
import sqlContext.sql
|
||||
class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext {
|
||||
|
||||
test("simple select queries") {
|
||||
withParquetTable((0 until 10).map(i => (i, i.toString)), "t") {
|
||||
|
@ -41,22 +40,22 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
|
|||
|
||||
test("appending") {
|
||||
val data = (0 until 10).map(i => (i, i.toString))
|
||||
sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
|
||||
ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
|
||||
withParquetTable(data, "t") {
|
||||
sql("INSERT INTO TABLE t SELECT * FROM tmp")
|
||||
checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple))
|
||||
checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple))
|
||||
}
|
||||
sqlContext.catalog.unregisterTable(Seq("tmp"))
|
||||
ctx.catalog.unregisterTable(Seq("tmp"))
|
||||
}
|
||||
|
||||
test("overwriting") {
|
||||
val data = (0 until 10).map(i => (i, i.toString))
|
||||
sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
|
||||
ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
|
||||
withParquetTable(data, "t") {
|
||||
sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
|
||||
checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple))
|
||||
checkAnswer(ctx.table("t"), data.map(Row.fromTuple))
|
||||
}
|
||||
sqlContext.catalog.unregisterTable(Seq("tmp"))
|
||||
ctx.catalog.unregisterTable(Seq("tmp"))
|
||||
}
|
||||
|
||||
test("self-join") {
|
||||
|
@ -119,9 +118,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
|
|||
val schema = StructType(List(StructField("d", DecimalType(18, 0), false),
|
||||
StructField("time", TimestampType, false)).toArray)
|
||||
withTempPath { file =>
|
||||
val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema)
|
||||
val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema)
|
||||
df.write.parquet(file.getCanonicalPath)
|
||||
val df2 = sqlContext.read.parquet(file.getCanonicalPath)
|
||||
val df2 = ctx.read.parquet(file.getCanonicalPath)
|
||||
checkAnswer(df2, df.collect().toSeq)
|
||||
}
|
||||
}
|
||||
|
@ -130,12 +129,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
|
|||
def testSchemaMerging(expectedColumnNumber: Int): Unit = {
|
||||
withTempDir { dir =>
|
||||
val basePath = dir.getCanonicalPath
|
||||
sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
|
||||
sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
|
||||
ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
|
||||
ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
|
||||
// delete summary files, so if we don't merge part-files, one column will not be included.
|
||||
Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata"))
|
||||
Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata"))
|
||||
assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber)
|
||||
assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -154,9 +153,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
|
|||
def testSchemaMerging(expectedColumnNumber: Int): Unit = {
|
||||
withTempDir { dir =>
|
||||
val basePath = dir.getCanonicalPath
|
||||
sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
|
||||
sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
|
||||
assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber)
|
||||
ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
|
||||
ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
|
||||
assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -172,19 +171,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
|
|||
test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") {
|
||||
withTempPath { dir =>
|
||||
val basePath = dir.getCanonicalPath
|
||||
sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
|
||||
sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString)
|
||||
ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
|
||||
ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString)
|
||||
|
||||
// Disables the global SQL option for schema merging
|
||||
withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") {
|
||||
assertResult(2) {
|
||||
// Disables schema merging via data source option
|
||||
sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length
|
||||
ctx.read.option("mergeSchema", "false").parquet(basePath).columns.length
|
||||
}
|
||||
|
||||
assertResult(3) {
|
||||
// Enables schema merging via data source option
|
||||
sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length
|
||||
ctx.read.option("mergeSchema", "true").parquet(basePath).columns.length
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,13 +22,11 @@ import scala.reflect.runtime.universe.TypeTag
|
|||
|
||||
import org.apache.parquet.schema.MessageTypeParser
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.ScalaReflection
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest {
|
||||
val sqlContext = TestSQLContext
|
||||
abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext {
|
||||
|
||||
/**
|
||||
* Checks whether the reflected Parquet message type for product type `T` conforms `messageType`.
|
||||
|
|
|
@ -22,9 +22,8 @@ import java.io.File
|
|||
import scala.reflect.ClassTag
|
||||
import scala.reflect.runtime.universe.TypeTag
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
import org.apache.spark.sql.{DataFrame, SaveMode}
|
||||
import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext}
|
||||
|
||||
/**
|
||||
* A helper trait that provides convenient facilities for Parquet testing.
|
||||
|
@ -33,7 +32,9 @@ import org.apache.spark.sql.{DataFrame, SaveMode}
|
|||
* convenient to use tuples rather than special case classes when writing test cases/suites.
|
||||
* Especially, `Tuple1.apply` can be used to easily wrap a single type/value.
|
||||
*/
|
||||
private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
|
||||
private[sql] trait ParquetTest extends SQLTestUtils {
|
||||
protected def _sqlContext: SQLContext
|
||||
|
||||
/**
|
||||
* Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f`
|
||||
* returns.
|
||||
|
@ -42,7 +43,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
|
|||
(data: Seq[T])
|
||||
(f: String => Unit): Unit = {
|
||||
withTempPath { file =>
|
||||
sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
|
||||
_sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
|
||||
f(file.getCanonicalPath)
|
||||
}
|
||||
}
|
||||
|
@ -54,7 +55,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
|
|||
protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag]
|
||||
(data: Seq[T])
|
||||
(f: DataFrame => Unit): Unit = {
|
||||
withParquetFile(data)(path => f(sqlContext.read.parquet(path)))
|
||||
withParquetFile(data)(path => f(_sqlContext.read.parquet(path)))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -66,14 +67,14 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
|
|||
(data: Seq[T], tableName: String)
|
||||
(f: => Unit): Unit = {
|
||||
withParquetDataFrame(data) { df =>
|
||||
sqlContext.registerDataFrameAsTable(df, tableName)
|
||||
_sqlContext.registerDataFrameAsTable(df, tableName)
|
||||
withTempTable(tableName)(f)
|
||||
}
|
||||
}
|
||||
|
||||
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
|
||||
data: Seq[T], path: File): Unit = {
|
||||
sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
|
||||
_sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
|
||||
}
|
||||
|
||||
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
|
||||
|
|
|
@ -17,14 +17,12 @@
|
|||
|
||||
package org.apache.spark.sql.execution.datasources.parquet
|
||||
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.{Row, SQLContext}
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest {
|
||||
class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
|
||||
import ParquetCompatibilityTest._
|
||||
|
||||
override val sqlContext: SQLContext = TestSQLContext
|
||||
|
||||
private val parquetFilePath =
|
||||
Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet")
|
||||
|
||||
|
|
|
@ -18,10 +18,10 @@
|
|||
package org.apache.spark.sql.execution.debug
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.TestData._
|
||||
import org.apache.spark.sql.test.TestSQLContext._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class DebuggingSuite extends SparkFunSuite with SharedSQLContext {
|
||||
|
||||
class DebuggingSuite extends SparkFunSuite {
|
||||
test("DataFrame.debug()") {
|
||||
testData.debug()
|
||||
}
|
||||
|
|
|
@ -23,12 +23,12 @@ import org.apache.spark.SparkFunSuite
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.execution.metric.SQLMetrics
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
|
||||
import org.apache.spark.util.collection.CompactBuffer
|
||||
|
||||
|
||||
class HashedRelationSuite extends SparkFunSuite {
|
||||
class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
|
||||
|
||||
// Key is simply the record itself
|
||||
private val keyProjection = new Projection {
|
||||
|
@ -37,7 +37,7 @@ class HashedRelationSuite extends SparkFunSuite {
|
|||
|
||||
test("GeneralHashedRelation") {
|
||||
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
|
||||
val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data")
|
||||
val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
|
||||
val hashed = HashedRelation(data.iterator, numDataRows, keyProjection)
|
||||
assert(hashed.isInstanceOf[GeneralHashedRelation])
|
||||
|
||||
|
@ -53,7 +53,7 @@ class HashedRelationSuite extends SparkFunSuite {
|
|||
|
||||
test("UniqueKeyHashedRelation") {
|
||||
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2))
|
||||
val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data")
|
||||
val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
|
||||
val hashed = HashedRelation(data.iterator, numDataRows, keyProjection)
|
||||
assert(hashed.isInstanceOf[UniqueKeyHashedRelation])
|
||||
|
||||
|
@ -73,7 +73,7 @@ class HashedRelationSuite extends SparkFunSuite {
|
|||
test("UnsafeHashedRelation") {
|
||||
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
|
||||
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
|
||||
val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data")
|
||||
val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
|
||||
val toUnsafe = UnsafeProjection.create(schema)
|
||||
val unsafeData = data.map(toUnsafe(_).copy()).toArray
|
||||
|
||||
|
|
|
@ -17,97 +17,19 @@
|
|||
|
||||
package org.apache.spark.sql.execution.joins
|
||||
|
||||
import org.apache.spark.sql.{DataFrame, execution, Row, SQLConf}
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
|
||||
import org.apache.spark.sql.catalyst.plans.Inner
|
||||
import org.apache.spark.sql.catalyst.plans.logical.Join
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
|
||||
import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame}
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
import org.apache.spark.sql.execution._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
|
||||
|
||||
class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
|
||||
class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
|
||||
|
||||
private def testInnerJoin(
|
||||
testName: String,
|
||||
leftRows: DataFrame,
|
||||
rightRows: DataFrame,
|
||||
condition: Expression,
|
||||
expectedAnswer: Seq[Product]): Unit = {
|
||||
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
|
||||
ExtractEquiJoinKeys.unapply(join).foreach {
|
||||
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
|
||||
|
||||
def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
|
||||
val broadcastHashJoin =
|
||||
execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right)
|
||||
boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
|
||||
}
|
||||
|
||||
def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
|
||||
val shuffledHashJoin =
|
||||
execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right)
|
||||
val filteredJoin =
|
||||
boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
|
||||
EnsureRequirements(sqlContext).apply(filteredJoin)
|
||||
}
|
||||
|
||||
def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = {
|
||||
val sortMergeJoin =
|
||||
execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right)
|
||||
val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
|
||||
EnsureRequirements(sqlContext).apply(filteredJoin)
|
||||
}
|
||||
|
||||
test(s"$testName using BroadcastHashJoin (build=left)") {
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
makeBroadcastHashJoin(left, right, joins.BuildLeft),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$testName using BroadcastHashJoin (build=right)") {
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
makeBroadcastHashJoin(left, right, joins.BuildRight),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$testName using ShuffledHashJoin (build=left)") {
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
makeShuffledHashJoin(left, right, joins.BuildLeft),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$testName using ShuffledHashJoin (build=right)") {
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
makeShuffledHashJoin(left, right, joins.BuildRight),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$testName using SortMergeJoin") {
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
makeSortMergeJoin(left, right),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
val upperCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
|
||||
private lazy val myUpperCaseData = ctx.createDataFrame(
|
||||
ctx.sparkContext.parallelize(Seq(
|
||||
Row(1, "A"),
|
||||
Row(2, "B"),
|
||||
Row(3, "C"),
|
||||
|
@ -117,7 +39,8 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
|
|||
Row(null, "G")
|
||||
)), new StructType().add("N", IntegerType).add("L", StringType))
|
||||
|
||||
val lowerCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
|
||||
private lazy val myLowerCaseData = ctx.createDataFrame(
|
||||
ctx.sparkContext.parallelize(Seq(
|
||||
Row(1, "a"),
|
||||
Row(2, "b"),
|
||||
Row(3, "c"),
|
||||
|
@ -125,21 +48,7 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
|
|||
Row(null, "e")
|
||||
)), new StructType().add("n", IntegerType).add("l", StringType))
|
||||
|
||||
testInnerJoin(
|
||||
"inner join, one match per row",
|
||||
upperCaseData,
|
||||
lowerCaseData,
|
||||
(upperCaseData.col("N") === lowerCaseData.col("n")).expr,
|
||||
Seq(
|
||||
(1, "A", 1, "a"),
|
||||
(2, "B", 2, "b"),
|
||||
(3, "C", 3, "c"),
|
||||
(4, "D", 4, "d")
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
private val testData2 = Seq(
|
||||
private lazy val myTestData = Seq(
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(2, 1),
|
||||
|
@ -148,14 +57,139 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
|
|||
(3, 2)
|
||||
).toDF("a", "b")
|
||||
|
||||
// Note: the input dataframes and expression must be evaluated lazily because
|
||||
// the SQLContext should be used only within a test to keep SQL tests stable
|
||||
private def testInnerJoin(
|
||||
testName: String,
|
||||
leftRows: => DataFrame,
|
||||
rightRows: => DataFrame,
|
||||
condition: () => Expression,
|
||||
expectedAnswer: Seq[Product]): Unit = {
|
||||
|
||||
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
|
||||
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition()))
|
||||
ExtractEquiJoinKeys.unapply(join)
|
||||
}
|
||||
|
||||
def makeBroadcastHashJoin(
|
||||
leftKeys: Seq[Expression],
|
||||
rightKeys: Seq[Expression],
|
||||
boundCondition: Option[Expression],
|
||||
leftPlan: SparkPlan,
|
||||
rightPlan: SparkPlan,
|
||||
side: BuildSide) = {
|
||||
val broadcastHashJoin =
|
||||
execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
|
||||
boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
|
||||
}
|
||||
|
||||
def makeShuffledHashJoin(
|
||||
leftKeys: Seq[Expression],
|
||||
rightKeys: Seq[Expression],
|
||||
boundCondition: Option[Expression],
|
||||
leftPlan: SparkPlan,
|
||||
rightPlan: SparkPlan,
|
||||
side: BuildSide) = {
|
||||
val shuffledHashJoin =
|
||||
execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
|
||||
val filteredJoin =
|
||||
boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
|
||||
EnsureRequirements(sqlContext).apply(filteredJoin)
|
||||
}
|
||||
|
||||
def makeSortMergeJoin(
|
||||
leftKeys: Seq[Expression],
|
||||
rightKeys: Seq[Expression],
|
||||
boundCondition: Option[Expression],
|
||||
leftPlan: SparkPlan,
|
||||
rightPlan: SparkPlan) = {
|
||||
val sortMergeJoin =
|
||||
execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan)
|
||||
val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
|
||||
EnsureRequirements(sqlContext).apply(filteredJoin)
|
||||
}
|
||||
|
||||
test(s"$testName using BroadcastHashJoin (build=left)") {
|
||||
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
|
||||
makeBroadcastHashJoin(
|
||||
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$testName using BroadcastHashJoin (build=right)") {
|
||||
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
|
||||
makeBroadcastHashJoin(
|
||||
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$testName using ShuffledHashJoin (build=left)") {
|
||||
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
|
||||
makeShuffledHashJoin(
|
||||
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$testName using ShuffledHashJoin (build=right)") {
|
||||
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
|
||||
makeShuffledHashJoin(
|
||||
leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$testName using SortMergeJoin") {
|
||||
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
|
||||
makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
testInnerJoin(
|
||||
"inner join, one match per row",
|
||||
myUpperCaseData,
|
||||
myLowerCaseData,
|
||||
() => (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr,
|
||||
Seq(
|
||||
(1, "A", 1, "a"),
|
||||
(2, "B", 2, "b"),
|
||||
(3, "C", 3, "c"),
|
||||
(4, "D", 4, "d")
|
||||
)
|
||||
)
|
||||
|
||||
{
|
||||
val left = testData2.where("a = 1")
|
||||
val right = testData2.where("a = 1")
|
||||
lazy val left = myTestData.where("a = 1")
|
||||
lazy val right = myTestData.where("a = 1")
|
||||
testInnerJoin(
|
||||
"inner join, multiple matches",
|
||||
left,
|
||||
right,
|
||||
(left.col("a") === right.col("a")).expr,
|
||||
() => (left.col("a") === right.col("a")).expr,
|
||||
Seq(
|
||||
(1, 1, 1, 1),
|
||||
(1, 1, 1, 2),
|
||||
|
@ -166,13 +200,13 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
|
|||
}
|
||||
|
||||
{
|
||||
val left = testData2.where("a = 1")
|
||||
val right = testData2.where("a = 2")
|
||||
lazy val left = myTestData.where("a = 1")
|
||||
lazy val right = myTestData.where("a = 2")
|
||||
testInnerJoin(
|
||||
"inner join, no matches",
|
||||
left,
|
||||
right,
|
||||
(left.col("a") === right.col("a")).expr,
|
||||
() => (left.col("a") === right.col("a")).expr,
|
||||
Seq.empty
|
||||
)
|
||||
}
|
||||
|
|
|
@ -17,79 +17,19 @@
|
|||
|
||||
package org.apache.spark.sql.execution.joins
|
||||
|
||||
import org.apache.spark.sql.{DataFrame, Row, SQLConf}
|
||||
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
|
||||
import org.apache.spark.sql.catalyst.plans.logical.Join
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType}
|
||||
import org.apache.spark.sql.{SQLConf, DataFrame, Row}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans._
|
||||
import org.apache.spark.sql.execution.{EnsureRequirements, joins, SparkPlan, SparkPlanTest}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.Join
|
||||
import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
|
||||
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType}
|
||||
|
||||
class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
|
||||
class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
|
||||
|
||||
private def testOuterJoin(
|
||||
testName: String,
|
||||
leftRows: DataFrame,
|
||||
rightRows: DataFrame,
|
||||
joinType: JoinType,
|
||||
condition: Expression,
|
||||
expectedAnswer: Seq[Product]): Unit = {
|
||||
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
|
||||
ExtractEquiJoinKeys.unapply(join).foreach {
|
||||
case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
|
||||
test(s"$testName using ShuffledHashOuterJoin") {
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
EnsureRequirements(sqlContext).apply(
|
||||
ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
|
||||
if (joinType != FullOuter) {
|
||||
test(s"$testName using BroadcastHashOuterJoin") {
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$testName using SortMergeOuterJoin") {
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
EnsureRequirements(sqlContext).apply(
|
||||
SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = false)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$testName using BroadcastNestedLoopJoin (build=left)") {
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$testName using BroadcastNestedLoopJoin (build=right)") {
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
|
||||
private lazy val left = ctx.createDataFrame(
|
||||
ctx.sparkContext.parallelize(Seq(
|
||||
Row(1, 2.0),
|
||||
Row(2, 100.0),
|
||||
Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches
|
||||
|
@ -100,7 +40,8 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
|
|||
Row(null, null)
|
||||
)), new StructType().add("a", IntegerType).add("b", DoubleType))
|
||||
|
||||
val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
|
||||
private lazy val right = ctx.createDataFrame(
|
||||
ctx.sparkContext.parallelize(Seq(
|
||||
Row(0, 0.0),
|
||||
Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches
|
||||
Row(2, -1.0),
|
||||
|
@ -113,12 +54,64 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
|
|||
Row(null, null)
|
||||
)), new StructType().add("c", IntegerType).add("d", DoubleType))
|
||||
|
||||
val condition = {
|
||||
And(
|
||||
(left.col("a") === right.col("c")).expr,
|
||||
private lazy val condition = {
|
||||
And((left.col("a") === right.col("c")).expr,
|
||||
LessThan(left.col("b").expr, right.col("d").expr))
|
||||
}
|
||||
|
||||
// Note: the input dataframes and expression must be evaluated lazily because
|
||||
// the SQLContext should be used only within a test to keep SQL tests stable
|
||||
private def testOuterJoin(
|
||||
testName: String,
|
||||
leftRows: => DataFrame,
|
||||
rightRows: => DataFrame,
|
||||
joinType: JoinType,
|
||||
condition: => Expression,
|
||||
expectedAnswer: Seq[Product]): Unit = {
|
||||
|
||||
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
|
||||
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
|
||||
ExtractEquiJoinKeys.unapply(join)
|
||||
}
|
||||
|
||||
test(s"$testName using ShuffledHashOuterJoin") {
|
||||
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
EnsureRequirements(sqlContext).apply(
|
||||
ShuffledHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (joinType != FullOuter) {
|
||||
test(s"$testName using BroadcastHashOuterJoin") {
|
||||
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$testName using SortMergeOuterJoin") {
|
||||
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
EnsureRequirements(sqlContext).apply(
|
||||
SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
|
||||
expectedAnswer.map(Row.fromTuple),
|
||||
sortAnswers = false)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Basic outer joins ------------------------------------------------------------------------
|
||||
|
||||
testOuterJoin(
|
||||
|
|
|
@ -17,27 +17,61 @@
|
|||
|
||||
package org.apache.spark.sql.execution.joins
|
||||
|
||||
import org.apache.spark.sql.{SQLConf, DataFrame, Row}
|
||||
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
|
||||
import org.apache.spark.sql.catalyst.plans.Inner
|
||||
import org.apache.spark.sql.catalyst.plans.logical.Join
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
|
||||
import org.apache.spark.sql.{SQLConf, DataFrame, Row}
|
||||
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
|
||||
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
|
||||
|
||||
class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {
|
||||
class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
|
||||
|
||||
private lazy val left = ctx.createDataFrame(
|
||||
ctx.sparkContext.parallelize(Seq(
|
||||
Row(1, 2.0),
|
||||
Row(1, 2.0),
|
||||
Row(2, 1.0),
|
||||
Row(2, 1.0),
|
||||
Row(3, 3.0),
|
||||
Row(null, null),
|
||||
Row(null, 5.0),
|
||||
Row(6, null)
|
||||
)), new StructType().add("a", IntegerType).add("b", DoubleType))
|
||||
|
||||
private lazy val right = ctx.createDataFrame(
|
||||
ctx.sparkContext.parallelize(Seq(
|
||||
Row(2, 3.0),
|
||||
Row(2, 3.0),
|
||||
Row(3, 2.0),
|
||||
Row(4, 1.0),
|
||||
Row(null, null),
|
||||
Row(null, 5.0),
|
||||
Row(6, null)
|
||||
)), new StructType().add("c", IntegerType).add("d", DoubleType))
|
||||
|
||||
private lazy val condition = {
|
||||
And((left.col("a") === right.col("c")).expr,
|
||||
LessThan(left.col("b").expr, right.col("d").expr))
|
||||
}
|
||||
|
||||
// Note: the input dataframes and expression must be evaluated lazily because
|
||||
// the SQLContext should be used only within a test to keep SQL tests stable
|
||||
private def testLeftSemiJoin(
|
||||
testName: String,
|
||||
leftRows: DataFrame,
|
||||
rightRows: DataFrame,
|
||||
condition: Expression,
|
||||
leftRows: => DataFrame,
|
||||
rightRows: => DataFrame,
|
||||
condition: => Expression,
|
||||
expectedAnswer: Seq[Product]): Unit = {
|
||||
|
||||
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
|
||||
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
|
||||
ExtractEquiJoinKeys.unapply(join).foreach {
|
||||
case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
|
||||
ExtractEquiJoinKeys.unapply(join)
|
||||
}
|
||||
|
||||
test(s"$testName using LeftSemiJoinHash") {
|
||||
extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
EnsureRequirements(left.sqlContext).apply(
|
||||
|
@ -46,8 +80,10 @@ class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {
|
|||
sortAnswers = true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$testName using BroadcastLeftSemiJoinHash") {
|
||||
extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
|
||||
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
|
||||
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
|
||||
BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
|
||||
|
@ -67,33 +103,6 @@ class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {
|
|||
}
|
||||
}
|
||||
|
||||
val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
|
||||
Row(1, 2.0),
|
||||
Row(1, 2.0),
|
||||
Row(2, 1.0),
|
||||
Row(2, 1.0),
|
||||
Row(3, 3.0),
|
||||
Row(null, null),
|
||||
Row(null, 5.0),
|
||||
Row(6, null)
|
||||
)), new StructType().add("a", IntegerType).add("b", DoubleType))
|
||||
|
||||
val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
|
||||
Row(2, 3.0),
|
||||
Row(2, 3.0),
|
||||
Row(3, 2.0),
|
||||
Row(4, 1.0),
|
||||
Row(null, null),
|
||||
Row(null, 5.0),
|
||||
Row(6, null)
|
||||
)), new StructType().add("c", IntegerType).add("d", DoubleType))
|
||||
|
||||
val condition = {
|
||||
And(
|
||||
(left.col("a") === right.col("c")).expr,
|
||||
LessThan(left.col("b").expr, right.col("d").expr))
|
||||
}
|
||||
|
||||
testLeftSemiJoin(
|
||||
"basic test",
|
||||
left,
|
||||
|
|
|
@ -28,17 +28,15 @@ import org.apache.spark.SparkFunSuite
|
|||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.execution.ui.SparkPlanGraph
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
||||
|
||||
override val sqlContext = TestSQLContext
|
||||
|
||||
import sqlContext.implicits._
|
||||
class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
test("LongSQLMetric should not box Long") {
|
||||
val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long")
|
||||
val l = SQLMetrics.createLongMetric(ctx.sparkContext, "long")
|
||||
val f = () => {
|
||||
l += 1L
|
||||
l.add(1L)
|
||||
|
@ -52,7 +50,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
|
||||
test("Normal accumulator should do boxing") {
|
||||
// We need this test to make sure BoxingFinder works.
|
||||
val l = TestSQLContext.sparkContext.accumulator(0L)
|
||||
val l = ctx.sparkContext.accumulator(0L)
|
||||
val f = () => { l += 1L }
|
||||
BoxingFinder.getClassReader(f.getClass).foreach { cl =>
|
||||
val boxingFinder = new BoxingFinder()
|
||||
|
@ -73,19 +71,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
df: DataFrame,
|
||||
expectedNumOfJobs: Int,
|
||||
expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = {
|
||||
val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet
|
||||
val previousExecutionIds = ctx.listener.executionIdToData.keySet
|
||||
df.collect()
|
||||
TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000)
|
||||
val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
|
||||
ctx.sparkContext.listenerBus.waitUntilEmpty(10000)
|
||||
val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds)
|
||||
assert(executionIds.size === 1)
|
||||
val executionId = executionIds.head
|
||||
val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs
|
||||
val jobs = ctx.listener.getExecution(executionId).get.jobs
|
||||
// Use "<=" because there is a race condition that we may miss some jobs
|
||||
// TODO Change it to "=" once we fix the race condition that missing the JobStarted event.
|
||||
assert(jobs.size <= expectedNumOfJobs)
|
||||
if (jobs.size == expectedNumOfJobs) {
|
||||
// If we can track all jobs, check the metric values
|
||||
val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId)
|
||||
val metricValues = ctx.listener.getExecutionMetrics(executionId)
|
||||
val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node =>
|
||||
expectedMetrics.contains(node.id)
|
||||
}.map { node =>
|
||||
|
@ -111,7 +109,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
SQLConf.TUNGSTEN_ENABLED.key -> "false") {
|
||||
// Assume the execution plan is
|
||||
// PhysicalRDD(nodeId = 1) -> Project(nodeId = 0)
|
||||
val df = TestData.person.select('name)
|
||||
val df = person.select('name)
|
||||
testSparkPlanMetrics(df, 1, Map(
|
||||
0L ->("Project", Map(
|
||||
"number of rows" -> 2L)))
|
||||
|
@ -126,7 +124,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
SQLConf.TUNGSTEN_ENABLED.key -> "true") {
|
||||
// Assume the execution plan is
|
||||
// PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0)
|
||||
val df = TestData.person.select('name)
|
||||
val df = person.select('name)
|
||||
testSparkPlanMetrics(df, 1, Map(
|
||||
0L ->("TungstenProject", Map(
|
||||
"number of rows" -> 2L)))
|
||||
|
@ -137,7 +135,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
test("Filter metrics") {
|
||||
// Assume the execution plan is
|
||||
// PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0)
|
||||
val df = TestData.person.filter('age < 25)
|
||||
val df = person.filter('age < 25)
|
||||
testSparkPlanMetrics(df, 1, Map(
|
||||
0L -> ("Filter", Map(
|
||||
"number of input rows" -> 2L,
|
||||
|
@ -152,7 +150,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
SQLConf.TUNGSTEN_ENABLED.key -> "false") {
|
||||
// Assume the execution plan is
|
||||
// ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0)
|
||||
val df = TestData.testData2.groupBy().count() // 2 partitions
|
||||
val df = testData2.groupBy().count() // 2 partitions
|
||||
testSparkPlanMetrics(df, 1, Map(
|
||||
2L -> ("Aggregate", Map(
|
||||
"number of input rows" -> 6L,
|
||||
|
@ -163,7 +161,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
)
|
||||
|
||||
// 2 partitions and each partition contains 2 keys
|
||||
val df2 = TestData.testData2.groupBy('a).count()
|
||||
val df2 = testData2.groupBy('a).count()
|
||||
testSparkPlanMetrics(df2, 1, Map(
|
||||
2L -> ("Aggregate", Map(
|
||||
"number of input rows" -> 6L,
|
||||
|
@ -185,7 +183,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
// Assume the execution plan is
|
||||
// ... -> SortBasedAggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) ->
|
||||
// SortBasedAggregate(nodeId = 0)
|
||||
val df = TestData.testData2.groupBy().count() // 2 partitions
|
||||
val df = testData2.groupBy().count() // 2 partitions
|
||||
testSparkPlanMetrics(df, 1, Map(
|
||||
2L -> ("SortBasedAggregate", Map(
|
||||
"number of input rows" -> 6L,
|
||||
|
@ -199,7 +197,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
// ... -> SortBasedAggregate(nodeId = 3) -> TungstenExchange(nodeId = 2)
|
||||
// -> ExternalSort(nodeId = 1)-> SortBasedAggregate(nodeId = 0)
|
||||
// 2 partitions and each partition contains 2 keys
|
||||
val df2 = TestData.testData2.groupBy('a).count()
|
||||
val df2 = testData2.groupBy('a).count()
|
||||
testSparkPlanMetrics(df2, 1, Map(
|
||||
3L -> ("SortBasedAggregate", Map(
|
||||
"number of input rows" -> 6L,
|
||||
|
@ -219,7 +217,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
// Assume the execution plan is
|
||||
// ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1)
|
||||
// -> TungstenAggregate(nodeId = 0)
|
||||
val df = TestData.testData2.groupBy().count() // 2 partitions
|
||||
val df = testData2.groupBy().count() // 2 partitions
|
||||
testSparkPlanMetrics(df, 1, Map(
|
||||
2L -> ("TungstenAggregate", Map(
|
||||
"number of input rows" -> 6L,
|
||||
|
@ -230,7 +228,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
)
|
||||
|
||||
// 2 partitions and each partition contains 2 keys
|
||||
val df2 = TestData.testData2.groupBy('a).count()
|
||||
val df2 = testData2.groupBy('a).count()
|
||||
testSparkPlanMetrics(df2, 1, Map(
|
||||
2L -> ("TungstenAggregate", Map(
|
||||
"number of input rows" -> 6L,
|
||||
|
@ -246,7 +244,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
// Because SortMergeJoin may skip different rows if the number of partitions is different, this
|
||||
// test should use the deterministic number of partitions.
|
||||
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
|
||||
val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
|
||||
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
|
||||
testDataForJoin.registerTempTable("testDataForJoin")
|
||||
withTempTable("testDataForJoin") {
|
||||
// Assume the execution plan is
|
||||
|
@ -268,7 +266,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
// Because SortMergeOuterJoin may skip different rows if the number of partitions is different,
|
||||
// this test should use the deterministic number of partitions.
|
||||
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
|
||||
val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
|
||||
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
|
||||
testDataForJoin.registerTempTable("testDataForJoin")
|
||||
withTempTable("testDataForJoin") {
|
||||
// Assume the execution plan is
|
||||
|
@ -314,7 +312,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
|
||||
test("ShuffledHashJoin metrics") {
|
||||
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") {
|
||||
val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
|
||||
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
|
||||
testDataForJoin.registerTempTable("testDataForJoin")
|
||||
withTempTable("testDataForJoin") {
|
||||
// Assume the execution plan is
|
||||
|
@ -390,7 +388,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
|
||||
test("BroadcastNestedLoopJoin metrics") {
|
||||
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
|
||||
val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
|
||||
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
|
||||
testDataForJoin.registerTempTable("testDataForJoin")
|
||||
withTempTable("testDataForJoin") {
|
||||
// Assume the execution plan is
|
||||
|
@ -458,7 +456,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
}
|
||||
|
||||
test("CartesianProduct metrics") {
|
||||
val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
|
||||
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
|
||||
testDataForJoin.registerTempTable("testDataForJoin")
|
||||
withTempTable("testDataForJoin") {
|
||||
// Assume the execution plan is
|
||||
|
@ -476,19 +474,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
|
|||
|
||||
test("save metrics") {
|
||||
withTempPath { file =>
|
||||
val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet
|
||||
val previousExecutionIds = ctx.listener.executionIdToData.keySet
|
||||
// Assume the execution plan is
|
||||
// PhysicalRDD(nodeId = 0)
|
||||
TestData.person.select('name).write.format("json").save(file.getAbsolutePath)
|
||||
TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000)
|
||||
val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
|
||||
person.select('name).write.format("json").save(file.getAbsolutePath)
|
||||
ctx.sparkContext.listenerBus.waitUntilEmpty(10000)
|
||||
val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds)
|
||||
assert(executionIds.size === 1)
|
||||
val executionId = executionIds.head
|
||||
val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs
|
||||
val jobs = ctx.listener.getExecution(executionId).get.jobs
|
||||
// Use "<=" because there is a race condition that we may miss some jobs
|
||||
// TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event.
|
||||
assert(jobs.size <= 1)
|
||||
val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId)
|
||||
val metricValues = ctx.listener.getExecutionMetrics(executionId)
|
||||
// Because "save" will create a new DataFrame internally, we cannot get the real metric id.
|
||||
// However, we still can check the value.
|
||||
assert(metricValues.values.toSeq === Seq(2L))
|
||||
|
|
|
@ -25,12 +25,12 @@ import org.apache.spark.sql.execution.metric.LongSQLMetricValue
|
|||
import org.apache.spark.scheduler._
|
||||
import org.apache.spark.sql.{DataFrame, SQLContext}
|
||||
import org.apache.spark.sql.execution.SQLExecution
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class SQLListenerSuite extends SparkFunSuite {
|
||||
class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
private def createTestDataFrame: DataFrame = {
|
||||
import TestSQLContext.implicits._
|
||||
Seq(
|
||||
(1, 1),
|
||||
(2, 2)
|
||||
|
@ -74,7 +74,7 @@ class SQLListenerSuite extends SparkFunSuite {
|
|||
}
|
||||
|
||||
test("basic") {
|
||||
val listener = new SQLListener(TestSQLContext)
|
||||
val listener = new SQLListener(ctx)
|
||||
val executionId = 0
|
||||
val df = createTestDataFrame
|
||||
val accumulatorIds =
|
||||
|
@ -212,7 +212,7 @@ class SQLListenerSuite extends SparkFunSuite {
|
|||
}
|
||||
|
||||
test("onExecutionEnd happens before onJobEnd(JobSucceeded)") {
|
||||
val listener = new SQLListener(TestSQLContext)
|
||||
val listener = new SQLListener(ctx)
|
||||
val executionId = 0
|
||||
val df = createTestDataFrame
|
||||
listener.onExecutionStart(
|
||||
|
@ -241,7 +241,7 @@ class SQLListenerSuite extends SparkFunSuite {
|
|||
}
|
||||
|
||||
test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") {
|
||||
val listener = new SQLListener(TestSQLContext)
|
||||
val listener = new SQLListener(ctx)
|
||||
val executionId = 0
|
||||
val df = createTestDataFrame
|
||||
listener.onExecutionStart(
|
||||
|
@ -281,7 +281,7 @@ class SQLListenerSuite extends SparkFunSuite {
|
|||
}
|
||||
|
||||
test("onExecutionEnd happens before onJobEnd(JobFailed)") {
|
||||
val listener = new SQLListener(TestSQLContext)
|
||||
val listener = new SQLListener(ctx)
|
||||
val executionId = 0
|
||||
val df = createTestDataFrame
|
||||
listener.onExecutionStart(
|
||||
|
|
|
@ -25,10 +25,13 @@ import org.h2.jdbc.JdbcSQLException
|
|||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
|
||||
class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
val url = "jdbc:h2:mem:testdb0"
|
||||
val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
|
||||
var conn: java.sql.Connection = null
|
||||
|
@ -42,10 +45,6 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
Some(StringType)
|
||||
}
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
import ctx.sql
|
||||
|
||||
before {
|
||||
Utils.classForName("org.h2.Driver")
|
||||
// Extra properties that will be specified for our database. We need these to test
|
||||
|
|
|
@ -23,11 +23,13 @@ import java.util.Properties
|
|||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.{SaveMode, Row}
|
||||
import org.apache.spark.sql.{Row, SaveMode}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
|
||||
class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
|
||||
|
||||
val url = "jdbc:h2:mem:testdb2"
|
||||
var conn: java.sql.Connection = null
|
||||
val url1 = "jdbc:h2:mem:testdb3"
|
||||
|
@ -37,10 +39,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
properties.setProperty("password", "testPass")
|
||||
properties.setProperty("rowId", "false")
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
|
||||
import ctx.implicits._
|
||||
import ctx.sql
|
||||
|
||||
before {
|
||||
Utils.classForName("org.h2.Driver")
|
||||
conn = DriverManager.getConnection(url)
|
||||
|
@ -58,14 +56,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
"create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
|
||||
conn1.commit()
|
||||
|
||||
ctx.sql(
|
||||
sql(
|
||||
s"""
|
||||
|CREATE TEMPORARY TABLE PEOPLE
|
||||
|USING org.apache.spark.sql.jdbc
|
||||
|OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass')
|
||||
""".stripMargin.replaceAll("\n", " "))
|
||||
|
||||
ctx.sql(
|
||||
sql(
|
||||
s"""
|
||||
|CREATE TEMPORARY TABLE PEOPLE1
|
||||
|USING org.apache.spark.sql.jdbc
|
||||
|
@ -144,14 +142,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
}
|
||||
|
||||
test("INSERT to JDBC Datasource") {
|
||||
ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
|
||||
sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
|
||||
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
|
||||
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
|
||||
}
|
||||
|
||||
test("INSERT to JDBC Datasource with overwrite") {
|
||||
ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
|
||||
ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
|
||||
sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
|
||||
sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
|
||||
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
|
||||
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
|
||||
}
|
||||
|
|
|
@ -19,28 +19,32 @@ package org.apache.spark.sql.sources
|
|||
|
||||
import java.io.{File, IOException}
|
||||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.execution.datasources.DDLException
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
|
||||
|
||||
import caseInsensitiveContext.sql
|
||||
|
||||
class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter {
|
||||
protected override lazy val sql = caseInsensitiveContext.sql _
|
||||
private lazy val sparkContext = caseInsensitiveContext.sparkContext
|
||||
|
||||
var path: File = null
|
||||
private var path: File = null
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
path = Utils.createTempDir()
|
||||
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
|
||||
caseInsensitiveContext.read.json(rdd).registerTempTable("jt")
|
||||
}
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
try {
|
||||
caseInsensitiveContext.dropTempTable("jt")
|
||||
} finally {
|
||||
super.afterAll()
|
||||
}
|
||||
}
|
||||
|
||||
after {
|
||||
|
|
|
@ -18,11 +18,12 @@
|
|||
package org.apache.spark.sql.sources
|
||||
|
||||
import org.apache.spark.sql.SQLContext
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types.{StringType, StructField, StructType}
|
||||
|
||||
|
||||
// please note that the META-INF/services had to be modified for the test directory for this to work
|
||||
class DDLSourceLoadSuite extends DataSourceTest {
|
||||
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {
|
||||
|
||||
test("data sources with the same name") {
|
||||
intercept[RuntimeException] {
|
||||
|
|
|
@ -20,6 +20,7 @@ package org.apache.spark.sql.sources
|
|||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
|
@ -68,10 +69,12 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo
|
|||
}
|
||||
}
|
||||
|
||||
class DDLTestSuite extends DataSourceTest {
|
||||
class DDLTestSuite extends DataSourceTest with SharedSQLContext {
|
||||
protected override lazy val sql = caseInsensitiveContext.sql _
|
||||
|
||||
before {
|
||||
caseInsensitiveContext.sql(
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
sql(
|
||||
"""
|
||||
|CREATE TEMPORARY TABLE ddlPeople
|
||||
|USING org.apache.spark.sql.sources.DDLScanSource
|
||||
|
@ -105,7 +108,7 @@ class DDLTestSuite extends DataSourceTest {
|
|||
))
|
||||
|
||||
test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") {
|
||||
val attributes = caseInsensitiveContext.sql("describe ddlPeople")
|
||||
val attributes = sql("describe ddlPeople")
|
||||
.queryExecution.executedPlan.output
|
||||
assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment"))
|
||||
assert(attributes.map(_.dataType).toSet === Set(StringType))
|
||||
|
|
|
@ -17,18 +17,23 @@
|
|||
|
||||
package org.apache.spark.sql.sources
|
||||
|
||||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
|
||||
|
||||
abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
|
||||
private[sql] abstract class DataSourceTest extends QueryTest {
|
||||
protected def _sqlContext: SQLContext
|
||||
|
||||
// We want to test some edge cases.
|
||||
protected implicit lazy val caseInsensitiveContext = {
|
||||
val ctx = new SQLContext(TestSQLContext.sparkContext)
|
||||
protected lazy val caseInsensitiveContext: SQLContext = {
|
||||
val ctx = new SQLContext(_sqlContext.sparkContext)
|
||||
ctx.setConf(SQLConf.CASE_SENSITIVE, false)
|
||||
ctx
|
||||
}
|
||||
|
||||
protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row]) {
|
||||
test(sqlString) {
|
||||
checkAnswer(caseInsensitiveContext.sql(sqlString), expectedAnswer)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import scala.language.existentials
|
|||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
||||
|
@ -96,11 +97,11 @@ object FiltersPushed {
|
|||
var list: Seq[Filter] = Nil
|
||||
}
|
||||
|
||||
class FilteredScanSuite extends DataSourceTest {
|
||||
class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
|
||||
protected override lazy val sql = caseInsensitiveContext.sql _
|
||||
|
||||
import caseInsensitiveContext.sql
|
||||
|
||||
before {
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
sql(
|
||||
"""
|
||||
|CREATE TEMPORARY TABLE oneToTenFiltered
|
||||
|
|
|
@ -19,20 +19,17 @@ package org.apache.spark.sql.sources
|
|||
|
||||
import java.io.File
|
||||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
import org.apache.spark.sql.{SaveMode, AnalysisException, Row}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
|
||||
|
||||
import caseInsensitiveContext.sql
|
||||
|
||||
class InsertSuite extends DataSourceTest with SharedSQLContext {
|
||||
protected override lazy val sql = caseInsensitiveContext.sql _
|
||||
private lazy val sparkContext = caseInsensitiveContext.sparkContext
|
||||
|
||||
var path: File = null
|
||||
private var path: File = null
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
path = Utils.createTempDir()
|
||||
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""))
|
||||
caseInsensitiveContext.read.json(rdd).registerTempTable("jt")
|
||||
|
@ -47,9 +44,13 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
|
|||
}
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
try {
|
||||
caseInsensitiveContext.dropTempTable("jsonTable")
|
||||
caseInsensitiveContext.dropTempTable("jt")
|
||||
Utils.deleteRecursively(path)
|
||||
} finally {
|
||||
super.afterAll()
|
||||
}
|
||||
}
|
||||
|
||||
test("Simple INSERT OVERWRITE a JSONRelation") {
|
||||
|
@ -221,9 +222,10 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
|
|||
sql("SELECT a * 2 FROM jsonTable"),
|
||||
(1 to 10).map(i => Row(i * 2)).toSeq)
|
||||
|
||||
assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2)
|
||||
checkAnswer(
|
||||
sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"),
|
||||
assertCached(sql(
|
||||
"SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2)
|
||||
checkAnswer(sql(
|
||||
"SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"),
|
||||
(2 to 10).map(i => Row(i, i - 1)).toSeq)
|
||||
|
||||
// Insert overwrite and keep the same schema.
|
||||
|
|
|
@ -19,21 +19,21 @@ package org.apache.spark.sql.sources
|
|||
|
||||
import org.apache.spark.sql.{Row, QueryTest}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.TestSQLContext
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class PartitionedWriteSuite extends QueryTest {
|
||||
import TestSQLContext.implicits._
|
||||
class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
test("write many partitions") {
|
||||
val path = Utils.createTempDir()
|
||||
path.delete()
|
||||
|
||||
val df = TestSQLContext.range(100).select($"id", lit(1).as("data"))
|
||||
val df = ctx.range(100).select($"id", lit(1).as("data"))
|
||||
df.write.partitionBy("id").save(path.getCanonicalPath)
|
||||
|
||||
checkAnswer(
|
||||
TestSQLContext.read.load(path.getCanonicalPath),
|
||||
ctx.read.load(path.getCanonicalPath),
|
||||
(0 to 99).map(Row(1, _)).toSeq)
|
||||
|
||||
Utils.deleteRecursively(path)
|
||||
|
@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest {
|
|||
val path = Utils.createTempDir()
|
||||
path.delete()
|
||||
|
||||
val base = TestSQLContext.range(100)
|
||||
val base = ctx.range(100)
|
||||
val df = base.unionAll(base).select($"id", lit(1).as("data"))
|
||||
df.write.partitionBy("id").save(path.getCanonicalPath)
|
||||
|
||||
checkAnswer(
|
||||
TestSQLContext.read.load(path.getCanonicalPath),
|
||||
ctx.read.load(path.getCanonicalPath),
|
||||
(0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq)
|
||||
|
||||
Utils.deleteRecursively(path)
|
||||
|
|
|
@ -21,6 +21,7 @@ import scala.language.existentials
|
|||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
class PrunedScanSource extends RelationProvider {
|
||||
|
@ -51,10 +52,12 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo
|
|||
}
|
||||
}
|
||||
|
||||
class PrunedScanSuite extends DataSourceTest {
|
||||
class PrunedScanSuite extends DataSourceTest with SharedSQLContext {
|
||||
protected override lazy val sql = caseInsensitiveContext.sql _
|
||||
|
||||
before {
|
||||
caseInsensitiveContext.sql(
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
sql(
|
||||
"""
|
||||
|CREATE TEMPORARY TABLE oneToTenPruned
|
||||
|USING org.apache.spark.sql.sources.PrunedScanSource
|
||||
|
@ -114,7 +117,7 @@ class PrunedScanSuite extends DataSourceTest {
|
|||
|
||||
def testPruning(sqlString: String, expectedColumns: String*): Unit = {
|
||||
test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") {
|
||||
val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution
|
||||
val queryExecution = sql(sqlString).queryExecution
|
||||
val rawPlan = queryExecution.executedPlan.collect {
|
||||
case p: execution.PhysicalRDD => p
|
||||
} match {
|
||||
|
|
|
@ -19,25 +19,22 @@ package org.apache.spark.sql.sources
|
|||
|
||||
import java.io.File
|
||||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
|
||||
|
||||
import caseInsensitiveContext.sql
|
||||
|
||||
class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter {
|
||||
protected override lazy val sql = caseInsensitiveContext.sql _
|
||||
private lazy val sparkContext = caseInsensitiveContext.sparkContext
|
||||
|
||||
var originalDefaultSource: String = null
|
||||
|
||||
var path: File = null
|
||||
|
||||
var df: DataFrame = null
|
||||
private var originalDefaultSource: String = null
|
||||
private var path: File = null
|
||||
private var df: DataFrame = null
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName
|
||||
|
||||
path = Utils.createTempDir()
|
||||
|
@ -49,11 +46,14 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
|
|||
}
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
try {
|
||||
caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
|
||||
} finally {
|
||||
super.afterAll()
|
||||
}
|
||||
}
|
||||
|
||||
after {
|
||||
caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
|
||||
Utils.deleteRecursively(path)
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
|
|||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
class DefaultSource extends SimpleScanSource
|
||||
|
@ -95,8 +96,8 @@ case class AllDataTypesScan(
|
|||
}
|
||||
}
|
||||
|
||||
class TableScanSuite extends DataSourceTest {
|
||||
import caseInsensitiveContext.sql
|
||||
class TableScanSuite extends DataSourceTest with SharedSQLContext {
|
||||
protected override lazy val sql = caseInsensitiveContext.sql _
|
||||
|
||||
private lazy val tableWithSchemaExpected = (1 to 10).map { i =>
|
||||
Row(
|
||||
|
@ -122,7 +123,8 @@ class TableScanSuite extends DataSourceTest {
|
|||
Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(Date.valueOf(s"1970-01-${i + 1}")))))
|
||||
}.toSeq
|
||||
|
||||
before {
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
sql(
|
||||
"""
|
||||
|CREATE TEMPORARY TABLE oneToTen
|
||||
|
@ -303,9 +305,10 @@ class TableScanSuite extends DataSourceTest {
|
|||
sql("SELECT i * 2 FROM oneToTen"),
|
||||
(1 to 10).map(i => Row(i * 2)).toSeq)
|
||||
|
||||
assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2)
|
||||
checkAnswer(
|
||||
sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"),
|
||||
assertCached(sql(
|
||||
"SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2)
|
||||
checkAnswer(sql(
|
||||
"SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"),
|
||||
(2 to 10).map(i => Row(i, i - 1)).toSeq)
|
||||
|
||||
// Verify uncaching
|
||||
|
|
|
@ -0,0 +1,290 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.test
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits}
|
||||
|
||||
/**
|
||||
* A collection of sample data used in SQL tests.
|
||||
*/
|
||||
private[sql] trait SQLTestData { self =>
|
||||
protected def _sqlContext: SQLContext
|
||||
|
||||
// Helper object to import SQL implicits without a concrete SQLContext
|
||||
private object internalImplicits extends SQLImplicits {
|
||||
protected override def _sqlContext: SQLContext = self._sqlContext
|
||||
}
|
||||
|
||||
import internalImplicits._
|
||||
import SQLTestData._
|
||||
|
||||
// Note: all test data should be lazy because the SQLContext is not set up yet.
|
||||
|
||||
protected lazy val testData: DataFrame = {
|
||||
val df = _sqlContext.sparkContext.parallelize(
|
||||
(1 to 100).map(i => TestData(i, i.toString))).toDF()
|
||||
df.registerTempTable("testData")
|
||||
df
|
||||
}
|
||||
|
||||
protected lazy val testData2: DataFrame = {
|
||||
val df = _sqlContext.sparkContext.parallelize(
|
||||
TestData2(1, 1) ::
|
||||
TestData2(1, 2) ::
|
||||
TestData2(2, 1) ::
|
||||
TestData2(2, 2) ::
|
||||
TestData2(3, 1) ::
|
||||
TestData2(3, 2) :: Nil, 2).toDF()
|
||||
df.registerTempTable("testData2")
|
||||
df
|
||||
}
|
||||
|
||||
protected lazy val testData3: DataFrame = {
|
||||
val df = _sqlContext.sparkContext.parallelize(
|
||||
TestData3(1, None) ::
|
||||
TestData3(2, Some(2)) :: Nil).toDF()
|
||||
df.registerTempTable("testData3")
|
||||
df
|
||||
}
|
||||
|
||||
protected lazy val negativeData: DataFrame = {
|
||||
val df = _sqlContext.sparkContext.parallelize(
|
||||
(1 to 100).map(i => TestData(-i, (-i).toString))).toDF()
|
||||
df.registerTempTable("negativeData")
|
||||
df
|
||||
}
|
||||
|
||||
protected lazy val largeAndSmallInts: DataFrame = {
|
||||
val df = _sqlContext.sparkContext.parallelize(
|
||||
LargeAndSmallInts(2147483644, 1) ::
|
||||
LargeAndSmallInts(1, 2) ::
|
||||
LargeAndSmallInts(2147483645, 1) ::
|
||||
LargeAndSmallInts(2, 2) ::
|
||||
LargeAndSmallInts(2147483646, 1) ::
|
||||
LargeAndSmallInts(3, 2) :: Nil).toDF()
|
||||
df.registerTempTable("largeAndSmallInts")
|
||||
df
|
||||
}
|
||||
|
||||
protected lazy val decimalData: DataFrame = {
|
||||
val df = _sqlContext.sparkContext.parallelize(
|
||||
DecimalData(1, 1) ::
|
||||
DecimalData(1, 2) ::
|
||||
DecimalData(2, 1) ::
|
||||
DecimalData(2, 2) ::
|
||||
DecimalData(3, 1) ::
|
||||
DecimalData(3, 2) :: Nil).toDF()
|
||||
df.registerTempTable("decimalData")
|
||||
df
|
||||
}
|
||||
|
||||
protected lazy val binaryData: DataFrame = {
|
||||
val df = _sqlContext.sparkContext.parallelize(
|
||||
BinaryData("12".getBytes, 1) ::
|
||||
BinaryData("22".getBytes, 5) ::
|
||||
BinaryData("122".getBytes, 3) ::
|
||||
BinaryData("121".getBytes, 2) ::
|
||||
BinaryData("123".getBytes, 4) :: Nil).toDF()
|
||||
df.registerTempTable("binaryData")
|
||||
df
|
||||
}
|
||||
|
||||
protected lazy val upperCaseData: DataFrame = {
|
||||
val df = _sqlContext.sparkContext.parallelize(
|
||||
UpperCaseData(1, "A") ::
|
||||
UpperCaseData(2, "B") ::
|
||||
UpperCaseData(3, "C") ::
|
||||
UpperCaseData(4, "D") ::
|
||||
UpperCaseData(5, "E") ::
|
||||
UpperCaseData(6, "F") :: Nil).toDF()
|
||||
df.registerTempTable("upperCaseData")
|
||||
df
|
||||
}
|
||||
|
||||
protected lazy val lowerCaseData: DataFrame = {
|
||||
val df = _sqlContext.sparkContext.parallelize(
|
||||
LowerCaseData(1, "a") ::
|
||||
LowerCaseData(2, "b") ::
|
||||
LowerCaseData(3, "c") ::
|
||||
LowerCaseData(4, "d") :: Nil).toDF()
|
||||
df.registerTempTable("lowerCaseData")
|
||||
df
|
||||
}
|
||||
|
||||
protected lazy val arrayData: RDD[ArrayData] = {
|
||||
val rdd = _sqlContext.sparkContext.parallelize(
|
||||
ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
|
||||
ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
|
||||
rdd.toDF().registerTempTable("arrayData")
|
||||
rdd
|
||||
}
|
||||
|
||||
protected lazy val mapData: RDD[MapData] = {
|
||||
val rdd = _sqlContext.sparkContext.parallelize(
|
||||
MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
|
||||
MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
|
||||
MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
|
||||
MapData(Map(1 -> "a4", 2 -> "b4")) ::
|
||||
MapData(Map(1 -> "a5")) :: Nil)
|
||||
rdd.toDF().registerTempTable("mapData")
|
||||
rdd
|
||||
}
|
||||
|
||||
protected lazy val repeatedData: RDD[StringData] = {
|
||||
val rdd = _sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
|
||||
rdd.toDF().registerTempTable("repeatedData")
|
||||
rdd
|
||||
}
|
||||
|
||||
protected lazy val nullableRepeatedData: RDD[StringData] = {
|
||||
val rdd = _sqlContext.sparkContext.parallelize(
|
||||
List.fill(2)(StringData(null)) ++
|
||||
List.fill(2)(StringData("test")))
|
||||
rdd.toDF().registerTempTable("nullableRepeatedData")
|
||||
rdd
|
||||
}
|
||||
|
||||
protected lazy val nullInts: DataFrame = {
|
||||
val df = _sqlContext.sparkContext.parallelize(
|
||||
NullInts(1) ::
|
||||
NullInts(2) ::
|
||||
NullInts(3) ::
|
||||
NullInts(null) :: Nil).toDF()
|
||||
df.registerTempTable("nullInts")
|
||||
df
|
||||
}
|
||||
|
||||
protected lazy val allNulls: DataFrame = {
|
||||
val df = _sqlContext.sparkContext.parallelize(
|
||||
NullInts(null) ::
|
||||
NullInts(null) ::
|
||||
NullInts(null) ::
|
||||
NullInts(null) :: Nil).toDF()
|
||||
df.registerTempTable("allNulls")
|
||||
df
|
||||
}
|
||||
|
||||
protected lazy val nullStrings: DataFrame = {
|
||||
val df = _sqlContext.sparkContext.parallelize(
|
||||
NullStrings(1, "abc") ::
|
||||
NullStrings(2, "ABC") ::
|
||||
NullStrings(3, null) :: Nil).toDF()
|
||||
df.registerTempTable("nullStrings")
|
||||
df
|
||||
}
|
||||
|
||||
protected lazy val tableName: DataFrame = {
|
||||
val df = _sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF()
|
||||
df.registerTempTable("tableName")
|
||||
df
|
||||
}
|
||||
|
||||
protected lazy val unparsedStrings: RDD[String] = {
|
||||
_sqlContext.sparkContext.parallelize(
|
||||
"1, A1, true, null" ::
|
||||
"2, B2, false, null" ::
|
||||
"3, C3, true, null" ::
|
||||
"4, D4, true, 2147483644" :: Nil)
|
||||
}
|
||||
|
||||
// An RDD with 4 elements and 8 partitions
|
||||
protected lazy val withEmptyParts: RDD[IntField] = {
|
||||
val rdd = _sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
|
||||
rdd.toDF().registerTempTable("withEmptyParts")
|
||||
rdd
|
||||
}
|
||||
|
||||
protected lazy val person: DataFrame = {
|
||||
val df = _sqlContext.sparkContext.parallelize(
|
||||
Person(0, "mike", 30) ::
|
||||
Person(1, "jim", 20) :: Nil).toDF()
|
||||
df.registerTempTable("person")
|
||||
df
|
||||
}
|
||||
|
||||
protected lazy val salary: DataFrame = {
|
||||
val df = _sqlContext.sparkContext.parallelize(
|
||||
Salary(0, 2000.0) ::
|
||||
Salary(1, 1000.0) :: Nil).toDF()
|
||||
df.registerTempTable("salary")
|
||||
df
|
||||
}
|
||||
|
||||
protected lazy val complexData: DataFrame = {
|
||||
val df = _sqlContext.sparkContext.parallelize(
|
||||
ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) ::
|
||||
ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) ::
|
||||
Nil).toDF()
|
||||
df.registerTempTable("complexData")
|
||||
df
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize all test data such that all temp tables are properly registered.
|
||||
*/
|
||||
def loadTestData(): Unit = {
|
||||
assert(_sqlContext != null, "attempted to initialize test data before SQLContext.")
|
||||
testData
|
||||
testData2
|
||||
testData3
|
||||
negativeData
|
||||
largeAndSmallInts
|
||||
decimalData
|
||||
binaryData
|
||||
upperCaseData
|
||||
lowerCaseData
|
||||
arrayData
|
||||
mapData
|
||||
repeatedData
|
||||
nullableRepeatedData
|
||||
nullInts
|
||||
allNulls
|
||||
nullStrings
|
||||
tableName
|
||||
unparsedStrings
|
||||
withEmptyParts
|
||||
person
|
||||
salary
|
||||
complexData
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Case classes used in test data.
|
||||
*/
|
||||
private[sql] object SQLTestData {
|
||||
case class TestData(key: Int, value: String)
|
||||
case class TestData2(a: Int, b: Int)
|
||||
case class TestData3(a: Int, b: Option[Int])
|
||||
case class LargeAndSmallInts(a: Int, b: Int)
|
||||
case class DecimalData(a: BigDecimal, b: BigDecimal)
|
||||
case class BinaryData(a: Array[Byte], b: Int)
|
||||
case class UpperCaseData(N: Int, L: String)
|
||||
case class LowerCaseData(n: Int, l: String)
|
||||
case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
|
||||
case class MapData(data: scala.collection.Map[Int, String])
|
||||
case class StringData(s: String)
|
||||
case class IntField(i: Int)
|
||||
case class NullInts(a: Integer)
|
||||
case class NullStrings(n: Int, s: String)
|
||||
case class TableName(tableName: String)
|
||||
case class Person(id: Int, name: String, age: Int)
|
||||
case class Salary(personId: Int, salary: Double)
|
||||
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
|
||||
}
|
|
@ -21,15 +21,71 @@ import java.io.File
|
|||
import java.util.UUID
|
||||
|
||||
import scala.util.Try
|
||||
import scala.language.implicitConversions
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.SQLContext
|
||||
import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
trait SQLTestUtils { this: SparkFunSuite =>
|
||||
protected def sqlContext: SQLContext
|
||||
/**
|
||||
* Helper trait that should be extended by all SQL test suites.
|
||||
*
|
||||
* This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data
|
||||
* prepared in advance as well as all implicit conversions used extensively by dataframes.
|
||||
* To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]].
|
||||
*
|
||||
* Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is
|
||||
* prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM.
|
||||
*/
|
||||
private[sql] trait SQLTestUtils
|
||||
extends SparkFunSuite
|
||||
with BeforeAndAfterAll
|
||||
with SQLTestData { self =>
|
||||
|
||||
protected def configuration = sqlContext.sparkContext.hadoopConfiguration
|
||||
protected def _sqlContext: SQLContext
|
||||
|
||||
// Whether to materialize all test data before the first test is run
|
||||
private var loadTestDataBeforeTests = false
|
||||
|
||||
// Shorthand for running a query using our SQLContext
|
||||
protected lazy val sql = _sqlContext.sql _
|
||||
|
||||
/**
|
||||
* A helper object for importing SQL implicits.
|
||||
*
|
||||
* Note that the alternative of importing `sqlContext.implicits._` is not possible here.
|
||||
* This is because we create the [[SQLContext]] immediately before the first test is run,
|
||||
* but the implicits import is needed in the constructor.
|
||||
*/
|
||||
protected object testImplicits extends SQLImplicits {
|
||||
protected override def _sqlContext: SQLContext = self._sqlContext
|
||||
}
|
||||
|
||||
/**
|
||||
* Materialize the test data immediately after the [[SQLContext]] is set up.
|
||||
* This is necessary if the data is accessed by name but not through direct reference.
|
||||
*/
|
||||
protected def setupTestData(): Unit = {
|
||||
loadTestDataBeforeTests = true
|
||||
}
|
||||
|
||||
protected override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
if (loadTestDataBeforeTests) {
|
||||
loadTestData()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The Hadoop configuration used by the active [[SQLContext]].
|
||||
*/
|
||||
protected def configuration: Configuration = {
|
||||
_sqlContext.sparkContext.hadoopConfiguration
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL
|
||||
|
@ -39,12 +95,12 @@ trait SQLTestUtils { this: SparkFunSuite =>
|
|||
*/
|
||||
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
|
||||
val (keys, values) = pairs.unzip
|
||||
val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption)
|
||||
(keys, values).zipped.foreach(sqlContext.conf.setConfString)
|
||||
val currentValues = keys.map(key => Try(_sqlContext.conf.getConfString(key)).toOption)
|
||||
(keys, values).zipped.foreach(_sqlContext.conf.setConfString)
|
||||
try f finally {
|
||||
keys.zip(currentValues).foreach {
|
||||
case (key, Some(value)) => sqlContext.conf.setConfString(key, value)
|
||||
case (key, None) => sqlContext.conf.unsetConf(key)
|
||||
case (key, Some(value)) => _sqlContext.conf.setConfString(key, value)
|
||||
case (key, None) => _sqlContext.conf.unsetConf(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -76,7 +132,7 @@ trait SQLTestUtils { this: SparkFunSuite =>
|
|||
* Drops temporary table `tableName` after calling `f`.
|
||||
*/
|
||||
protected def withTempTable(tableNames: String*)(f: => Unit): Unit = {
|
||||
try f finally tableNames.foreach(sqlContext.dropTempTable)
|
||||
try f finally tableNames.foreach(_sqlContext.dropTempTable)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -85,7 +141,7 @@ trait SQLTestUtils { this: SparkFunSuite =>
|
|||
protected def withTable(tableNames: String*)(f: => Unit): Unit = {
|
||||
try f finally {
|
||||
tableNames.foreach { name =>
|
||||
sqlContext.sql(s"DROP TABLE IF EXISTS $name")
|
||||
_sqlContext.sql(s"DROP TABLE IF EXISTS $name")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -98,12 +154,12 @@ trait SQLTestUtils { this: SparkFunSuite =>
|
|||
val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}"
|
||||
|
||||
try {
|
||||
sqlContext.sql(s"CREATE DATABASE $dbName")
|
||||
_sqlContext.sql(s"CREATE DATABASE $dbName")
|
||||
} catch { case cause: Throwable =>
|
||||
fail("Failed to create temporary database", cause)
|
||||
}
|
||||
|
||||
try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE")
|
||||
try f(dbName) finally _sqlContext.sql(s"DROP DATABASE $dbName CASCADE")
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -111,7 +167,15 @@ trait SQLTestUtils { this: SparkFunSuite =>
|
|||
* `f` returns.
|
||||
*/
|
||||
protected def activateDatabase(db: String)(f: => Unit): Unit = {
|
||||
sqlContext.sql(s"USE $db")
|
||||
try f finally sqlContext.sql(s"USE default")
|
||||
_sqlContext.sql(s"USE $db")
|
||||
try f finally _sqlContext.sql(s"USE default")
|
||||
}
|
||||
|
||||
/**
|
||||
* Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier
|
||||
* way to construct [[DataFrame]] directly out of local data without relying on implicits.
|
||||
*/
|
||||
protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
|
||||
DataFrame(_sqlContext, plan)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.test
|
||||
|
||||
import org.apache.spark.sql.SQLContext
|
||||
|
||||
|
||||
/**
|
||||
* Helper trait for SQL test suites where all tests share a single [[TestSQLContext]].
|
||||
*/
|
||||
private[sql] trait SharedSQLContext extends SQLTestUtils {
|
||||
|
||||
/**
|
||||
* The [[TestSQLContext]] to use for all tests in this suite.
|
||||
*
|
||||
* By default, the underlying [[org.apache.spark.SparkContext]] will be run in local
|
||||
* mode with the default test configurations.
|
||||
*/
|
||||
private var _ctx: TestSQLContext = null
|
||||
|
||||
/**
|
||||
* The [[TestSQLContext]] to use for all tests in this suite.
|
||||
*/
|
||||
protected def ctx: TestSQLContext = _ctx
|
||||
protected def sqlContext: TestSQLContext = _ctx
|
||||
protected override def _sqlContext: SQLContext = _ctx
|
||||
|
||||
/**
|
||||
* Initialize the [[TestSQLContext]].
|
||||
*/
|
||||
protected override def beforeAll(): Unit = {
|
||||
if (_ctx == null) {
|
||||
_ctx = new TestSQLContext
|
||||
}
|
||||
// Ensure we have initialized the context before calling parent code
|
||||
super.beforeAll()
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the underlying [[org.apache.spark.SparkContext]], if any.
|
||||
*/
|
||||
protected override def afterAll(): Unit = {
|
||||
try {
|
||||
if (_ctx != null) {
|
||||
_ctx.sparkContext.stop()
|
||||
_ctx = null
|
||||
}
|
||||
} finally {
|
||||
super.afterAll()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -17,40 +17,36 @@
|
|||
|
||||
package org.apache.spark.sql.test
|
||||
|
||||
import scala.language.implicitConversions
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.{SQLConf, SQLContext}
|
||||
|
||||
/** A SQLContext that can be used for local testing. */
|
||||
class LocalSQLContext
|
||||
extends SQLContext(
|
||||
new SparkContext("local[2]", "TestSQLContext", new SparkConf()
|
||||
.set("spark.sql.testkey", "true")
|
||||
// SPARK-8910
|
||||
.set("spark.ui.enabled", "false"))) {
|
||||
|
||||
override protected[sql] def createSession(): SQLSession = {
|
||||
new this.SQLSession()
|
||||
/**
|
||||
* A special [[SQLContext]] prepared for testing.
|
||||
*/
|
||||
private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self =>
|
||||
|
||||
def this() {
|
||||
this(new SparkContext("local[2]", "test-sql-context",
|
||||
new SparkConf().set("spark.sql.testkey", "true")))
|
||||
}
|
||||
|
||||
// Use fewer partitions to speed up testing
|
||||
protected[sql] override def createSession(): SQLSession = new this.SQLSession()
|
||||
|
||||
/** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */
|
||||
protected[sql] class SQLSession extends super.SQLSession {
|
||||
protected[sql] override lazy val conf: SQLConf = new SQLConf {
|
||||
/** Fewer partitions to speed up testing. */
|
||||
override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to
|
||||
* construct [[DataFrame]] directly out of local data without relying on implicits.
|
||||
*/
|
||||
protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
|
||||
DataFrame(this, plan)
|
||||
// Needed for Java tests
|
||||
def loadTestData(): Unit = {
|
||||
testData.loadTestData()
|
||||
}
|
||||
|
||||
private object testData extends SQLTestData {
|
||||
protected override def _sqlContext: SQLContext = self
|
||||
}
|
||||
}
|
||||
|
||||
object TestSQLContext extends LocalSQLContext
|
||||
|
|
@ -27,7 +27,6 @@ import org.scalatest.concurrent.Eventually._
|
|||
import org.scalatest.selenium.WebBrowser
|
||||
import org.scalatest.time.SpanSugar._
|
||||
|
||||
import org.apache.spark.sql.hive.HiveContext
|
||||
import org.apache.spark.ui.SparkUICssErrorHandler
|
||||
|
||||
class UISeleniumSuite
|
||||
|
@ -36,7 +35,6 @@ class UISeleniumSuite
|
|||
|
||||
implicit var webDriver: WebDriver = _
|
||||
var server: HiveThriftServer2 = _
|
||||
var hc: HiveContext = _
|
||||
val uiPort = 20000 + Random.nextInt(10000)
|
||||
override def mode: ServerMode.Value = ServerMode.binary
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._
|
|||
import org.apache.spark.sql.sources.DataSourceTest
|
||||
import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils}
|
||||
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
|
||||
import org.apache.spark.sql.{Row, SaveMode}
|
||||
import org.apache.spark.sql.{Row, SaveMode, SQLContext}
|
||||
import org.apache.spark.{Logging, SparkFunSuite}
|
||||
|
||||
|
||||
|
@ -53,7 +53,8 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging {
|
|||
}
|
||||
|
||||
class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils {
|
||||
override val sqlContext = TestHive
|
||||
override def _sqlContext: SQLContext = TestHive
|
||||
import testImplicits._
|
||||
|
||||
private val testDF = range(1, 3).select(
|
||||
('id + 0.1) cast DecimalType(10, 3) as 'd1,
|
||||
|
|
|
@ -19,14 +19,13 @@ package org.apache.spark.sql.hive
|
|||
|
||||
import org.apache.spark.sql.hive.test.TestHive
|
||||
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
|
||||
import org.apache.spark.sql.{QueryTest, Row}
|
||||
import org.apache.spark.sql.{QueryTest, Row, SQLContext}
|
||||
|
||||
case class Cases(lower: String, UPPER: String)
|
||||
|
||||
class HiveParquetSuite extends QueryTest with ParquetTest {
|
||||
val sqlContext = TestHive
|
||||
|
||||
import sqlContext._
|
||||
private val ctx = TestHive
|
||||
override def _sqlContext: SQLContext = ctx
|
||||
|
||||
test("Case insensitive attribute names") {
|
||||
withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") {
|
||||
|
@ -54,7 +53,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest {
|
|||
test("Converting Hive to Parquet Table via saveAsParquetFile") {
|
||||
withTempPath { dir =>
|
||||
sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath)
|
||||
read.parquet(dir.getCanonicalPath).registerTempTable("p")
|
||||
ctx.read.parquet(dir.getCanonicalPath).registerTempTable("p")
|
||||
withTempTable("p") {
|
||||
checkAnswer(
|
||||
sql("SELECT * FROM src ORDER BY key"),
|
||||
|
@ -67,7 +66,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest {
|
|||
withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") {
|
||||
withTempPath { file =>
|
||||
sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath)
|
||||
read.parquet(file.getCanonicalPath).registerTempTable("p")
|
||||
ctx.read.parquet(file.getCanonicalPath).registerTempTable("p")
|
||||
withTempTable("p") {
|
||||
// let's do three overwrites for good measure
|
||||
sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
|
||||
|
|
|
@ -22,7 +22,6 @@ import java.io.{IOException, File}
|
|||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.hadoop.mapred.InvalidInputException
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
import org.apache.spark.Logging
|
||||
|
@ -42,7 +41,8 @@ import org.apache.spark.util.Utils
|
|||
*/
|
||||
class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll
|
||||
with Logging {
|
||||
override val sqlContext = TestHive
|
||||
override def _sqlContext: SQLContext = TestHive
|
||||
private val sqlContext = _sqlContext
|
||||
|
||||
var jsonFilePath: String = _
|
||||
|
||||
|
|
|
@ -22,9 +22,8 @@ import org.apache.spark.sql.test.SQLTestUtils
|
|||
import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode}
|
||||
|
||||
class MultiDatabaseSuite extends QueryTest with SQLTestUtils {
|
||||
override val sqlContext: SQLContext = TestHive
|
||||
|
||||
import sqlContext.sql
|
||||
override val _sqlContext: SQLContext = TestHive
|
||||
private val sqlContext = _sqlContext
|
||||
|
||||
private val df = sqlContext.range(10).coalesce(1)
|
||||
|
||||
|
|
|
@ -26,7 +26,8 @@ import org.apache.spark.sql.{Row, SQLConf, SQLContext}
|
|||
class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest {
|
||||
import ParquetCompatibilityTest.makeNullable
|
||||
|
||||
override val sqlContext: SQLContext = TestHive
|
||||
override def _sqlContext: SQLContext = TestHive
|
||||
private val sqlContext = _sqlContext
|
||||
|
||||
/**
|
||||
* Set the staging directory (and hence path to ignore Parquet files under)
|
||||
|
|
|
@ -17,14 +17,12 @@
|
|||
|
||||
package org.apache.spark.sql.hive
|
||||
|
||||
import org.apache.spark.sql.{Row, QueryTest}
|
||||
import org.apache.spark.sql.QueryTest
|
||||
|
||||
case class FunctionResult(f1: String, f2: String)
|
||||
|
||||
class UDFSuite extends QueryTest {
|
||||
|
||||
private lazy val ctx = org.apache.spark.sql.hive.test.TestHive
|
||||
import ctx.implicits._
|
||||
|
||||
test("UDF case insensitive") {
|
||||
ctx.udf.register("random0", () => { Math.random() })
|
||||
|
|
|
@ -17,17 +17,18 @@
|
|||
|
||||
package org.apache.spark.sql.hive.execution
|
||||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.execution.aggregate
|
||||
import org.apache.spark.sql.hive.test.TestHive
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
|
||||
|
||||
abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
|
||||
|
||||
override val sqlContext = TestHive
|
||||
override def _sqlContext: SQLContext = TestHive
|
||||
protected val sqlContext = _sqlContext
|
||||
import sqlContext.implicits._
|
||||
|
||||
var originalUseAggregate2: Boolean = _
|
||||
|
|
|
@ -26,8 +26,8 @@ import org.apache.spark.sql.test.SQLTestUtils
|
|||
* A set of tests that validates support for Hive Explain command.
|
||||
*/
|
||||
class HiveExplainSuite extends QueryTest with SQLTestUtils {
|
||||
|
||||
def sqlContext: SQLContext = TestHive
|
||||
override def _sqlContext: SQLContext = TestHive
|
||||
private val sqlContext = _sqlContext
|
||||
|
||||
test("explain extended command") {
|
||||
checkExistence(sql(" explain select * from src where key=123 "), true,
|
||||
|
|
|
@ -66,7 +66,8 @@ class MyDialect extends DefaultParserDialect
|
|||
* valid, but Hive currently cannot execute it.
|
||||
*/
|
||||
class SQLQuerySuite extends QueryTest with SQLTestUtils {
|
||||
override def sqlContext: SQLContext = TestHive
|
||||
override def _sqlContext: SQLContext = TestHive
|
||||
private val sqlContext = _sqlContext
|
||||
|
||||
test("UDTF") {
|
||||
sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}")
|
||||
|
|
|
@ -31,7 +31,8 @@ import org.apache.spark.sql.types.StringType
|
|||
|
||||
class ScriptTransformationSuite extends SparkPlanTest {
|
||||
|
||||
override def sqlContext: SQLContext = TestHive
|
||||
override def _sqlContext: SQLContext = TestHive
|
||||
private val sqlContext = _sqlContext
|
||||
|
||||
private val noSerdeIOSchema = HiveScriptIOSchema(
|
||||
inputRowFormat = Seq.empty,
|
||||
|
|
|
@ -27,8 +27,8 @@ import org.apache.spark.sql._
|
|||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
|
||||
private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite =>
|
||||
lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive
|
||||
|
||||
protected override def _sqlContext: SQLContext = org.apache.spark.sql.hive.test.TestHive
|
||||
protected val sqlContext = _sqlContext
|
||||
import sqlContext.implicits._
|
||||
import sqlContext.sparkContext
|
||||
|
||||
|
|
|
@ -685,7 +685,8 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
|
|||
* A collection of tests for parquet data with various forms of partitioning.
|
||||
*/
|
||||
abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
|
||||
override def sqlContext: SQLContext = TestHive
|
||||
override def _sqlContext: SQLContext = TestHive
|
||||
protected val sqlContext = _sqlContext
|
||||
|
||||
var partitionedTableDir: File = null
|
||||
var normalTableDir: File = null
|
||||
|
|
|
@ -18,14 +18,16 @@
|
|||
package org.apache.spark.sql.sources
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.spark.deploy.SparkHadoopUtil
|
||||
import org.apache.spark.{SparkException, SparkFunSuite}
|
||||
import org.apache.spark.deploy.SparkHadoopUtil
|
||||
import org.apache.spark.sql.SQLContext
|
||||
import org.apache.spark.sql.hive.test.TestHive
|
||||
import org.apache.spark.sql.test.SQLTestUtils
|
||||
|
||||
|
||||
class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils {
|
||||
override val sqlContext = TestHive
|
||||
override def _sqlContext: SQLContext = TestHive
|
||||
private val sqlContext = _sqlContext
|
||||
|
||||
// When committing a task, `CommitFailureTestSource` throws an exception for testing purpose.
|
||||
val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName
|
||||
|
|
|
@ -34,9 +34,8 @@ import org.apache.spark.sql.types._
|
|||
|
||||
|
||||
abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
|
||||
override lazy val sqlContext: SQLContext = TestHive
|
||||
|
||||
import sqlContext.sql
|
||||
override def _sqlContext: SQLContext = TestHive
|
||||
protected val sqlContext = _sqlContext
|
||||
import sqlContext.implicits._
|
||||
|
||||
val dataSourceName: String
|
||||
|
|
Loading…
Reference in a new issue