[SPARK-10613] [SPARK-10624] [SQL] Reduce LocalNode tests dependency on SQLContext
Instead of relying on `DataFrames` to verify our answers, we can just use simple arrays. This significantly simplifies the test logic for `LocalNode`s and reduces a lot of code duplicated from `SparkPlanTest`. This also fixes an additional issue [SPARK-10624](https://issues.apache.org/jira/browse/SPARK-10624) where the output of `TakeOrderedAndProjectNode` is not actually ordered. Author: Andrew Or <andrew@databricks.com> Closes #8764 from andrewor14/sql-local-tests-cleanup.
This commit is contained in:
parent
38700ea40c
commit
35a19f3357
|
@ -24,7 +24,7 @@ import org.apache.spark.sql.{SQLConf, Row}
|
|||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.trees.TreeNode
|
||||
import org.apache.spark.sql.catalyst.plans.QueryPlan
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
/**
|
||||
|
@ -33,18 +33,14 @@ import org.apache.spark.sql.types.StructType
|
|||
* Before consuming the iterator, open function must be called.
|
||||
* After consuming the iterator, close function must be called.
|
||||
*/
|
||||
abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging {
|
||||
abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging {
|
||||
|
||||
protected val codegenEnabled: Boolean = conf.codegenEnabled
|
||||
|
||||
protected val unsafeEnabled: Boolean = conf.unsafeEnabled
|
||||
|
||||
lazy val schema: StructType = StructType.fromAttributes(output)
|
||||
|
||||
private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing")
|
||||
|
||||
def output: Seq[Attribute]
|
||||
|
||||
/**
|
||||
* Called before open(). Prepare can be used to reserve memory needed. It must NOT consume
|
||||
* any input data.
|
||||
|
|
|
@ -17,13 +17,12 @@
|
|||
|
||||
package org.apache.spark.sql.execution.local
|
||||
|
||||
import java.util.Random
|
||||
|
||||
import org.apache.spark.sql.SQLConf
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
|
||||
|
||||
|
||||
/**
|
||||
* Sample the dataset.
|
||||
*
|
||||
|
@ -51,18 +50,15 @@ case class SampleNode(
|
|||
|
||||
override def open(): Unit = {
|
||||
child.open()
|
||||
val (sampler, _seed) = if (withReplacement) {
|
||||
val random = new Random(seed)
|
||||
val sampler =
|
||||
if (withReplacement) {
|
||||
// Disable gap sampling since the gap sampling method buffers two rows internally,
|
||||
// requiring us to copy the row, which is more expensive than the random number generator.
|
||||
(new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false),
|
||||
// Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result
|
||||
// of DataFrame
|
||||
random.nextLong())
|
||||
new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false)
|
||||
} else {
|
||||
(new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed)
|
||||
new BernoulliCellSampler[InternalRow](lowerBound, upperBound)
|
||||
}
|
||||
sampler.setSeed(_seed)
|
||||
sampler.setSeed(seed)
|
||||
iterator = sampler.sample(child.asIterator)
|
||||
}
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ case class TakeOrderedAndProjectNode(
|
|||
}
|
||||
// Close it eagerly since we don't need it.
|
||||
child.close()
|
||||
iterator = queue.iterator
|
||||
iterator = queue.toArray.sorted(ord).iterator
|
||||
}
|
||||
|
||||
override def next(): Boolean = {
|
||||
|
|
|
@ -238,7 +238,7 @@ object SparkPlanTest {
|
|||
outputPlan transform {
|
||||
case plan: SparkPlan =>
|
||||
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
|
||||
plan.transformExpressions {
|
||||
plan transformExpressions {
|
||||
case UnresolvedAttribute(Seq(u)) =>
|
||||
inputMap.getOrElse(u,
|
||||
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.execution.local
|
||||
|
||||
import org.apache.spark.sql.SQLConf
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
|
||||
|
||||
/**
|
||||
* A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]].
|
||||
*/
|
||||
private[local] case class DummyNode(
|
||||
output: Seq[Attribute],
|
||||
relation: LocalRelation,
|
||||
conf: SQLConf)
|
||||
extends LocalNode(conf) {
|
||||
|
||||
import DummyNode._
|
||||
|
||||
private var index: Int = CLOSED
|
||||
private val input: Seq[InternalRow] = relation.data
|
||||
|
||||
def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) {
|
||||
this(output, LocalRelation.fromProduct(output, data), conf)
|
||||
}
|
||||
|
||||
def isOpen: Boolean = index != CLOSED
|
||||
|
||||
override def children: Seq[LocalNode] = Seq.empty
|
||||
|
||||
override def open(): Unit = {
|
||||
index = -1
|
||||
}
|
||||
|
||||
override def next(): Boolean = {
|
||||
index += 1
|
||||
index < input.size
|
||||
}
|
||||
|
||||
override def fetch(): InternalRow = {
|
||||
assert(index >= 0 && index < input.size)
|
||||
input(index)
|
||||
}
|
||||
|
||||
override def close(): Unit = {
|
||||
index = CLOSED
|
||||
}
|
||||
}
|
||||
|
||||
private object DummyNode {
|
||||
val CLOSED: Int = Int.MinValue
|
||||
}
|
|
@ -17,35 +17,33 @@
|
|||
|
||||
package org.apache.spark.sql.execution.local
|
||||
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
|
||||
|
||||
class ExpandNodeSuite extends LocalNodeTest {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
test("expand") {
|
||||
val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value")
|
||||
checkAnswer(
|
||||
input,
|
||||
node =>
|
||||
ExpandNode(conf, Seq(
|
||||
Seq(
|
||||
input.col("key") + input.col("value"), input.col("key") - input.col("value")
|
||||
).map(_.expr),
|
||||
Seq(
|
||||
input.col("key") * input.col("value"), input.col("key") / input.col("value")
|
||||
).map(_.expr)
|
||||
), node.output, node),
|
||||
Seq(
|
||||
(2, 0),
|
||||
(1, 1),
|
||||
(4, 0),
|
||||
(4, 1),
|
||||
(6, 0),
|
||||
(9, 1),
|
||||
(8, 0),
|
||||
(16, 1),
|
||||
(10, 0),
|
||||
(25, 1)
|
||||
).toDF().collect()
|
||||
)
|
||||
private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = {
|
||||
val inputNode = new DummyNode(kvIntAttributes, inputData)
|
||||
val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v))
|
||||
val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode)
|
||||
val resolvedNode = resolveExpressions(expandNode)
|
||||
val expectedOutput = {
|
||||
val firstHalf = inputData.map { case (k, v) => (k + v, k - v) }
|
||||
val secondHalf = inputData.map { case (k, v) => (k * v, k / v) }
|
||||
firstHalf ++ secondHalf
|
||||
}
|
||||
val actualOutput = resolvedNode.collect().map { case row =>
|
||||
(row.getInt(0), row.getInt(1))
|
||||
}
|
||||
assert(actualOutput.toSet === expectedOutput.toSet)
|
||||
}
|
||||
|
||||
test("empty") {
|
||||
testExpand()
|
||||
}
|
||||
|
||||
test("basic") {
|
||||
testExpand((1 to 100).map { i => (i, i * 1000) }.toArray)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,25 +17,29 @@
|
|||
|
||||
package org.apache.spark.sql.execution.local
|
||||
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
|
||||
class FilterNodeSuite extends LocalNodeTest with SharedSQLContext {
|
||||
|
||||
test("basic") {
|
||||
val condition = (testData.col("key") % 2) === 0
|
||||
checkAnswer(
|
||||
testData,
|
||||
node => FilterNode(conf, condition.expr, node),
|
||||
testData.filter(condition).collect()
|
||||
)
|
||||
class FilterNodeSuite extends LocalNodeTest {
|
||||
|
||||
private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = {
|
||||
val cond = 'k % 2 === 0
|
||||
val inputNode = new DummyNode(kvIntAttributes, inputData)
|
||||
val filterNode = new FilterNode(conf, cond, inputNode)
|
||||
val resolvedNode = resolveExpressions(filterNode)
|
||||
val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 }
|
||||
val actualOutput = resolvedNode.collect().map { case row =>
|
||||
(row.getInt(0), row.getInt(1))
|
||||
}
|
||||
assert(actualOutput === expectedOutput)
|
||||
}
|
||||
|
||||
test("empty") {
|
||||
val condition = (emptyTestData.col("key") % 2) === 0
|
||||
checkAnswer(
|
||||
emptyTestData,
|
||||
node => FilterNode(conf, condition.expr, node),
|
||||
emptyTestData.filter(condition).collect()
|
||||
)
|
||||
testFilter()
|
||||
}
|
||||
|
||||
test("basic") {
|
||||
testFilter((1 to 100).map { i => (i, i) }.toArray)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -18,99 +18,80 @@
|
|||
package org.apache.spark.sql.execution.local
|
||||
|
||||
import org.apache.spark.sql.SQLConf
|
||||
import org.apache.spark.sql.execution.joins
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
|
||||
|
||||
|
||||
class HashJoinNodeSuite extends LocalNodeTest {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = {
|
||||
test(s"$suiteName: inner join with one match per row") {
|
||||
withSQLConf(confPairs: _*) {
|
||||
checkAnswer2(
|
||||
upperCaseData,
|
||||
lowerCaseData,
|
||||
wrapForUnsafe(
|
||||
(node1, node2) => HashJoinNode(
|
||||
conf,
|
||||
Seq(upperCaseData.col("N").expr),
|
||||
Seq(lowerCaseData.col("n").expr),
|
||||
joins.BuildLeft,
|
||||
node1,
|
||||
node2)
|
||||
),
|
||||
upperCaseData.join(lowerCaseData, $"n" === $"N").collect()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$suiteName: inner join with multiple matches") {
|
||||
withSQLConf(confPairs: _*) {
|
||||
val x = testData2.where($"a" === 1).as("x")
|
||||
val y = testData2.where($"a" === 1).as("y")
|
||||
checkAnswer2(
|
||||
x,
|
||||
y,
|
||||
wrapForUnsafe(
|
||||
(node1, node2) => HashJoinNode(
|
||||
conf,
|
||||
Seq(x.col("a").expr),
|
||||
Seq(y.col("a").expr),
|
||||
joins.BuildLeft,
|
||||
node1,
|
||||
node2)
|
||||
),
|
||||
x.join(y).where($"x.a" === $"y.a").collect()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$suiteName: inner join, no matches") {
|
||||
withSQLConf(confPairs: _*) {
|
||||
val x = testData2.where($"a" === 1).as("x")
|
||||
val y = testData2.where($"a" === 2).as("y")
|
||||
checkAnswer2(
|
||||
x,
|
||||
y,
|
||||
wrapForUnsafe(
|
||||
(node1, node2) => HashJoinNode(
|
||||
conf,
|
||||
Seq(x.col("a").expr),
|
||||
Seq(y.col("a").expr),
|
||||
joins.BuildLeft,
|
||||
node1,
|
||||
node2)
|
||||
),
|
||||
Nil
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$suiteName: big inner join, 4 matches per row") {
|
||||
withSQLConf(confPairs: _*) {
|
||||
val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData)
|
||||
val bigDataX = bigData.as("x")
|
||||
val bigDataY = bigData.as("y")
|
||||
|
||||
checkAnswer2(
|
||||
bigDataX,
|
||||
bigDataY,
|
||||
wrapForUnsafe(
|
||||
(node1, node2) =>
|
||||
HashJoinNode(
|
||||
conf,
|
||||
Seq(bigDataX.col("key").expr),
|
||||
Seq(bigDataY.col("key").expr),
|
||||
joins.BuildLeft,
|
||||
node1,
|
||||
node2)
|
||||
),
|
||||
bigDataX.join(bigDataY).where($"x.key" === $"y.key").collect())
|
||||
}
|
||||
// Test all combinations of the two dimensions: with/out unsafe and build sides
|
||||
private val maybeUnsafeAndCodegen = Seq(false, true)
|
||||
private val buildSides = Seq(BuildLeft, BuildRight)
|
||||
maybeUnsafeAndCodegen.foreach { unsafeAndCodegen =>
|
||||
buildSides.foreach { buildSide =>
|
||||
testJoin(unsafeAndCodegen, buildSide)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Test inner hash join with varying degrees of matches.
|
||||
*/
|
||||
private def testJoin(
|
||||
unsafeAndCodegen: Boolean,
|
||||
buildSide: BuildSide): Unit = {
|
||||
val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe"
|
||||
val testNamePrefix = s"$simpleOrUnsafe / $buildSide"
|
||||
val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray
|
||||
val conf = new SQLConf
|
||||
conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen)
|
||||
conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen)
|
||||
|
||||
// Actual test body
|
||||
def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = {
|
||||
val rightInputMap = rightInput.toMap
|
||||
val leftNode = new DummyNode(joinNameAttributes, leftInput)
|
||||
val rightNode = new DummyNode(joinNicknameAttributes, rightInput)
|
||||
val makeNode = (node1: LocalNode, node2: LocalNode) => {
|
||||
resolveExpressions(new HashJoinNode(
|
||||
conf, Seq('id1), Seq('id2), buildSide, node1, node2))
|
||||
}
|
||||
val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode
|
||||
val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
|
||||
val expectedOutput = leftInput
|
||||
.filter { case (k, _) => rightInputMap.contains(k) }
|
||||
.map { case (k, v) => (k, v, k, rightInputMap(k)) }
|
||||
val actualOutput = hashJoinNode.collect().map { row =>
|
||||
// (id, name, id, nickname)
|
||||
(row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))
|
||||
}
|
||||
assert(actualOutput === expectedOutput)
|
||||
}
|
||||
|
||||
test(s"$testNamePrefix: empty") {
|
||||
runTest(Array.empty, Array.empty)
|
||||
runTest(someData, Array.empty)
|
||||
runTest(Array.empty, someData)
|
||||
}
|
||||
|
||||
test(s"$testNamePrefix: no matches") {
|
||||
val someIrrelevantData = (10000 to 100100).map { i => (i, "piper" + i) }.toArray
|
||||
runTest(someData, Array.empty)
|
||||
runTest(Array.empty, someData)
|
||||
runTest(someData, someIrrelevantData)
|
||||
runTest(someIrrelevantData, someData)
|
||||
}
|
||||
|
||||
test(s"$testNamePrefix: partial matches") {
|
||||
val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray
|
||||
runTest(someData, someOtherData)
|
||||
runTest(someOtherData, someData)
|
||||
}
|
||||
|
||||
test(s"$testNamePrefix: full matches") {
|
||||
val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) }.toArray
|
||||
runTest(someData, someSuperRelevantData)
|
||||
runTest(someSuperRelevantData, someData)
|
||||
}
|
||||
}
|
||||
|
||||
joinSuite(
|
||||
"general", SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false")
|
||||
joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true")
|
||||
}
|
||||
|
|
|
@ -17,19 +17,21 @@
|
|||
|
||||
package org.apache.spark.sql.execution.local
|
||||
|
||||
|
||||
class IntersectNodeSuite extends LocalNodeTest {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
test("basic") {
|
||||
val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
|
||||
val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value")
|
||||
|
||||
checkAnswer2(
|
||||
input1,
|
||||
input2,
|
||||
(node1, node2) => IntersectNode(conf, node1, node2),
|
||||
input1.intersect(input2).collect()
|
||||
)
|
||||
val n = 100
|
||||
val leftData = (1 to n).filter { i => i % 2 == 0 }.map { i => (i, i) }.toArray
|
||||
val rightData = (1 to n).filter { i => i % 3 == 0 }.map { i => (i, i) }.toArray
|
||||
val leftNode = new DummyNode(kvIntAttributes, leftData)
|
||||
val rightNode = new DummyNode(kvIntAttributes, rightData)
|
||||
val intersectNode = new IntersectNode(conf, leftNode, rightNode)
|
||||
val expectedOutput = leftData.intersect(rightData)
|
||||
val actualOutput = intersectNode.collect().map { case row =>
|
||||
(row.getInt(0), row.getInt(1))
|
||||
}
|
||||
assert(actualOutput === expectedOutput)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,23 +17,25 @@
|
|||
|
||||
package org.apache.spark.sql.execution.local
|
||||
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class LimitNodeSuite extends LocalNodeTest with SharedSQLContext {
|
||||
class LimitNodeSuite extends LocalNodeTest {
|
||||
|
||||
test("basic") {
|
||||
checkAnswer(
|
||||
testData,
|
||||
node => LimitNode(conf, 10, node),
|
||||
testData.limit(10).collect()
|
||||
)
|
||||
private def testLimit(inputData: Array[(Int, Int)] = Array.empty, limit: Int = 10): Unit = {
|
||||
val inputNode = new DummyNode(kvIntAttributes, inputData)
|
||||
val limitNode = new LimitNode(conf, limit, inputNode)
|
||||
val expectedOutput = inputData.take(limit)
|
||||
val actualOutput = limitNode.collect().map { case row =>
|
||||
(row.getInt(0), row.getInt(1))
|
||||
}
|
||||
assert(actualOutput === expectedOutput)
|
||||
}
|
||||
|
||||
test("empty") {
|
||||
checkAnswer(
|
||||
emptyTestData,
|
||||
node => LimitNode(conf, 10, node),
|
||||
emptyTestData.limit(10).collect()
|
||||
)
|
||||
testLimit()
|
||||
}
|
||||
|
||||
test("basic") {
|
||||
testLimit((1 to 100).map { i => (i, i) }.toArray, 20)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,28 +17,24 @@
|
|||
|
||||
package org.apache.spark.sql.execution.local
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.SQLConf
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.types.IntegerType
|
||||
|
||||
class LocalNodeSuite extends SparkFunSuite {
|
||||
private val data = (1 to 100).toArray
|
||||
class LocalNodeSuite extends LocalNodeTest {
|
||||
private val data = (1 to 100).map { i => (i, i) }.toArray
|
||||
|
||||
test("basic open, next, fetch, close") {
|
||||
val node = new DummyLocalNode(data)
|
||||
val node = new DummyNode(kvIntAttributes, data)
|
||||
assert(!node.isOpen)
|
||||
node.open()
|
||||
assert(node.isOpen)
|
||||
data.foreach { i =>
|
||||
data.foreach { case (k, v) =>
|
||||
assert(node.next())
|
||||
// fetch should be idempotent
|
||||
val fetched = node.fetch()
|
||||
assert(node.fetch() === fetched)
|
||||
assert(node.fetch() === fetched)
|
||||
assert(node.fetch().numFields === 1)
|
||||
assert(node.fetch().getInt(0) === i)
|
||||
assert(node.fetch().numFields === 2)
|
||||
assert(node.fetch().getInt(0) === k)
|
||||
assert(node.fetch().getInt(1) === v)
|
||||
}
|
||||
assert(!node.next())
|
||||
node.close()
|
||||
|
@ -46,16 +42,17 @@ class LocalNodeSuite extends SparkFunSuite {
|
|||
}
|
||||
|
||||
test("asIterator") {
|
||||
val node = new DummyLocalNode(data)
|
||||
val node = new DummyNode(kvIntAttributes, data)
|
||||
val iter = node.asIterator
|
||||
node.open()
|
||||
data.foreach { i =>
|
||||
data.foreach { case (k, v) =>
|
||||
// hasNext should be idempotent
|
||||
assert(iter.hasNext)
|
||||
assert(iter.hasNext)
|
||||
val item = iter.next()
|
||||
assert(item.numFields === 1)
|
||||
assert(item.getInt(0) === i)
|
||||
assert(item.numFields === 2)
|
||||
assert(item.getInt(0) === k)
|
||||
assert(item.getInt(1) === v)
|
||||
}
|
||||
intercept[NoSuchElementException] {
|
||||
iter.next()
|
||||
|
@ -64,53 +61,13 @@ class LocalNodeSuite extends SparkFunSuite {
|
|||
}
|
||||
|
||||
test("collect") {
|
||||
val node = new DummyLocalNode(data)
|
||||
val node = new DummyNode(kvIntAttributes, data)
|
||||
node.open()
|
||||
val collected = node.collect()
|
||||
assert(collected.size === data.size)
|
||||
assert(collected.forall(_.size === 1))
|
||||
assert(collected.map(_.getInt(0)) === data)
|
||||
assert(collected.forall(_.size === 2))
|
||||
assert(collected.map { case row => (row.getInt(0), row.getInt(0)) } === data)
|
||||
node.close()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* A dummy [[LocalNode]] that just returns one row per integer in the input.
|
||||
*/
|
||||
private case class DummyLocalNode(conf: SQLConf, input: Array[Int]) extends LocalNode(conf) {
|
||||
private var index = Int.MinValue
|
||||
|
||||
def this(input: Array[Int]) {
|
||||
this(new SQLConf, input)
|
||||
}
|
||||
|
||||
def isOpen: Boolean = {
|
||||
index != Int.MinValue
|
||||
}
|
||||
|
||||
override def output: Seq[Attribute] = {
|
||||
Seq(AttributeReference("something", IntegerType)())
|
||||
}
|
||||
|
||||
override def children: Seq[LocalNode] = Seq.empty
|
||||
|
||||
override def open(): Unit = {
|
||||
index = -1
|
||||
}
|
||||
|
||||
override def next(): Boolean = {
|
||||
index += 1
|
||||
index < input.size
|
||||
}
|
||||
|
||||
override def fetch(): InternalRow = {
|
||||
assert(index >= 0 && index < input.size)
|
||||
val values = Array(input(index).asInstanceOf[Any])
|
||||
new GenericInternalRow(values)
|
||||
}
|
||||
|
||||
override def close(): Unit = {
|
||||
index = Int.MinValue
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,147 +17,54 @@
|
|||
|
||||
package org.apache.spark.sql.execution.local
|
||||
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.{DataFrame, Row, SQLConf}
|
||||
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
|
||||
import org.apache.spark.sql.SQLConf
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
|
||||
import org.apache.spark.sql.catalyst.expressions.AttributeReference
|
||||
import org.apache.spark.sql.types.{IntegerType, StringType}
|
||||
|
||||
class LocalNodeTest extends SparkFunSuite with SharedSQLContext {
|
||||
|
||||
def conf: SQLConf = sqlContext.conf
|
||||
class LocalNodeTest extends SparkFunSuite {
|
||||
|
||||
protected val conf: SQLConf = new SQLConf
|
||||
protected val kvIntAttributes = Seq(
|
||||
AttributeReference("k", IntegerType)(),
|
||||
AttributeReference("v", IntegerType)())
|
||||
protected val joinNameAttributes = Seq(
|
||||
AttributeReference("id1", IntegerType)(),
|
||||
AttributeReference("name", StringType)())
|
||||
protected val joinNicknameAttributes = Seq(
|
||||
AttributeReference("id2", IntegerType)(),
|
||||
AttributeReference("nickname", StringType)())
|
||||
|
||||
/**
|
||||
* Wrap a function processing two [[LocalNode]]s such that:
|
||||
* (1) all input rows are automatically converted to unsafe rows
|
||||
* (2) all output rows are automatically converted back to safe rows
|
||||
*/
|
||||
protected def wrapForUnsafe(
|
||||
f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = {
|
||||
if (conf.unsafeEnabled) {
|
||||
(left: LocalNode, right: LocalNode) => {
|
||||
val _left = ConvertToUnsafeNode(conf, left)
|
||||
val _right = ConvertToUnsafeNode(conf, right)
|
||||
val r = f(_left, _right)
|
||||
ConvertToSafeNode(conf, r)
|
||||
}
|
||||
} else {
|
||||
f
|
||||
(left: LocalNode, right: LocalNode) => {
|
||||
val _left = ConvertToUnsafeNode(conf, left)
|
||||
val _right = ConvertToUnsafeNode(conf, right)
|
||||
val r = f(_left, _right)
|
||||
ConvertToSafeNode(conf, r)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the LocalNode and makes sure the answer matches the expected result.
|
||||
* @param input the input data to be used.
|
||||
* @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate
|
||||
* the local physical operator that's being tested.
|
||||
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
|
||||
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
|
||||
* to being compared.
|
||||
* Recursively resolve all expressions in a [[LocalNode]] using the node's attributes.
|
||||
*/
|
||||
protected def checkAnswer(
|
||||
input: DataFrame,
|
||||
nodeFunction: LocalNode => LocalNode,
|
||||
expectedAnswer: Seq[Row],
|
||||
sortAnswers: Boolean = true): Unit = {
|
||||
doCheckAnswer(
|
||||
input :: Nil,
|
||||
nodes => nodeFunction(nodes.head),
|
||||
expectedAnswer,
|
||||
sortAnswers)
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the LocalNode and makes sure the answer matches the expected result.
|
||||
* @param left the left input data to be used.
|
||||
* @param right the right input data to be used.
|
||||
* @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate
|
||||
* the local physical operator that's being tested.
|
||||
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
|
||||
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
|
||||
* to being compared.
|
||||
*/
|
||||
protected def checkAnswer2(
|
||||
left: DataFrame,
|
||||
right: DataFrame,
|
||||
nodeFunction: (LocalNode, LocalNode) => LocalNode,
|
||||
expectedAnswer: Seq[Row],
|
||||
sortAnswers: Boolean = true): Unit = {
|
||||
doCheckAnswer(
|
||||
left :: right :: Nil,
|
||||
nodes => nodeFunction(nodes(0), nodes(1)),
|
||||
expectedAnswer,
|
||||
sortAnswers)
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the `LocalNode`s and makes sure the answer matches the expected result.
|
||||
* @param input the input data to be used.
|
||||
* @param nodeFunction a function which accepts a sequence of input `LocalNode`s and uses them to
|
||||
* instantiate the local physical operator that's being tested.
|
||||
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
|
||||
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
|
||||
* to being compared.
|
||||
*/
|
||||
protected def doCheckAnswer(
|
||||
input: Seq[DataFrame],
|
||||
nodeFunction: Seq[LocalNode] => LocalNode,
|
||||
expectedAnswer: Seq[Row],
|
||||
sortAnswers: Boolean = true): Unit = {
|
||||
LocalNodeTest.checkAnswer(
|
||||
input.map(dataFrameToSeqScanNode), nodeFunction, expectedAnswer, sortAnswers) match {
|
||||
case Some(errorMessage) => fail(errorMessage)
|
||||
case None =>
|
||||
protected def resolveExpressions(outputNode: LocalNode): LocalNode = {
|
||||
outputNode transform {
|
||||
case node: LocalNode =>
|
||||
val inputMap = node.output.map { a => (a.name, a) }.toMap
|
||||
node transformExpressions {
|
||||
case UnresolvedAttribute(Seq(u)) =>
|
||||
inputMap.getOrElse(u,
|
||||
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = {
|
||||
new SeqScanNode(
|
||||
conf,
|
||||
df.queryExecution.sparkPlan.output,
|
||||
df.queryExecution.toRdd.map(_.copy()).collect())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper methods for writing tests of individual local physical operators.
|
||||
*/
|
||||
object LocalNodeTest {
|
||||
|
||||
/**
|
||||
* Runs the `LocalNode`s and makes sure the answer matches the expected result.
|
||||
* @param input the input data to be used.
|
||||
* @param nodeFunction a function which accepts the input `LocalNode`s and uses them to
|
||||
* instantiate the local physical operator that's being tested.
|
||||
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
|
||||
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
|
||||
* to being compared.
|
||||
*/
|
||||
def checkAnswer(
|
||||
input: Seq[SeqScanNode],
|
||||
nodeFunction: Seq[LocalNode] => LocalNode,
|
||||
expectedAnswer: Seq[Row],
|
||||
sortAnswers: Boolean): Option[String] = {
|
||||
|
||||
val outputNode = nodeFunction(input)
|
||||
|
||||
val outputResult: Seq[Row] = try {
|
||||
outputNode.collect()
|
||||
} catch {
|
||||
case NonFatal(e) =>
|
||||
val errorMessage =
|
||||
s"""
|
||||
| Exception thrown while executing local plan:
|
||||
| $outputNode
|
||||
| == Exception ==
|
||||
| $e
|
||||
| ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
|
||||
""".stripMargin
|
||||
return Some(errorMessage)
|
||||
}
|
||||
|
||||
SQLTestUtils.compareAnswers(outputResult, expectedAnswer, sortAnswers).map { errorMessage =>
|
||||
s"""
|
||||
| Results do not match for local plan:
|
||||
| $outputNode
|
||||
| $errorMessage
|
||||
""".stripMargin
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,222 +18,128 @@
|
|||
package org.apache.spark.sql.execution.local
|
||||
|
||||
import org.apache.spark.sql.SQLConf
|
||||
import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter}
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
|
||||
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
|
||||
|
||||
|
||||
class NestedLoopJoinNodeSuite extends LocalNodeTest {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
private def joinSuite(
|
||||
suiteName: String, buildSide: BuildSide, confPairs: (String, String)*): Unit = {
|
||||
test(s"$suiteName: left outer join") {
|
||||
withSQLConf(confPairs: _*) {
|
||||
checkAnswer2(
|
||||
upperCaseData,
|
||||
lowerCaseData,
|
||||
wrapForUnsafe(
|
||||
(node1, node2) => NestedLoopJoinNode(
|
||||
conf,
|
||||
node1,
|
||||
node2,
|
||||
buildSide,
|
||||
LeftOuter,
|
||||
Some((upperCaseData.col("N") === lowerCaseData.col("n")).expr))
|
||||
),
|
||||
upperCaseData.join(lowerCaseData, $"n" === $"N", "left").collect())
|
||||
|
||||
checkAnswer2(
|
||||
upperCaseData,
|
||||
lowerCaseData,
|
||||
wrapForUnsafe(
|
||||
(node1, node2) => NestedLoopJoinNode(
|
||||
conf,
|
||||
node1,
|
||||
node2,
|
||||
buildSide,
|
||||
LeftOuter,
|
||||
Some(
|
||||
(upperCaseData.col("N") === lowerCaseData.col("n") &&
|
||||
lowerCaseData.col("n") > 1).expr))
|
||||
),
|
||||
upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left").collect())
|
||||
|
||||
checkAnswer2(
|
||||
upperCaseData,
|
||||
lowerCaseData,
|
||||
wrapForUnsafe(
|
||||
(node1, node2) => NestedLoopJoinNode(
|
||||
conf,
|
||||
node1,
|
||||
node2,
|
||||
buildSide,
|
||||
LeftOuter,
|
||||
Some(
|
||||
(upperCaseData.col("N") === lowerCaseData.col("n") &&
|
||||
upperCaseData.col("N") > 1).expr))
|
||||
),
|
||||
upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left").collect())
|
||||
|
||||
checkAnswer2(
|
||||
upperCaseData,
|
||||
lowerCaseData,
|
||||
wrapForUnsafe(
|
||||
(node1, node2) => NestedLoopJoinNode(
|
||||
conf,
|
||||
node1,
|
||||
node2,
|
||||
buildSide,
|
||||
LeftOuter,
|
||||
Some(
|
||||
(upperCaseData.col("N") === lowerCaseData.col("n") &&
|
||||
lowerCaseData.col("l") > upperCaseData.col("L")).expr))
|
||||
),
|
||||
upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left").collect())
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$suiteName: right outer join") {
|
||||
withSQLConf(confPairs: _*) {
|
||||
checkAnswer2(
|
||||
lowerCaseData,
|
||||
upperCaseData,
|
||||
wrapForUnsafe(
|
||||
(node1, node2) => NestedLoopJoinNode(
|
||||
conf,
|
||||
node1,
|
||||
node2,
|
||||
buildSide,
|
||||
RightOuter,
|
||||
Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr))
|
||||
),
|
||||
lowerCaseData.join(upperCaseData, $"n" === $"N", "right").collect())
|
||||
|
||||
checkAnswer2(
|
||||
lowerCaseData,
|
||||
upperCaseData,
|
||||
wrapForUnsafe(
|
||||
(node1, node2) => NestedLoopJoinNode(
|
||||
conf,
|
||||
node1,
|
||||
node2,
|
||||
buildSide,
|
||||
RightOuter,
|
||||
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
|
||||
lowerCaseData.col("n") > 1).expr))
|
||||
),
|
||||
lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right").collect())
|
||||
|
||||
checkAnswer2(
|
||||
lowerCaseData,
|
||||
upperCaseData,
|
||||
wrapForUnsafe(
|
||||
(node1, node2) => NestedLoopJoinNode(
|
||||
conf,
|
||||
node1,
|
||||
node2,
|
||||
buildSide,
|
||||
RightOuter,
|
||||
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
|
||||
upperCaseData.col("N") > 1).expr))
|
||||
),
|
||||
lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right").collect())
|
||||
|
||||
checkAnswer2(
|
||||
lowerCaseData,
|
||||
upperCaseData,
|
||||
wrapForUnsafe(
|
||||
(node1, node2) => NestedLoopJoinNode(
|
||||
conf,
|
||||
node1,
|
||||
node2,
|
||||
buildSide,
|
||||
RightOuter,
|
||||
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
|
||||
lowerCaseData.col("l") > upperCaseData.col("L")).expr))
|
||||
),
|
||||
lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right").collect())
|
||||
}
|
||||
}
|
||||
|
||||
test(s"$suiteName: full outer join") {
|
||||
withSQLConf(confPairs: _*) {
|
||||
checkAnswer2(
|
||||
lowerCaseData,
|
||||
upperCaseData,
|
||||
wrapForUnsafe(
|
||||
(node1, node2) => NestedLoopJoinNode(
|
||||
conf,
|
||||
node1,
|
||||
node2,
|
||||
buildSide,
|
||||
FullOuter,
|
||||
Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr))
|
||||
),
|
||||
lowerCaseData.join(upperCaseData, $"n" === $"N", "full").collect())
|
||||
|
||||
checkAnswer2(
|
||||
lowerCaseData,
|
||||
upperCaseData,
|
||||
wrapForUnsafe(
|
||||
(node1, node2) => NestedLoopJoinNode(
|
||||
conf,
|
||||
node1,
|
||||
node2,
|
||||
buildSide,
|
||||
FullOuter,
|
||||
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
|
||||
lowerCaseData.col("n") > 1).expr))
|
||||
),
|
||||
lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "full").collect())
|
||||
|
||||
checkAnswer2(
|
||||
lowerCaseData,
|
||||
upperCaseData,
|
||||
wrapForUnsafe(
|
||||
(node1, node2) => NestedLoopJoinNode(
|
||||
conf,
|
||||
node1,
|
||||
node2,
|
||||
buildSide,
|
||||
FullOuter,
|
||||
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
|
||||
upperCaseData.col("N") > 1).expr))
|
||||
),
|
||||
lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "full").collect())
|
||||
|
||||
checkAnswer2(
|
||||
lowerCaseData,
|
||||
upperCaseData,
|
||||
wrapForUnsafe(
|
||||
(node1, node2) => NestedLoopJoinNode(
|
||||
conf,
|
||||
node1,
|
||||
node2,
|
||||
buildSide,
|
||||
FullOuter,
|
||||
Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
|
||||
lowerCaseData.col("l") > upperCaseData.col("L")).expr))
|
||||
),
|
||||
lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "full").collect())
|
||||
// Test all combinations of the three dimensions: with/out unsafe, build sides, and join types
|
||||
private val maybeUnsafeAndCodegen = Seq(false, true)
|
||||
private val buildSides = Seq(BuildLeft, BuildRight)
|
||||
private val joinTypes = Seq(LeftOuter, RightOuter, FullOuter)
|
||||
maybeUnsafeAndCodegen.foreach { unsafeAndCodegen =>
|
||||
buildSides.foreach { buildSide =>
|
||||
joinTypes.foreach { joinType =>
|
||||
testJoin(unsafeAndCodegen, buildSide, joinType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
joinSuite(
|
||||
"general-build-left",
|
||||
BuildLeft,
|
||||
SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false")
|
||||
joinSuite(
|
||||
"general-build-right",
|
||||
BuildRight,
|
||||
SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false")
|
||||
joinSuite(
|
||||
"tungsten-build-left",
|
||||
BuildLeft,
|
||||
SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true")
|
||||
joinSuite(
|
||||
"tungsten-build-right",
|
||||
BuildRight,
|
||||
SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true")
|
||||
/**
|
||||
* Test outer nested loop joins with varying degrees of matches.
|
||||
*/
|
||||
private def testJoin(
|
||||
unsafeAndCodegen: Boolean,
|
||||
buildSide: BuildSide,
|
||||
joinType: JoinType): Unit = {
|
||||
val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe"
|
||||
val testNamePrefix = s"$simpleOrUnsafe / $buildSide / $joinType"
|
||||
val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray
|
||||
val conf = new SQLConf
|
||||
conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen)
|
||||
conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen)
|
||||
|
||||
// Actual test body
|
||||
def runTest(
|
||||
joinType: JoinType,
|
||||
leftInput: Array[(Int, String)],
|
||||
rightInput: Array[(Int, String)]): Unit = {
|
||||
val leftNode = new DummyNode(joinNameAttributes, leftInput)
|
||||
val rightNode = new DummyNode(joinNicknameAttributes, rightInput)
|
||||
val cond = 'id1 === 'id2
|
||||
val makeNode = (node1: LocalNode, node2: LocalNode) => {
|
||||
resolveExpressions(
|
||||
new NestedLoopJoinNode(conf, node1, node2, buildSide, joinType, Some(cond)))
|
||||
}
|
||||
val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode
|
||||
val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
|
||||
val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType)
|
||||
val actualOutput = hashJoinNode.collect().map { row =>
|
||||
// (id, name, id, nickname)
|
||||
(row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))
|
||||
}
|
||||
assert(actualOutput.toSet === expectedOutput.toSet)
|
||||
}
|
||||
|
||||
test(s"$testNamePrefix: empty") {
|
||||
runTest(joinType, Array.empty, Array.empty)
|
||||
}
|
||||
|
||||
test(s"$testNamePrefix: no matches") {
|
||||
val someIrrelevantData = (10000 to 10100).map { i => (i, "piper" + i) }.toArray
|
||||
runTest(joinType, someData, Array.empty)
|
||||
runTest(joinType, Array.empty, someData)
|
||||
runTest(joinType, someData, someIrrelevantData)
|
||||
runTest(joinType, someIrrelevantData, someData)
|
||||
}
|
||||
|
||||
test(s"$testNamePrefix: partial matches") {
|
||||
val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray
|
||||
runTest(joinType, someData, someOtherData)
|
||||
runTest(joinType, someOtherData, someData)
|
||||
}
|
||||
|
||||
test(s"$testNamePrefix: full matches") {
|
||||
val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) }
|
||||
runTest(joinType, someData, someSuperRelevantData)
|
||||
runTest(joinType, someSuperRelevantData, someData)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method to generate the expected output of a test based on the join type.
|
||||
*/
|
||||
private def generateExpectedOutput(
|
||||
leftInput: Array[(Int, String)],
|
||||
rightInput: Array[(Int, String)],
|
||||
joinType: JoinType): Array[(Int, String, Int, String)] = {
|
||||
joinType match {
|
||||
case LeftOuter =>
|
||||
val rightInputMap = rightInput.toMap
|
||||
leftInput.map { case (k, v) =>
|
||||
val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0)
|
||||
val rightValue = rightInputMap.getOrElse(k, null)
|
||||
(k, v, rightKey, rightValue)
|
||||
}
|
||||
|
||||
case RightOuter =>
|
||||
val leftInputMap = leftInput.toMap
|
||||
rightInput.map { case (k, v) =>
|
||||
val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0)
|
||||
val leftValue = leftInputMap.getOrElse(k, null)
|
||||
(leftKey, leftValue, k, v)
|
||||
}
|
||||
|
||||
case FullOuter =>
|
||||
val leftInputMap = leftInput.toMap
|
||||
val rightInputMap = rightInput.toMap
|
||||
val leftOutput = leftInput.map { case (k, v) =>
|
||||
val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0)
|
||||
val rightValue = rightInputMap.getOrElse(k, null)
|
||||
(k, v, rightKey, rightValue)
|
||||
}
|
||||
val rightOutput = rightInput.map { case (k, v) =>
|
||||
val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0)
|
||||
val leftValue = leftInputMap.getOrElse(k, null)
|
||||
(leftKey, leftValue, k, v)
|
||||
}
|
||||
(leftOutput ++ rightOutput).distinct
|
||||
|
||||
case other =>
|
||||
throw new IllegalArgumentException(s"Join type $other is not applicable")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,28 +17,33 @@
|
|||
|
||||
package org.apache.spark.sql.execution.local
|
||||
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression}
|
||||
import org.apache.spark.sql.types.{IntegerType, StringType}
|
||||
|
||||
class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext {
|
||||
|
||||
test("basic") {
|
||||
val output = testData.queryExecution.sparkPlan.output
|
||||
val columns = Seq(output(1), output(0))
|
||||
checkAnswer(
|
||||
testData,
|
||||
node => ProjectNode(conf, columns, node),
|
||||
testData.select("value", "key").collect()
|
||||
)
|
||||
class ProjectNodeSuite extends LocalNodeTest {
|
||||
private val pieAttributes = Seq(
|
||||
AttributeReference("id", IntegerType)(),
|
||||
AttributeReference("age", IntegerType)(),
|
||||
AttributeReference("name", StringType)())
|
||||
|
||||
private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = {
|
||||
val inputNode = new DummyNode(pieAttributes, inputData)
|
||||
val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2))
|
||||
val projectNode = new ProjectNode(conf, columns, inputNode)
|
||||
val expectedOutput = inputData.map { case (id, age, name) => (id, name) }
|
||||
val actualOutput = projectNode.collect().map { case row =>
|
||||
(row.getInt(0), row.getString(1))
|
||||
}
|
||||
assert(actualOutput === expectedOutput)
|
||||
}
|
||||
|
||||
test("empty") {
|
||||
val output = emptyTestData.queryExecution.sparkPlan.output
|
||||
val columns = Seq(output(1), output(0))
|
||||
checkAnswer(
|
||||
emptyTestData,
|
||||
node => ProjectNode(conf, columns, node),
|
||||
emptyTestData.select("value", "key").collect()
|
||||
)
|
||||
testProject()
|
||||
}
|
||||
|
||||
test("basic") {
|
||||
testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -17,21 +17,32 @@
|
|||
|
||||
package org.apache.spark.sql.execution.local
|
||||
|
||||
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
|
||||
|
||||
|
||||
class SampleNodeSuite extends LocalNodeTest {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
private def testSample(withReplacement: Boolean): Unit = {
|
||||
test(s"withReplacement: $withReplacement") {
|
||||
val seed = 0L
|
||||
val input = sqlContext.sparkContext.
|
||||
parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition
|
||||
toDF("key", "value")
|
||||
checkAnswer(
|
||||
input,
|
||||
node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node),
|
||||
input.sample(withReplacement, 0.3, seed).collect()
|
||||
)
|
||||
val seed = 0L
|
||||
val lowerb = 0.0
|
||||
val upperb = 0.3
|
||||
val maybeOut = if (withReplacement) "" else "out"
|
||||
test(s"with$maybeOut replacement") {
|
||||
val inputData = (1 to 1000).map { i => (i, i) }.toArray
|
||||
val inputNode = new DummyNode(kvIntAttributes, inputData)
|
||||
val sampleNode = new SampleNode(conf, lowerb, upperb, withReplacement, seed, inputNode)
|
||||
val sampler =
|
||||
if (withReplacement) {
|
||||
new PoissonSampler[(Int, Int)](upperb - lowerb, useGapSamplingIfPossible = false)
|
||||
} else {
|
||||
new BernoulliCellSampler[(Int, Int)](lowerb, upperb)
|
||||
}
|
||||
sampler.setSeed(seed)
|
||||
val expectedOutput = sampler.sample(inputData.iterator).toArray
|
||||
val actualOutput = sampleNode.collect().map { case row =>
|
||||
(row.getInt(0), row.getInt(1))
|
||||
}
|
||||
assert(actualOutput === expectedOutput)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,38 +17,34 @@
|
|||
|
||||
package org.apache.spark.sql.execution.local
|
||||
|
||||
import org.apache.spark.sql.Column
|
||||
import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder}
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.SortOrder
|
||||
|
||||
|
||||
class TakeOrderedAndProjectNodeSuite extends LocalNodeTest {
|
||||
|
||||
import testImplicits._
|
||||
|
||||
private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = {
|
||||
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
|
||||
col.expr match {
|
||||
case expr: SortOrder =>
|
||||
expr
|
||||
case expr: Expression =>
|
||||
SortOrder(expr, Ascending)
|
||||
}
|
||||
}
|
||||
sortOrder
|
||||
}
|
||||
|
||||
private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = {
|
||||
val testCaseName = if (desc) "desc" else "asc"
|
||||
test(testCaseName) {
|
||||
val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
|
||||
val sortColumn = if (desc) input.col("key").desc else input.col("key")
|
||||
checkAnswer(
|
||||
input,
|
||||
node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node),
|
||||
input.sort(sortColumn).limit(5).collect()
|
||||
)
|
||||
private def testTakeOrderedAndProject(desc: Boolean): Unit = {
|
||||
val limit = 10
|
||||
val ascOrDesc = if (desc) "desc" else "asc"
|
||||
test(ascOrDesc) {
|
||||
val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) }.toArray
|
||||
val inputNode = new DummyNode(kvIntAttributes, inputData)
|
||||
val firstColumn = inputNode.output(0)
|
||||
val sortDirection = if (desc) Descending else Ascending
|
||||
val sortOrder = SortOrder(firstColumn, sortDirection)
|
||||
val takeOrderAndProjectNode = new TakeOrderedAndProjectNode(
|
||||
conf, limit, Seq(sortOrder), Some(Seq(firstColumn)), inputNode)
|
||||
val expectedOutput = inputData
|
||||
.map { case (k, _) => k }
|
||||
.sortBy { k => k * (if (desc) -1 else 1) }
|
||||
.take(limit)
|
||||
val actualOutput = takeOrderAndProjectNode.collect().map { row => row.getInt(0) }
|
||||
assert(actualOutput === expectedOutput)
|
||||
}
|
||||
}
|
||||
|
||||
testTakeOrderedAndProjectNode(desc = false)
|
||||
testTakeOrderedAndProjectNode(desc = true)
|
||||
testTakeOrderedAndProject(desc = false)
|
||||
testTakeOrderedAndProject(desc = true)
|
||||
}
|
||||
|
|
|
@ -17,36 +17,39 @@
|
|||
|
||||
package org.apache.spark.sql.execution.local
|
||||
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class UnionNodeSuite extends LocalNodeTest with SharedSQLContext {
|
||||
class UnionNodeSuite extends LocalNodeTest {
|
||||
|
||||
test("basic") {
|
||||
checkAnswer2(
|
||||
testData,
|
||||
testData,
|
||||
(node1, node2) => UnionNode(conf, Seq(node1, node2)),
|
||||
testData.unionAll(testData).collect()
|
||||
)
|
||||
private def testUnion(inputData: Seq[Array[(Int, Int)]]): Unit = {
|
||||
val inputNodes = inputData.map { data =>
|
||||
new DummyNode(kvIntAttributes, data)
|
||||
}
|
||||
val unionNode = new UnionNode(conf, inputNodes)
|
||||
val expectedOutput = inputData.flatten
|
||||
val actualOutput = unionNode.collect().map { case row =>
|
||||
(row.getInt(0), row.getInt(1))
|
||||
}
|
||||
assert(actualOutput === expectedOutput)
|
||||
}
|
||||
|
||||
test("empty") {
|
||||
checkAnswer2(
|
||||
emptyTestData,
|
||||
emptyTestData,
|
||||
(node1, node2) => UnionNode(conf, Seq(node1, node2)),
|
||||
emptyTestData.unionAll(emptyTestData).collect()
|
||||
)
|
||||
testUnion(Seq(Array.empty))
|
||||
testUnion(Seq(Array.empty, Array.empty))
|
||||
}
|
||||
|
||||
test("complicated union") {
|
||||
val dfs = Seq(testData, emptyTestData, emptyTestData, testData, testData, emptyTestData,
|
||||
emptyTestData, emptyTestData, testData, emptyTestData)
|
||||
doCheckAnswer(
|
||||
dfs,
|
||||
nodes => UnionNode(conf, nodes),
|
||||
dfs.reduce(_.unionAll(_)).collect()
|
||||
)
|
||||
test("self") {
|
||||
val data = (1 to 100).map { i => (i, i) }.toArray
|
||||
testUnion(Seq(data))
|
||||
testUnion(Seq(data, data))
|
||||
testUnion(Seq(data, data, data))
|
||||
}
|
||||
|
||||
test("basic") {
|
||||
val zero = Array.empty[(Int, Int)]
|
||||
val one = (1 to 100).map { i => (i, i) }.toArray
|
||||
val two = (50 to 150).map { i => (i, i) }.toArray
|
||||
val three = (800 to 900).map { i => (i, i) }.toArray
|
||||
testUnion(Seq(zero, one, two, three))
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue