[SPARK-10613] [SPARK-10624] [SQL] Reduce LocalNode tests dependency on SQLContext

Instead of relying on `DataFrames` to verify our answers, we can just use simple arrays. This significantly simplifies the test logic for `LocalNode`s and reduces a lot of code duplicated from `SparkPlanTest`.

This also fixes an additional issue [SPARK-10624](https://issues.apache.org/jira/browse/SPARK-10624) where the output of `TakeOrderedAndProjectNode` is not actually ordered.

Author: Andrew Or <andrew@databricks.com>

Closes #8764 from andrewor14/sql-local-tests-cleanup.
This commit is contained in:
Andrew Or 2015-09-15 17:24:32 -07:00
parent 38700ea40c
commit 35a19f3357
17 changed files with 483 additions and 651 deletions

View file

@ -24,7 +24,7 @@ import org.apache.spark.sql.{SQLConf, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.types.StructType
/**
@ -33,18 +33,14 @@ import org.apache.spark.sql.types.StructType
* Before consuming the iterator, open function must be called.
* After consuming the iterator, close function must be called.
*/
abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging {
abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging {
protected val codegenEnabled: Boolean = conf.codegenEnabled
protected val unsafeEnabled: Boolean = conf.unsafeEnabled
lazy val schema: StructType = StructType.fromAttributes(output)
private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing")
def output: Seq[Attribute]
/**
* Called before open(). Prepare can be used to reserve memory needed. It must NOT consume
* any input data.

View file

@ -17,13 +17,12 @@
package org.apache.spark.sql.execution.local
import java.util.Random
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
/**
* Sample the dataset.
*
@ -51,18 +50,15 @@ case class SampleNode(
override def open(): Unit = {
child.open()
val (sampler, _seed) = if (withReplacement) {
val random = new Random(seed)
val sampler =
if (withReplacement) {
// Disable gap sampling since the gap sampling method buffers two rows internally,
// requiring us to copy the row, which is more expensive than the random number generator.
(new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false),
// Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result
// of DataFrame
random.nextLong())
new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false)
} else {
(new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed)
new BernoulliCellSampler[InternalRow](lowerBound, upperBound)
}
sampler.setSeed(_seed)
sampler.setSeed(seed)
iterator = sampler.sample(child.asIterator)
}

View file

@ -50,7 +50,7 @@ case class TakeOrderedAndProjectNode(
}
// Close it eagerly since we don't need it.
child.close()
iterator = queue.iterator
iterator = queue.toArray.sorted(ord).iterator
}
override def next(): Boolean = {

View file

@ -238,7 +238,7 @@ object SparkPlanTest {
outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
plan.transformExpressions {
plan transformExpressions {
case UnresolvedAttribute(Seq(u)) =>
inputMap.getOrElse(u,
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))

View file

@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.local
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
/**
* A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]].
*/
private[local] case class DummyNode(
output: Seq[Attribute],
relation: LocalRelation,
conf: SQLConf)
extends LocalNode(conf) {
import DummyNode._
private var index: Int = CLOSED
private val input: Seq[InternalRow] = relation.data
def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) {
this(output, LocalRelation.fromProduct(output, data), conf)
}
def isOpen: Boolean = index != CLOSED
override def children: Seq[LocalNode] = Seq.empty
override def open(): Unit = {
index = -1
}
override def next(): Boolean = {
index += 1
index < input.size
}
override def fetch(): InternalRow = {
assert(index >= 0 && index < input.size)
input(index)
}
override def close(): Unit = {
index = CLOSED
}
}
private object DummyNode {
val CLOSED: Int = Int.MinValue
}

View file

@ -17,35 +17,33 @@
package org.apache.spark.sql.execution.local
import org.apache.spark.sql.catalyst.dsl.expressions._
class ExpandNodeSuite extends LocalNodeTest {
import testImplicits._
test("expand") {
val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value")
checkAnswer(
input,
node =>
ExpandNode(conf, Seq(
Seq(
input.col("key") + input.col("value"), input.col("key") - input.col("value")
).map(_.expr),
Seq(
input.col("key") * input.col("value"), input.col("key") / input.col("value")
).map(_.expr)
), node.output, node),
Seq(
(2, 0),
(1, 1),
(4, 0),
(4, 1),
(6, 0),
(9, 1),
(8, 0),
(16, 1),
(10, 0),
(25, 1)
).toDF().collect()
)
private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = {
val inputNode = new DummyNode(kvIntAttributes, inputData)
val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v))
val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode)
val resolvedNode = resolveExpressions(expandNode)
val expectedOutput = {
val firstHalf = inputData.map { case (k, v) => (k + v, k - v) }
val secondHalf = inputData.map { case (k, v) => (k * v, k / v) }
firstHalf ++ secondHalf
}
val actualOutput = resolvedNode.collect().map { case row =>
(row.getInt(0), row.getInt(1))
}
assert(actualOutput.toSet === expectedOutput.toSet)
}
test("empty") {
testExpand()
}
test("basic") {
testExpand((1 to 100).map { i => (i, i * 1000) }.toArray)
}
}

View file

@ -17,25 +17,29 @@
package org.apache.spark.sql.execution.local
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.catalyst.dsl.expressions._
class FilterNodeSuite extends LocalNodeTest with SharedSQLContext {
test("basic") {
val condition = (testData.col("key") % 2) === 0
checkAnswer(
testData,
node => FilterNode(conf, condition.expr, node),
testData.filter(condition).collect()
)
class FilterNodeSuite extends LocalNodeTest {
private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = {
val cond = 'k % 2 === 0
val inputNode = new DummyNode(kvIntAttributes, inputData)
val filterNode = new FilterNode(conf, cond, inputNode)
val resolvedNode = resolveExpressions(filterNode)
val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 }
val actualOutput = resolvedNode.collect().map { case row =>
(row.getInt(0), row.getInt(1))
}
assert(actualOutput === expectedOutput)
}
test("empty") {
val condition = (emptyTestData.col("key") % 2) === 0
checkAnswer(
emptyTestData,
node => FilterNode(conf, condition.expr, node),
emptyTestData.filter(condition).collect()
)
testFilter()
}
test("basic") {
testFilter((1 to 100).map { i => (i, i) }.toArray)
}
}

View file

@ -18,99 +18,80 @@
package org.apache.spark.sql.execution.local
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
class HashJoinNodeSuite extends LocalNodeTest {
import testImplicits._
def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = {
test(s"$suiteName: inner join with one match per row") {
withSQLConf(confPairs: _*) {
checkAnswer2(
upperCaseData,
lowerCaseData,
wrapForUnsafe(
(node1, node2) => HashJoinNode(
conf,
Seq(upperCaseData.col("N").expr),
Seq(lowerCaseData.col("n").expr),
joins.BuildLeft,
node1,
node2)
),
upperCaseData.join(lowerCaseData, $"n" === $"N").collect()
)
}
}
test(s"$suiteName: inner join with multiple matches") {
withSQLConf(confPairs: _*) {
val x = testData2.where($"a" === 1).as("x")
val y = testData2.where($"a" === 1).as("y")
checkAnswer2(
x,
y,
wrapForUnsafe(
(node1, node2) => HashJoinNode(
conf,
Seq(x.col("a").expr),
Seq(y.col("a").expr),
joins.BuildLeft,
node1,
node2)
),
x.join(y).where($"x.a" === $"y.a").collect()
)
}
}
test(s"$suiteName: inner join, no matches") {
withSQLConf(confPairs: _*) {
val x = testData2.where($"a" === 1).as("x")
val y = testData2.where($"a" === 2).as("y")
checkAnswer2(
x,
y,
wrapForUnsafe(
(node1, node2) => HashJoinNode(
conf,
Seq(x.col("a").expr),
Seq(y.col("a").expr),
joins.BuildLeft,
node1,
node2)
),
Nil
)
}
}
test(s"$suiteName: big inner join, 4 matches per row") {
withSQLConf(confPairs: _*) {
val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData)
val bigDataX = bigData.as("x")
val bigDataY = bigData.as("y")
checkAnswer2(
bigDataX,
bigDataY,
wrapForUnsafe(
(node1, node2) =>
HashJoinNode(
conf,
Seq(bigDataX.col("key").expr),
Seq(bigDataY.col("key").expr),
joins.BuildLeft,
node1,
node2)
),
bigDataX.join(bigDataY).where($"x.key" === $"y.key").collect())
}
// Test all combinations of the two dimensions: with/out unsafe and build sides
private val maybeUnsafeAndCodegen = Seq(false, true)
private val buildSides = Seq(BuildLeft, BuildRight)
maybeUnsafeAndCodegen.foreach { unsafeAndCodegen =>
buildSides.foreach { buildSide =>
testJoin(unsafeAndCodegen, buildSide)
}
}
/**
* Test inner hash join with varying degrees of matches.
*/
private def testJoin(
unsafeAndCodegen: Boolean,
buildSide: BuildSide): Unit = {
val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe"
val testNamePrefix = s"$simpleOrUnsafe / $buildSide"
val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray
val conf = new SQLConf
conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen)
conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen)
// Actual test body
def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = {
val rightInputMap = rightInput.toMap
val leftNode = new DummyNode(joinNameAttributes, leftInput)
val rightNode = new DummyNode(joinNicknameAttributes, rightInput)
val makeNode = (node1: LocalNode, node2: LocalNode) => {
resolveExpressions(new HashJoinNode(
conf, Seq('id1), Seq('id2), buildSide, node1, node2))
}
val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode
val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
val expectedOutput = leftInput
.filter { case (k, _) => rightInputMap.contains(k) }
.map { case (k, v) => (k, v, k, rightInputMap(k)) }
val actualOutput = hashJoinNode.collect().map { row =>
// (id, name, id, nickname)
(row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))
}
assert(actualOutput === expectedOutput)
}
test(s"$testNamePrefix: empty") {
runTest(Array.empty, Array.empty)
runTest(someData, Array.empty)
runTest(Array.empty, someData)
}
test(s"$testNamePrefix: no matches") {
val someIrrelevantData = (10000 to 100100).map { i => (i, "piper" + i) }.toArray
runTest(someData, Array.empty)
runTest(Array.empty, someData)
runTest(someData, someIrrelevantData)
runTest(someIrrelevantData, someData)
}
test(s"$testNamePrefix: partial matches") {
val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray
runTest(someData, someOtherData)
runTest(someOtherData, someData)
}
test(s"$testNamePrefix: full matches") {
val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) }.toArray
runTest(someData, someSuperRelevantData)
runTest(someSuperRelevantData, someData)
}
}
joinSuite(
"general", SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false")
joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true")
}

View file

@ -17,19 +17,21 @@
package org.apache.spark.sql.execution.local
class IntersectNodeSuite extends LocalNodeTest {
import testImplicits._
test("basic") {
val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value")
checkAnswer2(
input1,
input2,
(node1, node2) => IntersectNode(conf, node1, node2),
input1.intersect(input2).collect()
)
val n = 100
val leftData = (1 to n).filter { i => i % 2 == 0 }.map { i => (i, i) }.toArray
val rightData = (1 to n).filter { i => i % 3 == 0 }.map { i => (i, i) }.toArray
val leftNode = new DummyNode(kvIntAttributes, leftData)
val rightNode = new DummyNode(kvIntAttributes, rightData)
val intersectNode = new IntersectNode(conf, leftNode, rightNode)
val expectedOutput = leftData.intersect(rightData)
val actualOutput = intersectNode.collect().map { case row =>
(row.getInt(0), row.getInt(1))
}
assert(actualOutput === expectedOutput)
}
}

View file

@ -17,23 +17,25 @@
package org.apache.spark.sql.execution.local
import org.apache.spark.sql.test.SharedSQLContext
class LimitNodeSuite extends LocalNodeTest with SharedSQLContext {
class LimitNodeSuite extends LocalNodeTest {
test("basic") {
checkAnswer(
testData,
node => LimitNode(conf, 10, node),
testData.limit(10).collect()
)
private def testLimit(inputData: Array[(Int, Int)] = Array.empty, limit: Int = 10): Unit = {
val inputNode = new DummyNode(kvIntAttributes, inputData)
val limitNode = new LimitNode(conf, limit, inputNode)
val expectedOutput = inputData.take(limit)
val actualOutput = limitNode.collect().map { case row =>
(row.getInt(0), row.getInt(1))
}
assert(actualOutput === expectedOutput)
}
test("empty") {
checkAnswer(
emptyTestData,
node => LimitNode(conf, 10, node),
emptyTestData.limit(10).collect()
)
testLimit()
}
test("basic") {
testLimit((1 to 100).map { i => (i, i) }.toArray, 20)
}
}

View file

@ -17,28 +17,24 @@
package org.apache.spark.sql.execution.local
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.IntegerType
class LocalNodeSuite extends SparkFunSuite {
private val data = (1 to 100).toArray
class LocalNodeSuite extends LocalNodeTest {
private val data = (1 to 100).map { i => (i, i) }.toArray
test("basic open, next, fetch, close") {
val node = new DummyLocalNode(data)
val node = new DummyNode(kvIntAttributes, data)
assert(!node.isOpen)
node.open()
assert(node.isOpen)
data.foreach { i =>
data.foreach { case (k, v) =>
assert(node.next())
// fetch should be idempotent
val fetched = node.fetch()
assert(node.fetch() === fetched)
assert(node.fetch() === fetched)
assert(node.fetch().numFields === 1)
assert(node.fetch().getInt(0) === i)
assert(node.fetch().numFields === 2)
assert(node.fetch().getInt(0) === k)
assert(node.fetch().getInt(1) === v)
}
assert(!node.next())
node.close()
@ -46,16 +42,17 @@ class LocalNodeSuite extends SparkFunSuite {
}
test("asIterator") {
val node = new DummyLocalNode(data)
val node = new DummyNode(kvIntAttributes, data)
val iter = node.asIterator
node.open()
data.foreach { i =>
data.foreach { case (k, v) =>
// hasNext should be idempotent
assert(iter.hasNext)
assert(iter.hasNext)
val item = iter.next()
assert(item.numFields === 1)
assert(item.getInt(0) === i)
assert(item.numFields === 2)
assert(item.getInt(0) === k)
assert(item.getInt(1) === v)
}
intercept[NoSuchElementException] {
iter.next()
@ -64,53 +61,13 @@ class LocalNodeSuite extends SparkFunSuite {
}
test("collect") {
val node = new DummyLocalNode(data)
val node = new DummyNode(kvIntAttributes, data)
node.open()
val collected = node.collect()
assert(collected.size === data.size)
assert(collected.forall(_.size === 1))
assert(collected.map(_.getInt(0)) === data)
assert(collected.forall(_.size === 2))
assert(collected.map { case row => (row.getInt(0), row.getInt(0)) } === data)
node.close()
}
}
/**
* A dummy [[LocalNode]] that just returns one row per integer in the input.
*/
private case class DummyLocalNode(conf: SQLConf, input: Array[Int]) extends LocalNode(conf) {
private var index = Int.MinValue
def this(input: Array[Int]) {
this(new SQLConf, input)
}
def isOpen: Boolean = {
index != Int.MinValue
}
override def output: Seq[Attribute] = {
Seq(AttributeReference("something", IntegerType)())
}
override def children: Seq[LocalNode] = Seq.empty
override def open(): Unit = {
index = -1
}
override def next(): Boolean = {
index += 1
index < input.size
}
override def fetch(): InternalRow = {
assert(index >= 0 && index < input.size)
val values = Array(input(index).asInstanceOf[Any])
new GenericInternalRow(values)
}
override def close(): Unit = {
index = Int.MinValue
}
}

View file

@ -17,147 +17,54 @@
package org.apache.spark.sql.execution.local
import scala.util.control.NonFatal
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row, SQLConf}
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.types.{IntegerType, StringType}
class LocalNodeTest extends SparkFunSuite with SharedSQLContext {
def conf: SQLConf = sqlContext.conf
class LocalNodeTest extends SparkFunSuite {
protected val conf: SQLConf = new SQLConf
protected val kvIntAttributes = Seq(
AttributeReference("k", IntegerType)(),
AttributeReference("v", IntegerType)())
protected val joinNameAttributes = Seq(
AttributeReference("id1", IntegerType)(),
AttributeReference("name", StringType)())
protected val joinNicknameAttributes = Seq(
AttributeReference("id2", IntegerType)(),
AttributeReference("nickname", StringType)())
/**
* Wrap a function processing two [[LocalNode]]s such that:
* (1) all input rows are automatically converted to unsafe rows
* (2) all output rows are automatically converted back to safe rows
*/
protected def wrapForUnsafe(
f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = {
if (conf.unsafeEnabled) {
(left: LocalNode, right: LocalNode) => {
val _left = ConvertToUnsafeNode(conf, left)
val _right = ConvertToUnsafeNode(conf, right)
val r = f(_left, _right)
ConvertToSafeNode(conf, r)
}
} else {
f
(left: LocalNode, right: LocalNode) => {
val _left = ConvertToUnsafeNode(conf, left)
val _right = ConvertToUnsafeNode(conf, right)
val r = f(_left, _right)
ConvertToSafeNode(conf, r)
}
}
/**
* Runs the LocalNode and makes sure the answer matches the expected result.
* @param input the input data to be used.
* @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate
* the local physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
* Recursively resolve all expressions in a [[LocalNode]] using the node's attributes.
*/
protected def checkAnswer(
input: DataFrame,
nodeFunction: LocalNode => LocalNode,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
doCheckAnswer(
input :: Nil,
nodes => nodeFunction(nodes.head),
expectedAnswer,
sortAnswers)
}
/**
* Runs the LocalNode and makes sure the answer matches the expected result.
* @param left the left input data to be used.
* @param right the right input data to be used.
* @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate
* the local physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def checkAnswer2(
left: DataFrame,
right: DataFrame,
nodeFunction: (LocalNode, LocalNode) => LocalNode,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
doCheckAnswer(
left :: right :: Nil,
nodes => nodeFunction(nodes(0), nodes(1)),
expectedAnswer,
sortAnswers)
}
/**
* Runs the `LocalNode`s and makes sure the answer matches the expected result.
* @param input the input data to be used.
* @param nodeFunction a function which accepts a sequence of input `LocalNode`s and uses them to
* instantiate the local physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
protected def doCheckAnswer(
input: Seq[DataFrame],
nodeFunction: Seq[LocalNode] => LocalNode,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
LocalNodeTest.checkAnswer(
input.map(dataFrameToSeqScanNode), nodeFunction, expectedAnswer, sortAnswers) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
protected def resolveExpressions(outputNode: LocalNode): LocalNode = {
outputNode transform {
case node: LocalNode =>
val inputMap = node.output.map { a => (a.name, a) }.toMap
node transformExpressions {
case UnresolvedAttribute(Seq(u)) =>
inputMap.getOrElse(u,
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
}
}
}
protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = {
new SeqScanNode(
conf,
df.queryExecution.sparkPlan.output,
df.queryExecution.toRdd.map(_.copy()).collect())
}
}
/**
* Helper methods for writing tests of individual local physical operators.
*/
object LocalNodeTest {
/**
* Runs the `LocalNode`s and makes sure the answer matches the expected result.
* @param input the input data to be used.
* @param nodeFunction a function which accepts the input `LocalNode`s and uses them to
* instantiate the local physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
* to being compared.
*/
def checkAnswer(
input: Seq[SeqScanNode],
nodeFunction: Seq[LocalNode] => LocalNode,
expectedAnswer: Seq[Row],
sortAnswers: Boolean): Option[String] = {
val outputNode = nodeFunction(input)
val outputResult: Seq[Row] = try {
outputNode.collect()
} catch {
case NonFatal(e) =>
val errorMessage =
s"""
| Exception thrown while executing local plan:
| $outputNode
| == Exception ==
| $e
| ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
""".stripMargin
return Some(errorMessage)
}
SQLTestUtils.compareAnswers(outputResult, expectedAnswer, sortAnswers).map { errorMessage =>
s"""
| Results do not match for local plan:
| $outputNode
| $errorMessage
""".stripMargin
}
}
}

View file

@ -18,222 +18,128 @@
package org.apache.spark.sql.execution.local
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
class NestedLoopJoinNodeSuite extends LocalNodeTest {
import testImplicits._
private def joinSuite(
suiteName: String, buildSide: BuildSide, confPairs: (String, String)*): Unit = {
test(s"$suiteName: left outer join") {
withSQLConf(confPairs: _*) {
checkAnswer2(
upperCaseData,
lowerCaseData,
wrapForUnsafe(
(node1, node2) => NestedLoopJoinNode(
conf,
node1,
node2,
buildSide,
LeftOuter,
Some((upperCaseData.col("N") === lowerCaseData.col("n")).expr))
),
upperCaseData.join(lowerCaseData, $"n" === $"N", "left").collect())
checkAnswer2(
upperCaseData,
lowerCaseData,
wrapForUnsafe(
(node1, node2) => NestedLoopJoinNode(
conf,
node1,
node2,
buildSide,
LeftOuter,
Some(
(upperCaseData.col("N") === lowerCaseData.col("n") &&
lowerCaseData.col("n") > 1).expr))
),
upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left").collect())
checkAnswer2(
upperCaseData,
lowerCaseData,
wrapForUnsafe(
(node1, node2) => NestedLoopJoinNode(
conf,
node1,
node2,
buildSide,
LeftOuter,
Some(
(upperCaseData.col("N") === lowerCaseData.col("n") &&
upperCaseData.col("N") > 1).expr))
),
upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left").collect())
checkAnswer2(
upperCaseData,
lowerCaseData,
wrapForUnsafe(
(node1, node2) => NestedLoopJoinNode(
conf,
node1,
node2,
buildSide,
LeftOuter,
Some(
(upperCaseData.col("N") === lowerCaseData.col("n") &&
lowerCaseData.col("l") > upperCaseData.col("L")).expr))
),
upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left").collect())
}
}
test(s"$suiteName: right outer join") {
withSQLConf(confPairs: _*) {
checkAnswer2(
lowerCaseData,
upperCaseData,
wrapForUnsafe(
(node1, node2) => NestedLoopJoinNode(
conf,
node1,
node2,
buildSide,
RightOuter,
Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr))
),
lowerCaseData.join(upperCaseData, $"n" === $"N", "right").collect())
checkAnswer2(
lowerCaseData,
upperCaseData,
wrapForUnsafe(
(node1, node2) => NestedLoopJoinNode(
conf,
node1,
node2,
buildSide,
RightOuter,
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
lowerCaseData.col("n") > 1).expr))
),
lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right").collect())
checkAnswer2(
lowerCaseData,
upperCaseData,
wrapForUnsafe(
(node1, node2) => NestedLoopJoinNode(
conf,
node1,
node2,
buildSide,
RightOuter,
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
upperCaseData.col("N") > 1).expr))
),
lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right").collect())
checkAnswer2(
lowerCaseData,
upperCaseData,
wrapForUnsafe(
(node1, node2) => NestedLoopJoinNode(
conf,
node1,
node2,
buildSide,
RightOuter,
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
lowerCaseData.col("l") > upperCaseData.col("L")).expr))
),
lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right").collect())
}
}
test(s"$suiteName: full outer join") {
withSQLConf(confPairs: _*) {
checkAnswer2(
lowerCaseData,
upperCaseData,
wrapForUnsafe(
(node1, node2) => NestedLoopJoinNode(
conf,
node1,
node2,
buildSide,
FullOuter,
Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr))
),
lowerCaseData.join(upperCaseData, $"n" === $"N", "full").collect())
checkAnswer2(
lowerCaseData,
upperCaseData,
wrapForUnsafe(
(node1, node2) => NestedLoopJoinNode(
conf,
node1,
node2,
buildSide,
FullOuter,
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
lowerCaseData.col("n") > 1).expr))
),
lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "full").collect())
checkAnswer2(
lowerCaseData,
upperCaseData,
wrapForUnsafe(
(node1, node2) => NestedLoopJoinNode(
conf,
node1,
node2,
buildSide,
FullOuter,
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
upperCaseData.col("N") > 1).expr))
),
lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "full").collect())
checkAnswer2(
lowerCaseData,
upperCaseData,
wrapForUnsafe(
(node1, node2) => NestedLoopJoinNode(
conf,
node1,
node2,
buildSide,
FullOuter,
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
lowerCaseData.col("l") > upperCaseData.col("L")).expr))
),
lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "full").collect())
// Test all combinations of the three dimensions: with/out unsafe, build sides, and join types
private val maybeUnsafeAndCodegen = Seq(false, true)
private val buildSides = Seq(BuildLeft, BuildRight)
private val joinTypes = Seq(LeftOuter, RightOuter, FullOuter)
maybeUnsafeAndCodegen.foreach { unsafeAndCodegen =>
buildSides.foreach { buildSide =>
joinTypes.foreach { joinType =>
testJoin(unsafeAndCodegen, buildSide, joinType)
}
}
}
joinSuite(
"general-build-left",
BuildLeft,
SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false")
joinSuite(
"general-build-right",
BuildRight,
SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false")
joinSuite(
"tungsten-build-left",
BuildLeft,
SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true")
joinSuite(
"tungsten-build-right",
BuildRight,
SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true")
/**
* Test outer nested loop joins with varying degrees of matches.
*/
private def testJoin(
unsafeAndCodegen: Boolean,
buildSide: BuildSide,
joinType: JoinType): Unit = {
val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe"
val testNamePrefix = s"$simpleOrUnsafe / $buildSide / $joinType"
val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray
val conf = new SQLConf
conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen)
conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen)
// Actual test body
def runTest(
joinType: JoinType,
leftInput: Array[(Int, String)],
rightInput: Array[(Int, String)]): Unit = {
val leftNode = new DummyNode(joinNameAttributes, leftInput)
val rightNode = new DummyNode(joinNicknameAttributes, rightInput)
val cond = 'id1 === 'id2
val makeNode = (node1: LocalNode, node2: LocalNode) => {
resolveExpressions(
new NestedLoopJoinNode(conf, node1, node2, buildSide, joinType, Some(cond)))
}
val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode
val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType)
val actualOutput = hashJoinNode.collect().map { row =>
// (id, name, id, nickname)
(row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))
}
assert(actualOutput.toSet === expectedOutput.toSet)
}
test(s"$testNamePrefix: empty") {
runTest(joinType, Array.empty, Array.empty)
}
test(s"$testNamePrefix: no matches") {
val someIrrelevantData = (10000 to 10100).map { i => (i, "piper" + i) }.toArray
runTest(joinType, someData, Array.empty)
runTest(joinType, Array.empty, someData)
runTest(joinType, someData, someIrrelevantData)
runTest(joinType, someIrrelevantData, someData)
}
test(s"$testNamePrefix: partial matches") {
val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray
runTest(joinType, someData, someOtherData)
runTest(joinType, someOtherData, someData)
}
test(s"$testNamePrefix: full matches") {
val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) }
runTest(joinType, someData, someSuperRelevantData)
runTest(joinType, someSuperRelevantData, someData)
}
}
/**
* Helper method to generate the expected output of a test based on the join type.
*/
private def generateExpectedOutput(
leftInput: Array[(Int, String)],
rightInput: Array[(Int, String)],
joinType: JoinType): Array[(Int, String, Int, String)] = {
joinType match {
case LeftOuter =>
val rightInputMap = rightInput.toMap
leftInput.map { case (k, v) =>
val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0)
val rightValue = rightInputMap.getOrElse(k, null)
(k, v, rightKey, rightValue)
}
case RightOuter =>
val leftInputMap = leftInput.toMap
rightInput.map { case (k, v) =>
val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0)
val leftValue = leftInputMap.getOrElse(k, null)
(leftKey, leftValue, k, v)
}
case FullOuter =>
val leftInputMap = leftInput.toMap
val rightInputMap = rightInput.toMap
val leftOutput = leftInput.map { case (k, v) =>
val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0)
val rightValue = rightInputMap.getOrElse(k, null)
(k, v, rightKey, rightValue)
}
val rightOutput = rightInput.map { case (k, v) =>
val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0)
val leftValue = leftInputMap.getOrElse(k, null)
(leftKey, leftValue, k, v)
}
(leftOutput ++ rightOutput).distinct
case other =>
throw new IllegalArgumentException(s"Join type $other is not applicable")
}
}
}

View file

@ -17,28 +17,33 @@
package org.apache.spark.sql.execution.local
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression}
import org.apache.spark.sql.types.{IntegerType, StringType}
class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext {
test("basic") {
val output = testData.queryExecution.sparkPlan.output
val columns = Seq(output(1), output(0))
checkAnswer(
testData,
node => ProjectNode(conf, columns, node),
testData.select("value", "key").collect()
)
class ProjectNodeSuite extends LocalNodeTest {
private val pieAttributes = Seq(
AttributeReference("id", IntegerType)(),
AttributeReference("age", IntegerType)(),
AttributeReference("name", StringType)())
private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = {
val inputNode = new DummyNode(pieAttributes, inputData)
val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2))
val projectNode = new ProjectNode(conf, columns, inputNode)
val expectedOutput = inputData.map { case (id, age, name) => (id, name) }
val actualOutput = projectNode.collect().map { case row =>
(row.getInt(0), row.getString(1))
}
assert(actualOutput === expectedOutput)
}
test("empty") {
val output = emptyTestData.queryExecution.sparkPlan.output
val columns = Seq(output(1), output(0))
checkAnswer(
emptyTestData,
node => ProjectNode(conf, columns, node),
emptyTestData.select("value", "key").collect()
)
testProject()
}
test("basic") {
testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray)
}
}

View file

@ -17,21 +17,32 @@
package org.apache.spark.sql.execution.local
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
class SampleNodeSuite extends LocalNodeTest {
import testImplicits._
private def testSample(withReplacement: Boolean): Unit = {
test(s"withReplacement: $withReplacement") {
val seed = 0L
val input = sqlContext.sparkContext.
parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition
toDF("key", "value")
checkAnswer(
input,
node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node),
input.sample(withReplacement, 0.3, seed).collect()
)
val seed = 0L
val lowerb = 0.0
val upperb = 0.3
val maybeOut = if (withReplacement) "" else "out"
test(s"with$maybeOut replacement") {
val inputData = (1 to 1000).map { i => (i, i) }.toArray
val inputNode = new DummyNode(kvIntAttributes, inputData)
val sampleNode = new SampleNode(conf, lowerb, upperb, withReplacement, seed, inputNode)
val sampler =
if (withReplacement) {
new PoissonSampler[(Int, Int)](upperb - lowerb, useGapSamplingIfPossible = false)
} else {
new BernoulliCellSampler[(Int, Int)](lowerb, upperb)
}
sampler.setSeed(seed)
val expectedOutput = sampler.sample(inputData.iterator).toArray
val actualOutput = sampleNode.collect().map { case row =>
(row.getInt(0), row.getInt(1))
}
assert(actualOutput === expectedOutput)
}
}

View file

@ -17,38 +17,34 @@
package org.apache.spark.sql.execution.local
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder}
import scala.util.Random
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SortOrder
class TakeOrderedAndProjectNodeSuite extends LocalNodeTest {
import testImplicits._
private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = {
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case expr: SortOrder =>
expr
case expr: Expression =>
SortOrder(expr, Ascending)
}
}
sortOrder
}
private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = {
val testCaseName = if (desc) "desc" else "asc"
test(testCaseName) {
val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
val sortColumn = if (desc) input.col("key").desc else input.col("key")
checkAnswer(
input,
node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node),
input.sort(sortColumn).limit(5).collect()
)
private def testTakeOrderedAndProject(desc: Boolean): Unit = {
val limit = 10
val ascOrDesc = if (desc) "desc" else "asc"
test(ascOrDesc) {
val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) }.toArray
val inputNode = new DummyNode(kvIntAttributes, inputData)
val firstColumn = inputNode.output(0)
val sortDirection = if (desc) Descending else Ascending
val sortOrder = SortOrder(firstColumn, sortDirection)
val takeOrderAndProjectNode = new TakeOrderedAndProjectNode(
conf, limit, Seq(sortOrder), Some(Seq(firstColumn)), inputNode)
val expectedOutput = inputData
.map { case (k, _) => k }
.sortBy { k => k * (if (desc) -1 else 1) }
.take(limit)
val actualOutput = takeOrderAndProjectNode.collect().map { row => row.getInt(0) }
assert(actualOutput === expectedOutput)
}
}
testTakeOrderedAndProjectNode(desc = false)
testTakeOrderedAndProjectNode(desc = true)
testTakeOrderedAndProject(desc = false)
testTakeOrderedAndProject(desc = true)
}

View file

@ -17,36 +17,39 @@
package org.apache.spark.sql.execution.local
import org.apache.spark.sql.test.SharedSQLContext
class UnionNodeSuite extends LocalNodeTest with SharedSQLContext {
class UnionNodeSuite extends LocalNodeTest {
test("basic") {
checkAnswer2(
testData,
testData,
(node1, node2) => UnionNode(conf, Seq(node1, node2)),
testData.unionAll(testData).collect()
)
private def testUnion(inputData: Seq[Array[(Int, Int)]]): Unit = {
val inputNodes = inputData.map { data =>
new DummyNode(kvIntAttributes, data)
}
val unionNode = new UnionNode(conf, inputNodes)
val expectedOutput = inputData.flatten
val actualOutput = unionNode.collect().map { case row =>
(row.getInt(0), row.getInt(1))
}
assert(actualOutput === expectedOutput)
}
test("empty") {
checkAnswer2(
emptyTestData,
emptyTestData,
(node1, node2) => UnionNode(conf, Seq(node1, node2)),
emptyTestData.unionAll(emptyTestData).collect()
)
testUnion(Seq(Array.empty))
testUnion(Seq(Array.empty, Array.empty))
}
test("complicated union") {
val dfs = Seq(testData, emptyTestData, emptyTestData, testData, testData, emptyTestData,
emptyTestData, emptyTestData, testData, emptyTestData)
doCheckAnswer(
dfs,
nodes => UnionNode(conf, nodes),
dfs.reduce(_.unionAll(_)).collect()
)
test("self") {
val data = (1 to 100).map { i => (i, i) }.toArray
testUnion(Seq(data))
testUnion(Seq(data, data))
testUnion(Seq(data, data, data))
}
test("basic") {
val zero = Array.empty[(Int, Int)]
val one = (1 to 100).map { i => (i, i) }.toArray
val two = (50 to 150).map { i => (i, i) }.toArray
val three = (800 to 900).map { i => (i, i) }.toArray
testUnion(Seq(zero, one, two, three))
}
}