diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala index 8cf1b7fc71..45b57207c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala @@ -28,7 +28,6 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { val scalaTestUDF = TestScalaUDF(name = "scalaUDF") val pythonTestUDF = TestPythonUDF(name = "pyUDF") - assume(shouldTestPythonUDFs) lazy val base = Seq( (Some(1), Some(1)), (Some(1), Some(2)), (Some(2), Some(1)), @@ -36,6 +35,7 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { (None, Some(1)), (Some(3), None), (None, None)).toDF("a", "b") test("SPARK-28445: PythonUDF as grouping key and aggregate expressions") { + assume(shouldTestPythonUDFs) val df1 = base.groupBy(scalaTestUDF(base("a") + 1)) .agg(scalaTestUDF(base("a") + 1), scalaTestUDF(count(base("b")))) val df2 = base.groupBy(pythonTestUDF(base("a") + 1)) @@ -44,6 +44,7 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { } test("SPARK-28445: PythonUDF as grouping key and used in aggregate expressions") { + assume(shouldTestPythonUDFs) val df1 = base.groupBy(scalaTestUDF(base("a") + 1)) .agg(scalaTestUDF(base("a") + 1) + 1, scalaTestUDF(count(base("b")))) val df2 = base.groupBy(pythonTestUDF(base("a") + 1)) @@ -52,6 +53,7 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { } test("SPARK-28445: PythonUDF in aggregate expression has grouping key in its arguments") { + assume(shouldTestPythonUDFs) val df1 = base.groupBy(scalaTestUDF(base("a") + 1)) .agg(scalaTestUDF(scalaTestUDF(base("a") + 1)), scalaTestUDF(count(base("b")))) val df2 = base.groupBy(pythonTestUDF(base("a") + 1)) @@ -60,6 +62,7 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession { } test("SPARK-28445: PythonUDF over grouping key is argument to aggregate function") { + assume(shouldTestPythonUDFs) val df1 = base.groupBy(scalaTestUDF(base("a") + 1)) .agg(scalaTestUDF(scalaTestUDF(base("a") + 1)), scalaTestUDF(count(scalaTestUDF(base("a") + 1))))