[SPARK-31526][SQL][TESTS] Add a new test suite for ExpressionInfo

### What changes were proposed in this pull request?

This PR intends to add a new test suite for `ExpressionInfo`. Major changes are as follows;

 - Added a new test suite named `ExpressionInfoSuite`
 - To improve test coverage, added a test for error handling in `ExpressionInfoSuite`
 - Moved the `ExpressionInfo`-related tests from `UDFSuite` to `ExpressionInfoSuite`
 - Moved the related tests from `SQLQuerySuite` to `ExpressionInfoSuite`
 - Added a comment in `ExpressionInfoSuite` (followup of https://github.com/apache/spark/pull/28224)

### Why are the changes needed?

To improve test suites/coverage.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Added tests.

Closes #28308 from maropu/SPARK-31526.

Authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
Takeshi Yamamuro 2020-04-24 11:19:20 +09:00 committed by HyukjinKwon
parent f093480af9
commit 42f496f6ac
4 changed files with 162 additions and 111 deletions

View file

@ -103,6 +103,12 @@ public @interface ExpressionDescription {
String arguments() default "";
String examples() default "";
String note() default "";
/**
* Valid group names are almost the same with one defined as `groupname` in
* `sql/functions.scala`. But, `collection_funcs` is split into fine-grained three groups:
* `array_funcs`, `map_funcs`, and `json_funcs`. See `ExpressionInfo` for the
* detailed group names.
*/
String group() default "";
String since() default "";
String deprecated() default "";

View file

@ -22,8 +22,6 @@ import java.net.{MalformedURLException, URL}
import java.sql.{Date, Timestamp}
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.parallel.immutable.ParVector
import org.apache.spark.{AccumulatorSuite, SparkException}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.expressions.GenericRow
@ -31,7 +29,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial}
import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, NestedColumnAliasingSuite}
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.HiveResult.hiveResultString
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
@ -126,83 +123,6 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
}
test("using _FUNC_ instead of function names in examples") {
val exampleRe = "(>.*;)".r
val setStmtRe = "(?i)^(>\\s+set\\s+).+".r
val ignoreSet = Set(
// Examples for CaseWhen show simpler syntax:
// `CASE WHEN ... THEN ... WHEN ... THEN ... END`
"org.apache.spark.sql.catalyst.expressions.CaseWhen",
// _FUNC_ is replaced by `locate` but `locate(... IN ...)` is not supported
"org.apache.spark.sql.catalyst.expressions.StringLocate",
// _FUNC_ is replaced by `%` which causes a parsing error on `SELECT %(2, 1.8)`
"org.apache.spark.sql.catalyst.expressions.Remainder",
// Examples demonstrate alternative names, see SPARK-20749
"org.apache.spark.sql.catalyst.expressions.Length")
spark.sessionState.functionRegistry.listFunction().foreach { funcId =>
val info = spark.sessionState.catalog.lookupFunctionInfo(funcId)
val className = info.getClassName
withClue(s"Expression class '$className'") {
val exprExamples = info.getOriginalExamples
if (!exprExamples.isEmpty && !ignoreSet.contains(className)) {
assert(exampleRe.findAllIn(exprExamples).toIterable
.filter(setStmtRe.findFirstIn(_).isEmpty) // Ignore SET commands
.forall(_.contains("_FUNC_")))
}
}
}
}
test("check outputs of expression examples") {
def unindentAndTrim(s: String): String = {
s.replaceAll("\n\\s+", "\n").trim
}
val beginSqlStmtRe = " > ".r
val endSqlStmtRe = ";\n".r
def checkExampleSyntax(example: String): Unit = {
val beginStmtNum = beginSqlStmtRe.findAllIn(example).length
val endStmtNum = endSqlStmtRe.findAllIn(example).length
assert(beginStmtNum === endStmtNum,
"The number of ` > ` does not match to the number of `;`")
}
val exampleRe = """^(.+);\n(?s)(.+)$""".r
val ignoreSet = Set(
// One of examples shows getting the current timestamp
"org.apache.spark.sql.catalyst.expressions.UnixTimestamp",
// Random output without a seed
"org.apache.spark.sql.catalyst.expressions.Rand",
"org.apache.spark.sql.catalyst.expressions.Randn",
"org.apache.spark.sql.catalyst.expressions.Shuffle",
"org.apache.spark.sql.catalyst.expressions.Uuid",
// The example calls methods that return unstable results.
"org.apache.spark.sql.catalyst.expressions.CallMethodViaReflection")
val parFuncs = new ParVector(spark.sessionState.functionRegistry.listFunction().toVector)
parFuncs.foreach { funcId =>
// Examples can change settings. We clone the session to prevent tests clashing.
val clonedSpark = spark.cloneSession()
// Coalescing partitions can change result order, so disable it.
clonedSpark.sessionState.conf.setConf(SQLConf.COALESCE_PARTITIONS_ENABLED, false)
val info = clonedSpark.sessionState.catalog.lookupFunctionInfo(funcId)
val className = info.getClassName
if (!ignoreSet.contains(className)) {
withClue(s"Function '${info.getName}', Expression class '$className'") {
val example = info.getExamples
checkExampleSyntax(example)
example.split(" > ").toList.foreach(_ match {
case exampleRe(sql, output) =>
val df = clonedSpark.sql(sql)
val actual = unindentAndTrim(
hiveResultString(df.queryExecution.executedPlan).mkString("\n"))
val expected = unindentAndTrim(output)
assert(actual === expected)
case _ =>
})
}
}
}
}
test("SPARK-6743: no columns from cache") {
Seq(
(83, 0, 38),

View file

@ -20,8 +20,6 @@ package org.apache.spark.sql
import java.math.BigDecimal
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.{QueryExecution, SimpleMode}
import org.apache.spark.sql.execution.columnar.InMemoryRelation
@ -534,35 +532,6 @@ class UDFSuite extends QueryTest with SharedSparkSession {
assert(spark.range(2).select(nonDeterministicJavaUDF()).distinct().count() == 2)
}
test("Replace _FUNC_ in UDF ExpressionInfo") {
val info = spark.sessionState.catalog.lookupFunctionInfo(FunctionIdentifier("upper"))
assert(info.getName === "upper")
assert(info.getClassName === "org.apache.spark.sql.catalyst.expressions.Upper")
assert(info.getUsage === "upper(str) - Returns `str` with all characters changed to uppercase.")
assert(info.getExamples.contains("> SELECT upper('SparkSql');"))
assert(info.getSince === "1.0.1")
assert(info.getNote === "")
assert(info.getExtended.contains("> SELECT upper('SparkSql');"))
}
test("group info in ExpressionInfo") {
val info = spark.sessionState.catalog.lookupFunctionInfo(FunctionIdentifier("sum"))
assert(info.getGroup === "agg_funcs")
Seq("agg_funcs", "array_funcs", "datetime_funcs", "json_funcs", "map_funcs", "window_funcs")
.foreach { groupName =>
val info = new ExpressionInfo(
"testClass", null, "testName", null, "", "", "", groupName, "", "")
assert(info.getGroup === groupName)
}
val errMsg = intercept[IllegalArgumentException] {
val invalidGroupName = "invalid_group_funcs"
new ExpressionInfo("testClass", null, "testName", null, "", "", "", invalidGroupName, "", "")
}.getMessage
assert(errMsg.contains("'group' is malformed in the expression [testName]."))
}
test("SPARK-28521 error message for CAST(parameter types contains DataType)") {
val e = intercept[AnalysisException] {
spark.sql("SELECT CAST(1)")

View file

@ -0,0 +1,156 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.expressions
import scala.collection.parallel.immutable.ParVector
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.execution.HiveResult.hiveResultString
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
test("Replace _FUNC_ in ExpressionInfo") {
val info = spark.sessionState.catalog.lookupFunctionInfo(FunctionIdentifier("upper"))
assert(info.getName === "upper")
assert(info.getClassName === "org.apache.spark.sql.catalyst.expressions.Upper")
assert(info.getUsage === "upper(str) - Returns `str` with all characters changed to uppercase.")
assert(info.getExamples.contains("> SELECT upper('SparkSql');"))
assert(info.getSince === "1.0.1")
assert(info.getNote === "")
assert(info.getExtended.contains("> SELECT upper('SparkSql');"))
}
test("group info in ExpressionInfo") {
val info = spark.sessionState.catalog.lookupFunctionInfo(FunctionIdentifier("sum"))
assert(info.getGroup === "agg_funcs")
Seq("agg_funcs", "array_funcs", "datetime_funcs", "json_funcs", "map_funcs", "window_funcs")
.foreach { groupName =>
val info = new ExpressionInfo(
"testClass", null, "testName", null, "", "", "", groupName, "", "")
assert(info.getGroup === groupName)
}
val errMsg = intercept[IllegalArgumentException] {
val invalidGroupName = "invalid_group_funcs"
new ExpressionInfo("testClass", null, "testName", null, "", "", "", invalidGroupName, "", "")
}.getMessage
assert(errMsg.contains("'group' is malformed in the expression [testName]."))
}
test("error handling in ExpressionInfo") {
val errMsg1 = intercept[IllegalArgumentException] {
val invalidNote = " invalid note"
new ExpressionInfo("testClass", null, "testName", null, "", "", invalidNote, "", "", "")
}.getMessage
assert(errMsg1.contains("'note' is malformed in the expression [testName]."))
val errMsg2 = intercept[IllegalArgumentException] {
val invalidSince = "-3.0.0"
new ExpressionInfo("testClass", null, "testName", null, "", "", "", "", invalidSince, "")
}.getMessage
assert(errMsg2.contains("'since' is malformed in the expression [testName]."))
val errMsg3 = intercept[IllegalArgumentException] {
val invalidDeprecated = " invalid deprecated"
new ExpressionInfo("testClass", null, "testName", null, "", "", "", "", "", invalidDeprecated)
}.getMessage
assert(errMsg3.contains("'deprecated' is malformed in the expression [testName]."))
}
test("using _FUNC_ instead of function names in examples") {
val exampleRe = "(>.*;)".r
val setStmtRe = "(?i)^(>\\s+set\\s+).+".r
val ignoreSet = Set(
// Examples for CaseWhen show simpler syntax:
// `CASE WHEN ... THEN ... WHEN ... THEN ... END`
"org.apache.spark.sql.catalyst.expressions.CaseWhen",
// _FUNC_ is replaced by `locate` but `locate(... IN ...)` is not supported
"org.apache.spark.sql.catalyst.expressions.StringLocate",
// _FUNC_ is replaced by `%` which causes a parsing error on `SELECT %(2, 1.8)`
"org.apache.spark.sql.catalyst.expressions.Remainder",
// Examples demonstrate alternative names, see SPARK-20749
"org.apache.spark.sql.catalyst.expressions.Length")
spark.sessionState.functionRegistry.listFunction().foreach { funcId =>
val info = spark.sessionState.catalog.lookupFunctionInfo(funcId)
val className = info.getClassName
withClue(s"Expression class '$className'") {
val exprExamples = info.getOriginalExamples
if (!exprExamples.isEmpty && !ignoreSet.contains(className)) {
assert(exampleRe.findAllIn(exprExamples).toIterable
.filter(setStmtRe.findFirstIn(_).isEmpty) // Ignore SET commands
.forall(_.contains("_FUNC_")))
}
}
}
}
test("check outputs of expression examples") {
def unindentAndTrim(s: String): String = {
s.replaceAll("\n\\s+", "\n").trim
}
val beginSqlStmtRe = " > ".r
val endSqlStmtRe = ";\n".r
def checkExampleSyntax(example: String): Unit = {
val beginStmtNum = beginSqlStmtRe.findAllIn(example).length
val endStmtNum = endSqlStmtRe.findAllIn(example).length
assert(beginStmtNum === endStmtNum,
"The number of ` > ` does not match to the number of `;`")
}
val exampleRe = """^(.+);\n(?s)(.+)$""".r
val ignoreSet = Set(
// One of examples shows getting the current timestamp
"org.apache.spark.sql.catalyst.expressions.UnixTimestamp",
// Random output without a seed
"org.apache.spark.sql.catalyst.expressions.Rand",
"org.apache.spark.sql.catalyst.expressions.Randn",
"org.apache.spark.sql.catalyst.expressions.Shuffle",
"org.apache.spark.sql.catalyst.expressions.Uuid",
// The example calls methods that return unstable results.
"org.apache.spark.sql.catalyst.expressions.CallMethodViaReflection")
val parFuncs = new ParVector(spark.sessionState.functionRegistry.listFunction().toVector)
parFuncs.foreach { funcId =>
// Examples can change settings. We clone the session to prevent tests clashing.
val clonedSpark = spark.cloneSession()
// Coalescing partitions can change result order, so disable it.
clonedSpark.sessionState.conf.setConf(SQLConf.COALESCE_PARTITIONS_ENABLED, false)
val info = clonedSpark.sessionState.catalog.lookupFunctionInfo(funcId)
val className = info.getClassName
if (!ignoreSet.contains(className)) {
withClue(s"Function '${info.getName}', Expression class '$className'") {
val example = info.getExamples
checkExampleSyntax(example)
example.split(" > ").toList.foreach {
case exampleRe(sql, output) =>
val df = clonedSpark.sql(sql)
val actual = unindentAndTrim(
hiveResultString(df.queryExecution.executedPlan).mkString("\n"))
val expected = unindentAndTrim(output)
assert(actual === expected)
case _ =>
}
}
}
}
}
}