[SPARK-10477][SQL] using DSL in ColumnPruningSuite to improve readability

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #8645 from cloud-fan/test.
This commit is contained in:
Wenchen Fan 2015-12-15 18:29:19 -08:00 committed by Andrew Or
parent c5b6b398d5
commit a89e8b6122
2 changed files with 27 additions and 21 deletions

View file

@ -275,13 +275,14 @@ package object dsl {
def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan)
// TODO specify the output column names
def generate(
generator: Generator,
join: Boolean = false,
outer: Boolean = false,
alias: Option[String] = None): LogicalPlan =
Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan)
alias: Option[String] = None,
outputNames: Seq[String] = Nil): LogicalPlan =
Generate(generator, join = join, outer = outer, alias,
outputNames.map(UnresolvedAttribute(_)), logicalPlan)
def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
InsertIntoTable(

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions.Explode
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
@ -35,12 +35,11 @@ class ColumnPruningSuite extends PlanTest {
test("Column pruning for Generate when Generate.join = false") {
val input = LocalRelation('a.int, 'b.array(StringType))
val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze
val query = input.generate(Explode('b), join = false).analyze
val optimized = Optimize.execute(query)
val correctAnswer =
Generate(Explode('b), false, false, None, 's.string :: Nil,
Project('b.attr :: Nil, input)).analyze
val correctAnswer = input.select('b).generate(Explode('b), join = false).analyze
comparePlans(optimized, correctAnswer)
}
@ -49,16 +48,19 @@ class ColumnPruningSuite extends PlanTest {
val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))
val query =
Project(Seq('a, 's),
Generate(Explode('c), true, false, None, 's.string :: Nil,
input)).analyze
input
.generate(Explode('c), join = true, outputNames = "explode" :: Nil)
.select('a, 'explode)
.analyze
val optimized = Optimize.execute(query)
val correctAnswer =
Project(Seq('a, 's),
Generate(Explode('c), true, false, None, 's.string :: Nil,
Project(Seq('a, 'c),
input))).analyze
input
.select('a, 'c)
.generate(Explode('c), join = true, outputNames = "explode" :: Nil)
.select('a, 'explode)
.analyze
comparePlans(optimized, correctAnswer)
}
@ -67,15 +69,18 @@ class ColumnPruningSuite extends PlanTest {
val input = LocalRelation('b.array(StringType))
val query =
Project(('s + 1).as("s+1") :: Nil,
Generate(Explode('b), true, false, None, 's.string :: Nil,
input)).analyze
input
.generate(Explode('b), join = true, outputNames = "explode" :: Nil)
.select(('explode + 1).as("result"))
.analyze
val optimized = Optimize.execute(query)
val correctAnswer =
Project(('s + 1).as("s+1") :: Nil,
Generate(Explode('b), false, false, None, 's.string :: Nil,
input)).analyze
input
.generate(Explode('b), join = false, outputNames = "explode" :: Nil)
.select(('explode + 1).as("result"))
.analyze
comparePlans(optimized, correctAnswer)
}