[SPARK-34086][SQL] RaiseError generates too much code and may fails codegen in length check for char varchar

### What changes were proposed in this pull request?

https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/133928/testReport/org.apache.spark.sql.execution/LogicalPlanTagInSparkPlanSuite/q41/

We can reduce more than 8000 bytes by removing the unnecessary CONCAT expression.

W/ this fix, for q41 in TPCDS with [Using TPCDS original definitions for char/varchar columns](https://github.com/apache/spark/pull/31012) applied, we can reduce the stage code-gen size from 22523 to 14369
```
14369  - 22523 = - 8154
```

### Why are the changes needed?

fix the perf regression(we need other improvements for q41 works), there will be a huge performance regression if codegen fails

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

modified uts

Closes #31150 from yaooqinn/SPARK-34086.

Authored-by: Kent Yao <yao@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Kent Yao 2021-01-13 09:52:36 +00:00 committed by Wenchen Fan
parent 861f8bb5fb
commit 04f031acb3
3 changed files with 29 additions and 40 deletions

View file

@ -66,11 +66,13 @@ case class PrintToStderr(child: Expression) extends UnaryExpression {
""",
since = "3.1.0",
group = "misc_funcs")
case class RaiseError(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
case class RaiseError(child: Expression, dataType: DataType)
extends UnaryExpression with ImplicitCastInputTypes {
def this(child: Expression) = this(child, NullType)
override def foldable: Boolean = false
override def nullable: Boolean = true
override def dataType: DataType = NullType
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
override def prettyName: String = "raise_error"
@ -100,6 +102,10 @@ case class RaiseError(child: Expression) extends UnaryExpression with ImplicitCa
}
}
object RaiseError {
def apply(child: Expression): RaiseError = new RaiseError(child)
}
/**
* A function that throws an exception if 'condition' is not true.
*/

View file

@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
object CharVarcharUtils extends Logging {
@ -202,12 +203,9 @@ object CharVarcharUtils extends Logging {
}.getOrElse(expr)
}
private def raiseError(expr: Expression, typeName: String, length: Int): Expression = {
val errorMsg = Concat(Seq(
Literal("input string of length "),
Cast(Length(expr), StringType),
Literal(s" exceeds $typeName type length limitation: $length")))
Cast(RaiseError(errorMsg), StringType)
private def raiseError(typeName: String, length: Int): Expression = {
val errMsg = UTF8String.fromString(s"Exceeds $typeName type length limitation: $length")
RaiseError(Literal(errMsg, StringType), StringType)
}
private def stringLengthCheck(expr: Expression, dt: DataType): Expression = dt match {
@ -217,7 +215,7 @@ object CharVarcharUtils extends Logging {
// spaces, as we will pad char type columns/fields at read time.
If(
GreaterThan(Length(trimmed), Literal(length)),
raiseError(expr, "char", length),
raiseError("char", length),
trimmed)
case VarcharType(length) =>
@ -230,7 +228,7 @@ object CharVarcharUtils extends Logging {
expr,
If(
GreaterThan(Length(trimmed), Literal(length)),
raiseError(expr, "varchar", length),
raiseError("varchar", length),
StringRPad(trimmed, Literal(length))))
case StructType(fields) =>

View file

@ -189,8 +189,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
sql("INSERT INTO t VALUES (null)")
checkAnswer(spark.table("t"), Row(null))
val e = intercept[SparkException](sql("INSERT INTO t VALUES ('123456')"))
assert(e.getCause.getMessage.contains(
s"input string of length 6 exceeds $typeName type length limitation: 5"))
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
}
}
@ -202,8 +201,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
sql("INSERT INTO t VALUES (1, null)")
checkAnswer(spark.table("t"), Row(1, null))
val e = intercept[SparkException](sql("INSERT INTO t VALUES (1, '123456')"))
assert(e.getCause.getMessage.contains(
s"input string of length 6 exceeds $typeName type length limitation: 5"))
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
}
}
}
@ -214,8 +212,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
sql("INSERT INTO t SELECT struct(null)")
checkAnswer(spark.table("t"), Row(Row(null)))
val e = intercept[SparkException](sql("INSERT INTO t SELECT struct('123456')"))
assert(e.getCause.getMessage.contains(
s"input string of length 6 exceeds $typeName type length limitation: 5"))
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
}
}
@ -225,8 +222,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
sql("INSERT INTO t VALUES (array(null))")
checkAnswer(spark.table("t"), Row(Seq(null)))
val e = intercept[SparkException](sql("INSERT INTO t VALUES (array('a', '123456'))"))
assert(e.getCause.getMessage.contains(
s"input string of length 6 exceeds $typeName type length limitation: 5"))
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
}
}
@ -234,8 +230,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
testTableWrite { typeName =>
sql(s"CREATE TABLE t(c MAP<$typeName(5), STRING>) USING $format")
val e = intercept[SparkException](sql("INSERT INTO t VALUES (map('123456', 'a'))"))
assert(e.getCause.getMessage.contains(
s"input string of length 6 exceeds $typeName type length limitation: 5"))
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
}
}
@ -245,8 +240,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
sql("INSERT INTO t VALUES (map('a', null))")
checkAnswer(spark.table("t"), Row(Map("a" -> null)))
val e = intercept[SparkException](sql("INSERT INTO t VALUES (map('a', '123456'))"))
assert(e.getCause.getMessage.contains(
s"input string of length 6 exceeds $typeName type length limitation: 5"))
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
}
}
@ -254,11 +248,9 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
testTableWrite { typeName =>
sql(s"CREATE TABLE t(c MAP<$typeName(5), $typeName(5)>) USING $format")
val e1 = intercept[SparkException](sql("INSERT INTO t VALUES (map('123456', 'a'))"))
assert(e1.getCause.getMessage.contains(
s"input string of length 6 exceeds $typeName type length limitation: 5"))
assert(e1.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
val e2 = intercept[SparkException](sql("INSERT INTO t VALUES (map('a', '123456'))"))
assert(e2.getCause.getMessage.contains(
s"input string of length 6 exceeds $typeName type length limitation: 5"))
assert(e2.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
}
}
@ -268,8 +260,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
sql("INSERT INTO t SELECT struct(array(null))")
checkAnswer(spark.table("t"), Row(Row(Seq(null))))
val e = intercept[SparkException](sql("INSERT INTO t SELECT struct(array('123456'))"))
assert(e.getCause.getMessage.contains(
s"input string of length 6 exceeds $typeName type length limitation: 5"))
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
}
}
@ -279,8 +270,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
sql("INSERT INTO t VALUES (array(struct(null)))")
checkAnswer(spark.table("t"), Row(Seq(Row(null))))
val e = intercept[SparkException](sql("INSERT INTO t VALUES (array(struct('123456')))"))
assert(e.getCause.getMessage.contains(
s"input string of length 6 exceeds $typeName type length limitation: 5"))
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
}
}
@ -290,8 +280,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
sql("INSERT INTO t VALUES (array(array(null)))")
checkAnswer(spark.table("t"), Row(Seq(Seq(null))))
val e = intercept[SparkException](sql("INSERT INTO t VALUES (array(array('123456')))"))
assert(e.getCause.getMessage.contains(
s"input string of length 6 exceeds $typeName type length limitation: 5"))
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
}
}
@ -312,11 +301,9 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
sql("INSERT INTO t VALUES (1234, 1234)")
checkAnswer(spark.table("t"), Row("1234 ", "1234"))
val e1 = intercept[SparkException](sql("INSERT INTO t VALUES (123456, 1)"))
assert(e1.getCause.getMessage.contains(
"input string of length 6 exceeds char type length limitation: 5"))
assert(e1.getCause.getMessage.contains("Exceeds char type length limitation: 5"))
val e2 = intercept[SparkException](sql("INSERT INTO t VALUES (1, 123456)"))
assert(e2.getCause.getMessage.contains(
"input string of length 6 exceeds varchar type length limitation: 5"))
assert(e2.getCause.getMessage.contains("Exceeds varchar type length limitation: 5"))
}
}
@ -626,8 +613,7 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa
sql("SELECT '123456' as col").write.format(format).save(dir.toString)
sql(s"CREATE TABLE t (col $typ(2)) using $format LOCATION '$dir'")
val e = intercept[SparkException] { sql("select * from t").collect() }
assert(e.getCause.getMessage.contains(
s"input string of length 6 exceeds $typ type length limitation: 2"))
assert(e.getCause.getMessage.contains(s"Exceeds $typ type length limitation: 2"))
}
}
}
@ -654,8 +640,7 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa
sql(s"CREATE TABLE t (col $typ(2)) using $format")
sql(s"ALTER TABLE t SET LOCATION '$dir'")
val e = intercept[SparkException] { spark.table("t").collect() }
assert(e.getCause.getMessage.contains(
s"input string of length 6 exceeds $typ type length limitation: 2"))
assert(e.getCause.getMessage.contains(s"Exceeds $typ type length limitation: 2"))
}
}
}