[SPARK-34086][SQL] RaiseError generates too much code and may fails codegen in length check for char varchar
### What changes were proposed in this pull request? https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/133928/testReport/org.apache.spark.sql.execution/LogicalPlanTagInSparkPlanSuite/q41/ We can reduce more than 8000 bytes by removing the unnecessary CONCAT expression. W/ this fix, for q41 in TPCDS with [Using TPCDS original definitions for char/varchar columns](https://github.com/apache/spark/pull/31012) applied, we can reduce the stage code-gen size from 22523 to 14369 ``` 14369 - 22523 = - 8154 ``` ### Why are the changes needed? fix the perf regression(we need other improvements for q41 works), there will be a huge performance regression if codegen fails ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? modified uts Closes #31150 from yaooqinn/SPARK-34086. Authored-by: Kent Yao <yao@apache.org> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
861f8bb5fb
commit
04f031acb3
|
@ -66,11 +66,13 @@ case class PrintToStderr(child: Expression) extends UnaryExpression {
|
|||
""",
|
||||
since = "3.1.0",
|
||||
group = "misc_funcs")
|
||||
case class RaiseError(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
|
||||
case class RaiseError(child: Expression, dataType: DataType)
|
||||
extends UnaryExpression with ImplicitCastInputTypes {
|
||||
|
||||
def this(child: Expression) = this(child, NullType)
|
||||
|
||||
override def foldable: Boolean = false
|
||||
override def nullable: Boolean = true
|
||||
override def dataType: DataType = NullType
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
|
||||
|
||||
override def prettyName: String = "raise_error"
|
||||
|
@ -100,6 +102,10 @@ case class RaiseError(child: Expression) extends UnaryExpression with ImplicitCa
|
|||
}
|
||||
}
|
||||
|
||||
object RaiseError {
|
||||
def apply(child: Expression): RaiseError = new RaiseError(child)
|
||||
}
|
||||
|
||||
/**
|
||||
* A function that throws an exception if 'condition' is not true.
|
||||
*/
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
|
|||
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
object CharVarcharUtils extends Logging {
|
||||
|
||||
|
@ -202,12 +203,9 @@ object CharVarcharUtils extends Logging {
|
|||
}.getOrElse(expr)
|
||||
}
|
||||
|
||||
private def raiseError(expr: Expression, typeName: String, length: Int): Expression = {
|
||||
val errorMsg = Concat(Seq(
|
||||
Literal("input string of length "),
|
||||
Cast(Length(expr), StringType),
|
||||
Literal(s" exceeds $typeName type length limitation: $length")))
|
||||
Cast(RaiseError(errorMsg), StringType)
|
||||
private def raiseError(typeName: String, length: Int): Expression = {
|
||||
val errMsg = UTF8String.fromString(s"Exceeds $typeName type length limitation: $length")
|
||||
RaiseError(Literal(errMsg, StringType), StringType)
|
||||
}
|
||||
|
||||
private def stringLengthCheck(expr: Expression, dt: DataType): Expression = dt match {
|
||||
|
@ -217,7 +215,7 @@ object CharVarcharUtils extends Logging {
|
|||
// spaces, as we will pad char type columns/fields at read time.
|
||||
If(
|
||||
GreaterThan(Length(trimmed), Literal(length)),
|
||||
raiseError(expr, "char", length),
|
||||
raiseError("char", length),
|
||||
trimmed)
|
||||
|
||||
case VarcharType(length) =>
|
||||
|
@ -230,7 +228,7 @@ object CharVarcharUtils extends Logging {
|
|||
expr,
|
||||
If(
|
||||
GreaterThan(Length(trimmed), Literal(length)),
|
||||
raiseError(expr, "varchar", length),
|
||||
raiseError("varchar", length),
|
||||
StringRPad(trimmed, Literal(length))))
|
||||
|
||||
case StructType(fields) =>
|
||||
|
|
|
@ -189,8 +189,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
|
|||
sql("INSERT INTO t VALUES (null)")
|
||||
checkAnswer(spark.table("t"), Row(null))
|
||||
val e = intercept[SparkException](sql("INSERT INTO t VALUES ('123456')"))
|
||||
assert(e.getCause.getMessage.contains(
|
||||
s"input string of length 6 exceeds $typeName type length limitation: 5"))
|
||||
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -202,8 +201,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
|
|||
sql("INSERT INTO t VALUES (1, null)")
|
||||
checkAnswer(spark.table("t"), Row(1, null))
|
||||
val e = intercept[SparkException](sql("INSERT INTO t VALUES (1, '123456')"))
|
||||
assert(e.getCause.getMessage.contains(
|
||||
s"input string of length 6 exceeds $typeName type length limitation: 5"))
|
||||
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -214,8 +212,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
|
|||
sql("INSERT INTO t SELECT struct(null)")
|
||||
checkAnswer(spark.table("t"), Row(Row(null)))
|
||||
val e = intercept[SparkException](sql("INSERT INTO t SELECT struct('123456')"))
|
||||
assert(e.getCause.getMessage.contains(
|
||||
s"input string of length 6 exceeds $typeName type length limitation: 5"))
|
||||
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -225,8 +222,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
|
|||
sql("INSERT INTO t VALUES (array(null))")
|
||||
checkAnswer(spark.table("t"), Row(Seq(null)))
|
||||
val e = intercept[SparkException](sql("INSERT INTO t VALUES (array('a', '123456'))"))
|
||||
assert(e.getCause.getMessage.contains(
|
||||
s"input string of length 6 exceeds $typeName type length limitation: 5"))
|
||||
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -234,8 +230,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
|
|||
testTableWrite { typeName =>
|
||||
sql(s"CREATE TABLE t(c MAP<$typeName(5), STRING>) USING $format")
|
||||
val e = intercept[SparkException](sql("INSERT INTO t VALUES (map('123456', 'a'))"))
|
||||
assert(e.getCause.getMessage.contains(
|
||||
s"input string of length 6 exceeds $typeName type length limitation: 5"))
|
||||
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -245,8 +240,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
|
|||
sql("INSERT INTO t VALUES (map('a', null))")
|
||||
checkAnswer(spark.table("t"), Row(Map("a" -> null)))
|
||||
val e = intercept[SparkException](sql("INSERT INTO t VALUES (map('a', '123456'))"))
|
||||
assert(e.getCause.getMessage.contains(
|
||||
s"input string of length 6 exceeds $typeName type length limitation: 5"))
|
||||
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -254,11 +248,9 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
|
|||
testTableWrite { typeName =>
|
||||
sql(s"CREATE TABLE t(c MAP<$typeName(5), $typeName(5)>) USING $format")
|
||||
val e1 = intercept[SparkException](sql("INSERT INTO t VALUES (map('123456', 'a'))"))
|
||||
assert(e1.getCause.getMessage.contains(
|
||||
s"input string of length 6 exceeds $typeName type length limitation: 5"))
|
||||
assert(e1.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
|
||||
val e2 = intercept[SparkException](sql("INSERT INTO t VALUES (map('a', '123456'))"))
|
||||
assert(e2.getCause.getMessage.contains(
|
||||
s"input string of length 6 exceeds $typeName type length limitation: 5"))
|
||||
assert(e2.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -268,8 +260,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
|
|||
sql("INSERT INTO t SELECT struct(array(null))")
|
||||
checkAnswer(spark.table("t"), Row(Row(Seq(null))))
|
||||
val e = intercept[SparkException](sql("INSERT INTO t SELECT struct(array('123456'))"))
|
||||
assert(e.getCause.getMessage.contains(
|
||||
s"input string of length 6 exceeds $typeName type length limitation: 5"))
|
||||
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -279,8 +270,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
|
|||
sql("INSERT INTO t VALUES (array(struct(null)))")
|
||||
checkAnswer(spark.table("t"), Row(Seq(Row(null))))
|
||||
val e = intercept[SparkException](sql("INSERT INTO t VALUES (array(struct('123456')))"))
|
||||
assert(e.getCause.getMessage.contains(
|
||||
s"input string of length 6 exceeds $typeName type length limitation: 5"))
|
||||
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -290,8 +280,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
|
|||
sql("INSERT INTO t VALUES (array(array(null)))")
|
||||
checkAnswer(spark.table("t"), Row(Seq(Seq(null))))
|
||||
val e = intercept[SparkException](sql("INSERT INTO t VALUES (array(array('123456')))"))
|
||||
assert(e.getCause.getMessage.contains(
|
||||
s"input string of length 6 exceeds $typeName type length limitation: 5"))
|
||||
assert(e.getCause.getMessage.contains(s"Exceeds $typeName type length limitation: 5"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -312,11 +301,9 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils {
|
|||
sql("INSERT INTO t VALUES (1234, 1234)")
|
||||
checkAnswer(spark.table("t"), Row("1234 ", "1234"))
|
||||
val e1 = intercept[SparkException](sql("INSERT INTO t VALUES (123456, 1)"))
|
||||
assert(e1.getCause.getMessage.contains(
|
||||
"input string of length 6 exceeds char type length limitation: 5"))
|
||||
assert(e1.getCause.getMessage.contains("Exceeds char type length limitation: 5"))
|
||||
val e2 = intercept[SparkException](sql("INSERT INTO t VALUES (1, 123456)"))
|
||||
assert(e2.getCause.getMessage.contains(
|
||||
"input string of length 6 exceeds varchar type length limitation: 5"))
|
||||
assert(e2.getCause.getMessage.contains("Exceeds varchar type length limitation: 5"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -626,8 +613,7 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa
|
|||
sql("SELECT '123456' as col").write.format(format).save(dir.toString)
|
||||
sql(s"CREATE TABLE t (col $typ(2)) using $format LOCATION '$dir'")
|
||||
val e = intercept[SparkException] { sql("select * from t").collect() }
|
||||
assert(e.getCause.getMessage.contains(
|
||||
s"input string of length 6 exceeds $typ type length limitation: 2"))
|
||||
assert(e.getCause.getMessage.contains(s"Exceeds $typ type length limitation: 2"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -654,8 +640,7 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa
|
|||
sql(s"CREATE TABLE t (col $typ(2)) using $format")
|
||||
sql(s"ALTER TABLE t SET LOCATION '$dir'")
|
||||
val e = intercept[SparkException] { spark.table("t").collect() }
|
||||
assert(e.getCause.getMessage.contains(
|
||||
s"input string of length 6 exceeds $typ type length limitation: 2"))
|
||||
assert(e.getCause.getMessage.contains(s"Exceeds $typ type length limitation: 2"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue