[SPARK-16355][SPARK-16354][SQL] Fix Bugs When LIMIT/TABLESAMPLE is Non-foldable, Zero or Negative

#### What changes were proposed in this pull request?
**Issue 1:** When a query containing LIMIT/TABLESAMPLE 0, the statistics could be zero. Results are correct but it could cause a huge performance regression. For example,
```Scala
Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
  .createOrReplaceTempView("test")
val df1 = spark.table("test")
val df2 = spark.table("test").limit(0)
val df = df1.join(df2, Seq("k"), "left")
```
The statistics of both `df` and `df2` are zero. The statistics values should never be zero; otherwise `sizeInBytes` of `BinaryNode` will also be zero (product of children). This PR is to increase it to `1` when the num of rows is equal to 0.

**Issue 2:** When a query containing negative LIMIT/TABLESAMPLE, we should issue exceptions. Negative values could break the implementation assumption of multiple parts. For example, statistics calculation.  Below is the example query.
```SQL
SELECT * FROM testData TABLESAMPLE (-1 rows)
SELECT * FROM testData LIMIT -1
```
This PR is to issue an appropriate exception in this case.

**Issue 3:** Spark SQL follows the restriction of LIMIT clause in Hive. The argument to the LIMIT clause must evaluate to a constant value. It can be a numeric literal, or another kind of numeric expression involving operators, casts, and function return values. You cannot refer to a column or use a subquery. Currently, we do not detect whether the expression in LIMIT clause is foldable or not. If non-foldable, we might issue a strange error message. For example,
```SQL
SELECT * FROM testData LIMIT rand() > 0.2
```
Then, a misleading error message is issued, like
```
assertion failed: No plan for GlobalLimit (_nondeterministic#203 > 0.2)
+- Project [key#11, value#12, rand(-1441968339187861415) AS _nondeterministic#203]
   +- LocalLimit (_nondeterministic#202 > 0.2)
      +- Project [key#11, value#12, rand(-1308350387169017676) AS _nondeterministic#202]
         +- LogicalRDD [key#11, value#12]

java.lang.AssertionError: assertion failed: No plan for GlobalLimit (_nondeterministic#203 > 0.2)
+- Project [key#11, value#12, rand(-1441968339187861415) AS _nondeterministic#203]
   +- LocalLimit (_nondeterministic#202 > 0.2)
      +- Project [key#11, value#12, rand(-1308350387169017676) AS _nondeterministic#202]
         +- LogicalRDD [key#11, value#12]
```
This PR detects it and then issues a meaningful error message.

#### How was this patch tested?
Added test cases.

Author: gatorsmile <gatorsmile@gmail.com>

Closes #14034 from gatorsmile/limit.
This commit is contained in:
gatorsmile 2016-07-11 16:21:13 +08:00 committed by Wenchen Fan
parent 82f0874453
commit e226278941
5 changed files with 118 additions and 4 deletions

View file

@ -46,6 +46,21 @@ trait CheckAnalysis extends PredicateHelper {
}).length > 1
}
private def checkLimitClause(limitExpr: Expression): Unit = {
limitExpr match {
case e if !e.foldable => failAnalysis(
"The limit expression must evaluate to a constant value, but got " +
limitExpr.sql)
case e if e.dataType != IntegerType => failAnalysis(
s"The limit expression must be integer type, but got " +
e.dataType.simpleString)
case e if e.eval().asInstanceOf[Int] < 0 => failAnalysis(
"The limit expression must be equal to or greater than 0, but got " +
e.eval().asInstanceOf[Int])
case e => // OK
}
}
def checkAnalysis(plan: LogicalPlan): Unit = {
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
@ -251,6 +266,10 @@ trait CheckAnalysis extends PredicateHelper {
s"but one table has '${firstError.output.length}' columns and another table has " +
s"'${s.children.head.output.length}' columns")
case GlobalLimit(limitExpr, _) => checkLimitClause(limitExpr)
case LocalLimit(limitExpr, _) => checkLimitClause(limitExpr)
case p if p.expressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) =>
p match {
case _: Filter | _: Aggregate | _: Project => // Ok

View file

@ -660,7 +660,13 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
}
override lazy val statistics: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
val sizeInBytes = if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
1
} else {
(limit: Long) * output.map(a => a.dataType.defaultSize).sum
}
child.statistics.copy(sizeInBytes = sizeInBytes)
}
}
@ -675,7 +681,13 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
}
override lazy val statistics: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
val sizeInBytes = if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
1
} else {
(limit: Long) * output.map(a => a.dataType.defaultSize).sum
}
child.statistics.copy(sizeInBytes = sizeInBytes)
}
}

View file

@ -352,6 +352,12 @@ class AnalysisErrorSuite extends AnalysisTest {
"Generators are not supported outside the SELECT clause, but got: Sort" :: Nil
)
errorTest(
"num_rows in limit clause must be equal to or greater than 0",
listRelation.limit(-1),
"The limit expression must be equal to or greater than 0, but got -1" :: Nil
)
errorTest(
"more than one generators in SELECT",
listRelation.select(Explode('list), Explode('list)),

View file

@ -660,11 +660,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("limit") {
checkAnswer(
sql("SELECT * FROM testData LIMIT 10"),
sql("SELECT * FROM testData LIMIT 9 + 1"),
testData.take(10).toSeq)
checkAnswer(
sql("SELECT * FROM arrayData LIMIT 1"),
sql("SELECT * FROM arrayData LIMIT CAST(1 AS Integer)"),
arrayData.collect().take(1).map(Row.fromTuple).toSeq)
checkAnswer(
@ -672,6 +672,39 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
mapData.collect().take(1).map(Row.fromTuple).toSeq)
}
test("non-foldable expressions in LIMIT") {
val e = intercept[AnalysisException] {
sql("SELECT * FROM testData LIMIT key > 3")
}.getMessage
assert(e.contains("The limit expression must evaluate to a constant value, " +
"but got (testdata.`key` > 3)"))
}
test("Expressions in limit clause are not integer") {
var e = intercept[AnalysisException] {
sql("SELECT * FROM testData LIMIT true")
}.getMessage
assert(e.contains("The limit expression must be integer type, but got boolean"))
e = intercept[AnalysisException] {
sql("SELECT * FROM testData LIMIT 'a'")
}.getMessage
assert(e.contains("The limit expression must be integer type, but got string"))
}
test("negative in LIMIT or TABLESAMPLE") {
val expected = "The limit expression must be equal to or greater than 0, but got -1"
var e = intercept[AnalysisException] {
sql("SELECT * FROM testData TABLESAMPLE (-1 rows)")
}.getMessage
assert(e.contains(expected))
e = intercept[AnalysisException] {
sql("SELECT * FROM testData LIMIT -1")
}.getMessage
assert(e.contains(expected))
}
test("CTE feature") {
checkAnswer(
sql("with q1 as (select * from testData limit 10) select * from q1"),

View file

@ -17,10 +17,12 @@
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
class StatisticsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("SPARK-15392: DataFrame created from RDD should not be broadcasted") {
val rdd = sparkContext.range(1, 100).map(i => Row(i, i))
@ -31,4 +33,46 @@ class StatisticsSuite extends QueryTest with SharedSQLContext {
spark.sessionState.conf.autoBroadcastJoinThreshold)
}
test("estimates the size of limit") {
withTempTable("test") {
Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
.createOrReplaceTempView("test")
Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) =>
val df = sql(s"""SELECT * FROM test limit $limit""")
val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit =>
g.statistics.sizeInBytes
}
assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
assert(sizesGlobalLimit.head === BigInt(expected),
s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}")
val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit =>
l.statistics.sizeInBytes
}
assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}")
assert(sizesLocalLimit.head === BigInt(expected),
s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}")
}
}
}
test("estimates the size of a limit 0 on outer join") {
withTempTable("test") {
Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
.createOrReplaceTempView("test")
val df1 = spark.table("test")
val df2 = spark.table("test").limit(0)
val df = df1.join(df2, Seq("k"), "left")
val sizes = df.queryExecution.analyzed.collect { case g: Join =>
g.statistics.sizeInBytes
}
assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}")
assert(sizes.head === BigInt(96),
s"expected exact size 96 for table 'test', got: ${sizes.head}")
}
}
}