[SPARK-13263][SQL] SQL Generation Support for Tablesample

In the parser, tableSample clause is part of tableSource.
```
tableSource
init { gParent.pushMsg("table source", state); }
after { gParent.popMsg(state); }
    : tabname=tableName
    ((tableProperties) => props=tableProperties)?
    ((tableSample) => ts=tableSample)?
    ((KW_AS) => (KW_AS alias=Identifier)
    |
    (Identifier) => (alias=Identifier))?
    -> ^(TOK_TABREF $tabname $props? $ts? $alias?)
    ;
```

Two typical query samples using TABLESAMPLE are:
```
    "SELECT s.id FROM t0 TABLESAMPLE(10 PERCENT) s"
    "SELECT * FROM t0 TABLESAMPLE(0.1 PERCENT)"
```

FYI, the logical plan of a TABLESAMPLE query:
```
sql("SELECT * FROM t0 TABLESAMPLE(0.1 PERCENT)").explain(true)

== Analyzed Logical Plan ==
id: bigint
Project [id#16L]
+- Sample 0.0, 0.001, false, 381
   +- Subquery t0
      +- Relation[id#16L] ParquetRelation
```

Thanks! cc liancheng

Author: gatorsmile <gatorsmile@gmail.com>
Author: xiaoli <lixiao1983@gmail.com>
Author: Xiao Li <xiaoli@Xiaos-MacBook-Pro.local>

This patch had conflicts when merged, resolved by
Committer: Cheng Lian <lian@databricks.com>

Closes #11148 from gatorsmile/tablesplitsample.
This commit is contained in:
gatorsmile 2016-02-23 16:13:09 +08:00 committed by Cheng Lian
parent 5cd3e6f60b
commit 87250580f2
8 changed files with 103 additions and 36 deletions

View file

@ -499,12 +499,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
s"Sampling fraction ($fraction) must be on interval [0, 100]")
Sample(0.0, fraction.toDouble / 100, withReplacement = false,
(math.random * 1000).toInt,
relation)
relation)(
isTableSample = true)
case Token("TOK_TABLEBUCKETSAMPLE",
Token(numerator, Nil) ::
Token(denominator, Nil) :: Nil) =>
val fraction = numerator.toDouble / denominator.toDouble
Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)
Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)(
isTableSample = true)
case a =>
noParseRule("Sampling", a)
}.getOrElse(relation)

View file

@ -123,7 +123,7 @@ object SamplePushDown extends Rule[LogicalPlan] {
// Push down projection into sample
case Project(projectList, s @ Sample(lb, up, replace, seed, child)) =>
Sample(lb, up, replace, seed,
Project(projectList, child))
Project(projectList, child))()
}
}

View file

@ -637,15 +637,18 @@ case class SubqueryAlias(alias: String, child: LogicalPlan) extends UnaryNode {
* @param withReplacement Whether to sample with replacement.
* @param seed the random seed
* @param child the LogicalPlan
* @param isTableSample Is created from TABLESAMPLE in the parser.
*/
case class Sample(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
child: LogicalPlan) extends UnaryNode {
child: LogicalPlan)(
val isTableSample: java.lang.Boolean = false) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil
}
/**

View file

@ -640,7 +640,7 @@ class FilterPushdownSuite extends PlanTest {
test("push project and filter down into sample") {
val x = testRelation.subquery('x)
val originalQuery =
Sample(0.0, 0.6, false, 11L, x).select('a)
Sample(0.0, 0.6, false, 11L, x)().select('a)
val originalQueryAnalyzed =
EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(originalQuery))
@ -648,7 +648,7 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQueryAnalyzed)
val correctAnswer =
Sample(0.0, 0.6, false, 11L, x.select('a))
Sample(0.0, 0.6, false, 11L, x.select('a))()
comparePlans(optimized, correctAnswer.analyze)
}

View file

@ -1041,7 +1041,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = withPlan {
Sample(0.0, fraction, withReplacement, seed, logicalPlan)
Sample(0.0, fraction, withReplacement, seed, logicalPlan)()
}
/**
@ -1073,7 +1073,7 @@ class DataFrame private[sql](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted))
new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)())
}.toArray
}

View file

@ -564,7 +564,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] =
withPlan(Sample(0.0, fraction, withReplacement, seed, _))
withPlan(Sample(0.0, fraction, withReplacement, seed, _)())
/**
* Returns a new [[Dataset]] by sampling a fraction of records, using a random seed.

View file

@ -91,6 +91,23 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
case Limit(limitExpr, child) =>
s"${toSQL(child)} LIMIT ${limitExpr.sql}"
case p: Sample if p.isTableSample =>
val fraction = math.min(100, math.max(0, (p.upperBound - p.lowerBound) * 100))
p.child match {
case m: MetastoreRelation =>
val aliasName = m.alias.getOrElse("")
build(
s"`${m.databaseName}`.`${m.tableName}`",
"TABLESAMPLE(" + fraction + " PERCENT)",
aliasName)
case s: SubqueryAlias =>
val aliasName = if (s.child.isInstanceOf[SubqueryAlias]) s.alias else ""
val plan = if (s.child.isInstanceOf[SubqueryAlias]) s.child else s
build(toSQL(plan), "TABLESAMPLE(" + fraction + " PERCENT)", aliasName)
case _ =>
build(toSQL(p.child), "TABLESAMPLE(" + fraction + " PERCENT)")
}
case p: Filter =>
val whereOrHaving = p.child match {
case _: Aggregate => "HAVING"
@ -232,6 +249,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
| OneRowRelation
| _: LocalLimit
| _: GlobalLimit
| _: Sample
) => plan
case plan: Project =>

View file

@ -26,25 +26,32 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
import testImplicits._
protected override def beforeAll(): Unit = {
sql("DROP TABLE IF EXISTS parquet_t0")
sql("DROP TABLE IF EXISTS parquet_t1")
sql("DROP TABLE IF EXISTS parquet_t2")
sql("DROP TABLE IF EXISTS t0")
sql("DROP TABLE IF EXISTS t1")
sql("DROP TABLE IF EXISTS t2")
sqlContext.range(10).write.saveAsTable("t0")
sqlContext.range(10).write.saveAsTable("parquet_t0")
sql("CREATE TABLE t0 AS SELECT * FROM parquet_t0")
sqlContext
.range(10)
.select('id as 'key, concat(lit("val_"), 'id) as 'value)
.write
.saveAsTable("t1")
.saveAsTable("parquet_t1")
sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2")
sqlContext
.range(10)
.select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd)
.write
.saveAsTable("parquet_t2")
}
override protected def afterAll(): Unit = {
sql("DROP TABLE IF EXISTS parquet_t0")
sql("DROP TABLE IF EXISTS parquet_t1")
sql("DROP TABLE IF EXISTS parquet_t2")
sql("DROP TABLE IF EXISTS t0")
sql("DROP TABLE IF EXISTS t1")
sql("DROP TABLE IF EXISTS t2")
}
private def checkHiveQl(hiveQl: String): Unit = {
@ -83,7 +90,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
}
test("in") {
checkHiveQl("SELECT id FROM t0 WHERE id IN (1, 2, 3)")
checkHiveQl("SELECT id FROM parquet_t0 WHERE id IN (1, 2, 3)")
}
test("not in") {
@ -95,11 +102,11 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
}
test("aggregate function in having clause") {
checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key HAVING MAX(key) > 0")
checkHiveQl("SELECT COUNT(value) FROM parquet_t1 GROUP BY key HAVING MAX(key) > 0")
}
test("aggregate function in order by clause") {
checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY MAX(key)")
checkHiveQl("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY MAX(key)")
}
// When there are multiple aggregate functions in ORDER BY clause, all of them are extracted into
@ -107,11 +114,11 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
// execution since these aliases have different expression ID. But this introduces name collision
// when converting resolved plans back to SQL query strings as expression IDs are stripped.
test("aggregate function in order by clause with multiple order keys") {
checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY key, MAX(key)")
checkHiveQl("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY key, MAX(key)")
}
test("type widening in union") {
checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0")
checkHiveQl("SELECT id FROM parquet_t0 UNION ALL SELECT CAST(id AS INT) AS id FROM parquet_t0")
}
test("union distinct") {
@ -124,9 +131,15 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
// UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`))
// UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1
test("three-child union") {
checkHiveQl("SELECT id FROM t0 UNION ALL SELECT id FROM t0 UNION ALL SELECT id FROM t0")
checkHiveQl(
"""
|SELECT id FROM parquet_t0
|UNION ALL SELECT id FROM parquet_t0
|UNION ALL SELECT id FROM parquet_t0
""".stripMargin)
}
test("intersect") {
checkHiveQl("SELECT * FROM t0 INTERSECT SELECT * FROM t0")
}
@ -136,59 +149,90 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
}
test("self join") {
checkHiveQl("SELECT x.key FROM t1 x JOIN t1 y ON x.key = y.key")
checkHiveQl("SELECT x.key FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key")
}
test("self join with group by") {
checkHiveQl("SELECT x.key, COUNT(*) FROM t1 x JOIN t1 y ON x.key = y.key group by x.key")
checkHiveQl(
"SELECT x.key, COUNT(*) FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key group by x.key")
}
test("case") {
checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM t0")
checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM parquet_t0")
}
test("case with else") {
checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM t0")
checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM parquet_t0")
}
test("case with key") {
checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM t0")
checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM parquet_t0")
}
test("case with key and else") {
checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM t0")
checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM parquet_t0")
}
test("select distinct without aggregate functions") {
checkHiveQl("SELECT DISTINCT id FROM t0")
checkHiveQl("SELECT DISTINCT id FROM parquet_t0")
}
test("cluster by") {
checkHiveQl("SELECT id FROM t0 CLUSTER BY id")
checkHiveQl("SELECT id FROM parquet_t0 CLUSTER BY id")
}
test("distribute by") {
checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id")
checkHiveQl("SELECT id FROM parquet_t0 DISTRIBUTE BY id")
}
test("distribute by with sort by") {
checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id SORT BY id")
checkHiveQl("SELECT id FROM parquet_t0 DISTRIBUTE BY id SORT BY id")
}
test("distinct aggregation") {
checkHiveQl("SELECT COUNT(DISTINCT id) FROM t0")
checkHiveQl("SELECT COUNT(DISTINCT id) FROM parquet_t0")
}
test("TABLESAMPLE") {
// Project [id#2L]
// +- Sample 0.0, 1.0, false, ...
// +- Subquery s
// +- Subquery parquet_t0
// +- Relation[id#2L] ParquetRelation
checkHiveQl("SELECT s.id FROM parquet_t0 TABLESAMPLE(100 PERCENT) s")
// Project [id#2L]
// +- Sample 0.0, 1.0, false, ...
// +- Subquery parquet_t0
// +- Relation[id#2L] ParquetRelation
checkHiveQl("SELECT * FROM parquet_t0 TABLESAMPLE(100 PERCENT)")
// Project [id#21L]
// +- Sample 0.0, 1.0, false, ...
// +- MetastoreRelation default, t0, Some(s)
checkHiveQl("SELECT s.id FROM t0 TABLESAMPLE(100 PERCENT) s")
// Project [id#24L]
// +- Sample 0.0, 1.0, false, ...
// +- MetastoreRelation default, t0, None
checkHiveQl("SELECT * FROM t0 TABLESAMPLE(100 PERCENT)")
// When a sampling fraction is not 100%, the returned results are random.
// Thus, added an always-false filter here to check if the generated plan can be successfully
// executed.
checkHiveQl("SELECT s.id FROM parquet_t0 TABLESAMPLE(0.1 PERCENT) s WHERE 1=0")
checkHiveQl("SELECT * FROM parquet_t0 TABLESAMPLE(0.1 PERCENT) WHERE 1=0")
}
// TODO Enable this
// Query plans transformed by DistinctAggregationRewriter are not recognized yet
ignore("multi-distinct columns") {
checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM t2 GROUP BY a")
checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM parquet_t2 GROUP BY a")
}
test("persisted data source relations") {
Seq("orc", "json", "parquet").foreach { format =>
val tableName = s"${format}_t0"
val tableName = s"${format}_parquet_t0"
withTable(tableName) {
sqlContext.range(10).write.format(format).saveAsTable(tableName)
checkHiveQl(s"SELECT id FROM $tableName")