diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 069c665a39..a0a56d728c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -499,12 +499,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C s"Sampling fraction ($fraction) must be on interval [0, 100]") Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt, - relation) + relation)( + isTableSample = true) case Token("TOK_TABLEBUCKETSAMPLE", Token(numerator, Nil) :: Token(denominator, Nil) :: Nil) => val fraction = numerator.toDouble / denominator.toDouble - Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)( + isTableSample = true) case a => noParseRule("Sampling", a) }.getOrElse(relation) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1554382840..1f05f2065c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -123,7 +123,7 @@ object SamplePushDown extends Rule[LogicalPlan] { // Push down projection into sample case Project(projectList, s @ Sample(lb, up, replace, seed, child)) => Sample(lb, up, replace, seed, - Project(projectList, child)) + Project(projectList, child))() } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 70ecbce829..c98d33d5a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -637,15 +637,18 @@ case class SubqueryAlias(alias: String, child: LogicalPlan) extends UnaryNode { * @param withReplacement Whether to sample with replacement. * @param seed the random seed * @param child the LogicalPlan + * @param isTableSample Is created from TABLESAMPLE in the parser. */ case class Sample( lowerBound: Double, upperBound: Double, withReplacement: Boolean, seed: Long, - child: LogicalPlan) extends UnaryNode { + child: LogicalPlan)( + val isTableSample: java.lang.Boolean = false) extends UnaryNode { override def output: Seq[Attribute] = child.output + override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 7805723ec8..70b34cbb24 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -640,7 +640,7 @@ class FilterPushdownSuite extends PlanTest { test("push project and filter down into sample") { val x = testRelation.subquery('x) val originalQuery = - Sample(0.0, 0.6, false, 11L, x).select('a) + Sample(0.0, 0.6, false, 11L, x)().select('a) val originalQueryAnalyzed = EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(originalQuery)) @@ -648,7 +648,7 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQueryAnalyzed) val correctAnswer = - Sample(0.0, 0.6, false, 11L, x.select('a)) + Sample(0.0, 0.6, false, 11L, x.select('a))() comparePlans(optimized, correctAnswer.analyze) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index e3412f7a2e..f590ac0114 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1041,7 +1041,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = withPlan { - Sample(0.0, fraction, withReplacement, seed, logicalPlan) + Sample(0.0, fraction, withReplacement, seed, logicalPlan)() } /** @@ -1073,7 +1073,7 @@ class DataFrame private[sql]( val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)) + new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)()) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ea7e7255ab..dd1fbcf3c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -564,7 +564,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] = - withPlan(Sample(0.0, fraction, withReplacement, seed, _)) + withPlan(Sample(0.0, fraction, withReplacement, seed, _)()) /** * Returns a new [[Dataset]] by sampling a fraction of records, using a random seed. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index e66cc127ea..5182af9d20 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -91,6 +91,23 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case Limit(limitExpr, child) => s"${toSQL(child)} LIMIT ${limitExpr.sql}" + case p: Sample if p.isTableSample => + val fraction = math.min(100, math.max(0, (p.upperBound - p.lowerBound) * 100)) + p.child match { + case m: MetastoreRelation => + val aliasName = m.alias.getOrElse("") + build( + s"`${m.databaseName}`.`${m.tableName}`", + "TABLESAMPLE(" + fraction + " PERCENT)", + aliasName) + case s: SubqueryAlias => + val aliasName = if (s.child.isInstanceOf[SubqueryAlias]) s.alias else "" + val plan = if (s.child.isInstanceOf[SubqueryAlias]) s.child else s + build(toSQL(plan), "TABLESAMPLE(" + fraction + " PERCENT)", aliasName) + case _ => + build(toSQL(p.child), "TABLESAMPLE(" + fraction + " PERCENT)") + } + case p: Filter => val whereOrHaving = p.child match { case _: Aggregate => "HAVING" @@ -232,6 +249,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi | OneRowRelation | _: LocalLimit | _: GlobalLimit + | _: Sample ) => plan case plan: Project => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 28559eac8d..fa78f5a425 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -26,25 +26,32 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { import testImplicits._ protected override def beforeAll(): Unit = { + sql("DROP TABLE IF EXISTS parquet_t0") + sql("DROP TABLE IF EXISTS parquet_t1") + sql("DROP TABLE IF EXISTS parquet_t2") sql("DROP TABLE IF EXISTS t0") - sql("DROP TABLE IF EXISTS t1") - sql("DROP TABLE IF EXISTS t2") - sqlContext.range(10).write.saveAsTable("t0") + sqlContext.range(10).write.saveAsTable("parquet_t0") + sql("CREATE TABLE t0 AS SELECT * FROM parquet_t0") sqlContext .range(10) .select('id as 'key, concat(lit("val_"), 'id) as 'value) .write - .saveAsTable("t1") + .saveAsTable("parquet_t1") - sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2") + sqlContext + .range(10) + .select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd) + .write + .saveAsTable("parquet_t2") } override protected def afterAll(): Unit = { + sql("DROP TABLE IF EXISTS parquet_t0") + sql("DROP TABLE IF EXISTS parquet_t1") + sql("DROP TABLE IF EXISTS parquet_t2") sql("DROP TABLE IF EXISTS t0") - sql("DROP TABLE IF EXISTS t1") - sql("DROP TABLE IF EXISTS t2") } private def checkHiveQl(hiveQl: String): Unit = { @@ -83,7 +90,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { } test("in") { - checkHiveQl("SELECT id FROM t0 WHERE id IN (1, 2, 3)") + checkHiveQl("SELECT id FROM parquet_t0 WHERE id IN (1, 2, 3)") } test("not in") { @@ -95,11 +102,11 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { } test("aggregate function in having clause") { - checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key HAVING MAX(key) > 0") + checkHiveQl("SELECT COUNT(value) FROM parquet_t1 GROUP BY key HAVING MAX(key) > 0") } test("aggregate function in order by clause") { - checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY MAX(key)") + checkHiveQl("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY MAX(key)") } // When there are multiple aggregate functions in ORDER BY clause, all of them are extracted into @@ -107,11 +114,11 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { // execution since these aliases have different expression ID. But this introduces name collision // when converting resolved plans back to SQL query strings as expression IDs are stripped. test("aggregate function in order by clause with multiple order keys") { - checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY key, MAX(key)") + checkHiveQl("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY key, MAX(key)") } test("type widening in union") { - checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0") + checkHiveQl("SELECT id FROM parquet_t0 UNION ALL SELECT CAST(id AS INT) AS id FROM parquet_t0") } test("union distinct") { @@ -124,9 +131,15 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { // UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) // UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1 test("three-child union") { - checkHiveQl("SELECT id FROM t0 UNION ALL SELECT id FROM t0 UNION ALL SELECT id FROM t0") + checkHiveQl( + """ + |SELECT id FROM parquet_t0 + |UNION ALL SELECT id FROM parquet_t0 + |UNION ALL SELECT id FROM parquet_t0 + """.stripMargin) } + test("intersect") { checkHiveQl("SELECT * FROM t0 INTERSECT SELECT * FROM t0") } @@ -136,59 +149,90 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { } test("self join") { - checkHiveQl("SELECT x.key FROM t1 x JOIN t1 y ON x.key = y.key") + checkHiveQl("SELECT x.key FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key") } test("self join with group by") { - checkHiveQl("SELECT x.key, COUNT(*) FROM t1 x JOIN t1 y ON x.key = y.key group by x.key") + checkHiveQl( + "SELECT x.key, COUNT(*) FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key group by x.key") } - test("case") { - checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM t0") + checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM parquet_t0") } test("case with else") { - checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM t0") + checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM parquet_t0") } test("case with key") { - checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM t0") + checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM parquet_t0") } test("case with key and else") { - checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM t0") + checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM parquet_t0") } test("select distinct without aggregate functions") { - checkHiveQl("SELECT DISTINCT id FROM t0") + checkHiveQl("SELECT DISTINCT id FROM parquet_t0") } test("cluster by") { - checkHiveQl("SELECT id FROM t0 CLUSTER BY id") + checkHiveQl("SELECT id FROM parquet_t0 CLUSTER BY id") } test("distribute by") { - checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id") + checkHiveQl("SELECT id FROM parquet_t0 DISTRIBUTE BY id") } test("distribute by with sort by") { - checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id SORT BY id") + checkHiveQl("SELECT id FROM parquet_t0 DISTRIBUTE BY id SORT BY id") } test("distinct aggregation") { - checkHiveQl("SELECT COUNT(DISTINCT id) FROM t0") + checkHiveQl("SELECT COUNT(DISTINCT id) FROM parquet_t0") + } + + test("TABLESAMPLE") { + // Project [id#2L] + // +- Sample 0.0, 1.0, false, ... + // +- Subquery s + // +- Subquery parquet_t0 + // +- Relation[id#2L] ParquetRelation + checkHiveQl("SELECT s.id FROM parquet_t0 TABLESAMPLE(100 PERCENT) s") + + // Project [id#2L] + // +- Sample 0.0, 1.0, false, ... + // +- Subquery parquet_t0 + // +- Relation[id#2L] ParquetRelation + checkHiveQl("SELECT * FROM parquet_t0 TABLESAMPLE(100 PERCENT)") + + // Project [id#21L] + // +- Sample 0.0, 1.0, false, ... + // +- MetastoreRelation default, t0, Some(s) + checkHiveQl("SELECT s.id FROM t0 TABLESAMPLE(100 PERCENT) s") + + // Project [id#24L] + // +- Sample 0.0, 1.0, false, ... + // +- MetastoreRelation default, t0, None + checkHiveQl("SELECT * FROM t0 TABLESAMPLE(100 PERCENT)") + + // When a sampling fraction is not 100%, the returned results are random. + // Thus, added an always-false filter here to check if the generated plan can be successfully + // executed. + checkHiveQl("SELECT s.id FROM parquet_t0 TABLESAMPLE(0.1 PERCENT) s WHERE 1=0") + checkHiveQl("SELECT * FROM parquet_t0 TABLESAMPLE(0.1 PERCENT) WHERE 1=0") } // TODO Enable this // Query plans transformed by DistinctAggregationRewriter are not recognized yet ignore("multi-distinct columns") { - checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM t2 GROUP BY a") + checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM parquet_t2 GROUP BY a") } test("persisted data source relations") { Seq("orc", "json", "parquet").foreach { format => - val tableName = s"${format}_t0" + val tableName = s"${format}_parquet_t0" withTable(tableName) { sqlContext.range(10).write.format(format).saveAsTable(tableName) checkHiveQl(s"SELECT id FROM $tableName")