[SPARK-34035][SQL] Refactor ScriptTransformation to remove input parameter and replace it by child.output

### What changes were proposed in this pull request?
Refactor ScriptTransformation to remove input parameter and replace it by child.output

### Why are the changes needed?
refactor code

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Existed UT

Closes #32228 from AngersZhuuuu/SPARK-34035.

Lead-authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Co-authored-by: AngersZhuuuu <angers.zhu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Angerszhuuuu 2021-04-20 14:52:21 +00:00 committed by Wenchen Fan
parent 1e64b4fa27
commit 361444890e
17 changed files with 21 additions and 133 deletions

View file

@ -1381,9 +1381,6 @@ class Analyzer(override val catalogManager: CatalogManager)
} else {
a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
}
// TODO: Remove this logic and see SPARK-34035
case t: ScriptTransformation if containsStar(t.input) =>
t.copy(input = t.child.output)
case g: Generate if containsStar(g.generator.children) =>
throw QueryCompilationErrors.invalidStarUsageError("explode/json_tuple/UDTF")

View file

@ -208,7 +208,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
.nonEmpty =>
Seq((oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))))
case oldVersion @ ScriptTransformation(_, _, output, _, _)
case oldVersion @ ScriptTransformation(_, output, _, _)
if AttributeSet(output).intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))

View file

@ -772,9 +772,6 @@ object ColumnPruning extends Rule[LogicalPlan] {
f.copy(child = prunedChild(child, f.references))
case e @ Expand(_, _, child) if !child.outputSet.subsetOf(e.references) =>
e.copy(child = prunedChild(child, e.references))
case s @ ScriptTransformation(_, _, _, child, _)
if !child.outputSet.subsetOf(s.references) =>
s.copy(child = prunedChild(child, s.references))
// prune unrequired references
case p @ Project(_, g: Generate) if p.references != g.outputSet =>

View file

@ -696,8 +696,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
isDistinct = false)
ScriptTransformation(
// TODO: Remove this logic and see SPARK-34035
Seq(UnresolvedStar(None)),
string(transformClause.script),
attributes,
plan,

View file

@ -17,24 +17,22 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet}
/**
* Transforms the input by forking and running the specified script.
*
* @param input the set of expression that should be passed to the script.
* @param script the command that should be executed.
* @param output the attributes that are produced by the script.
* @param ioschema the input and output schema applied in the execution of the script.
*/
case class ScriptTransformation(
input: Seq[Expression],
script: String,
output: Seq[Attribute],
child: LogicalPlan,
ioschema: ScriptInputOutputSchema) extends UnaryNode {
@transient
override lazy val references: AttributeSet = AttributeSet(input.flatMap(_.references))
override lazy val references: AttributeSet = AttributeSet(child.output)
override protected def withNewChildInternal(newChild: LogicalPlan): ScriptTransformation =
copy(child = newChild)

View file

@ -218,30 +218,6 @@ class ColumnPruningSuite extends PlanTest {
comparePlans(optimized, expected)
}
test("Column pruning for ScriptTransformation") {
val input = LocalRelation('a.int, 'b.string, 'c.double)
val query =
ScriptTransformation(
Seq('a, 'b),
"func",
Seq.empty,
input,
null).analyze
val optimized = Optimize.execute(query)
val expected =
ScriptTransformation(
Seq('a, 'b),
"func",
Seq.empty,
Project(
Seq('a, 'b),
input),
null).analyze
comparePlans(optimized, expected)
}
test("Column pruning on Filter") {
val input = LocalRelation('a.int, 'b.string, 'c.double)
val plan1 = Filter('a > 1, input).analyze

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.parser
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction}
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
@ -1074,7 +1074,6 @@ class PlanParserSuite extends AnalysisTest {
|FROM testData
""".stripMargin,
ScriptTransformation(
Seq(UnresolvedStar(None)),
"cat",
Seq(AttributeReference("key", StringType)(),
AttributeReference("value", StringType)()),
@ -1091,7 +1090,6 @@ class PlanParserSuite extends AnalysisTest {
|FROM testData
""".stripMargin,
ScriptTransformation(
Seq(UnresolvedStar(None)),
"cat",
Seq(AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
@ -1108,7 +1106,6 @@ class PlanParserSuite extends AnalysisTest {
|FROM testData
""".stripMargin,
ScriptTransformation(
Seq(UnresolvedStar(None)),
"cat",
Seq(AttributeReference("a", IntegerType)(),
AttributeReference("b", StringType)(),
@ -1137,7 +1134,6 @@ class PlanParserSuite extends AnalysisTest {
|FROM testData
""".stripMargin,
ScriptTransformation(
Seq(UnresolvedStar(None)),
"cat",
Seq(AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),

View file

@ -40,14 +40,13 @@ import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils}
trait BaseScriptTransformationExec extends UnaryExecNode {
def input: Seq[Expression]
def script: String
def output: Seq[Attribute]
def child: SparkPlan
def ioschema: ScriptTransformationIOSchema
protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = {
input.map { in =>
child.output.map { in =>
in.dataType match {
case _: ArrayType | _: MapType | _: StructType =>
new StructsToJson(ioschema.inputSerdeProps.toMap, in)

View file

@ -30,14 +30,12 @@ import org.apache.spark.util.CircularBuffer
/**
* Transforms the input by forking and running the specified script.
*
* @param input the set of expression that should be passed to the script.
* @param script the command that should be executed.
* @param output the attributes that are produced by the script.
* @param child logical plan whose output is transformed.
* @param ioschema the class set that defines how to handle input/output data.
*/
case class SparkScriptTransformationExec(
input: Seq[Expression],
script: String,
output: Seq[Attribute],
child: SparkPlan,

View file

@ -594,9 +594,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object SparkScripts extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.ScriptTransformation(input, script, output, child, ioschema) =>
case logical.ScriptTransformation(script, output, child, ioschema) =>
SparkScriptTransformationExec(
input,
script,
output,
planLater(child),

View file

@ -30,7 +30,7 @@ import org.apache.spark.{SparkException, TaskContext, TestUtils}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, GenericInternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, GenericInternalRow}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@ -64,7 +64,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
}
def createScriptTransformationExec(
input: Seq[Expression],
script: String,
output: Seq[Attribute],
child: SparkPlan,
@ -77,7 +76,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
checkAnswer(
rowsDf,
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(rowsDf.col("a").expr),
script = "cat",
output = Seq(AttributeReference("a", StringType)()),
child = child,
@ -95,7 +93,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
checkAnswer(
rowsDf,
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(rowsDf.col("a").expr),
script = "cat",
output = Seq(AttributeReference("a", StringType)()),
child = ExceptionInjectingOperator(child),
@ -152,12 +149,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
checkAnswer(
df,
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(
df.col("a").expr,
df.col("b").expr,
df.col("c").expr,
df.col("d").expr,
df.col("e").expr),
script = "cat",
output = Seq(
AttributeReference("key", StringType)(),
@ -170,11 +161,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
'b.cast("string").as("value")).collect())
checkAnswer(
df,
df.select('a, 'b),
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(
df.col("a").expr,
df.col("b").expr),
script = "cat",
output = Seq(
AttributeReference("key", StringType)(),
@ -187,10 +175,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
'b.cast("string").as("value")).collect())
checkAnswer(
df,
df.select('a),
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(
df.col("a").expr),
script = "cat",
output = Seq(
AttributeReference("key", StringType)(),
@ -211,7 +197,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
val e = intercept[SparkException] {
val plan =
createScriptTransformationExec(
input = Seq(rowsDf.col("a").expr),
script = "some_non_existent_command",
output = Seq(AttributeReference("a", StringType)()),
child = rowsDf.queryExecution.sparkPlan,
@ -239,17 +224,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
checkAnswer(
df,
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(
df.col("a").expr,
df.col("b").expr,
df.col("c").expr,
df.col("d").expr,
df.col("e").expr,
df.col("f").expr,
df.col("g").expr,
df.col("h").expr,
df.col("i").expr,
df.col("j").expr),
script = "cat",
output = Seq(
AttributeReference("a", IntegerType)(),
@ -293,12 +267,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
checkAnswer(
df,
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(
df.col("a").expr,
df.col("b").expr,
df.col("c").expr,
df.col("d").expr,
df.col("e").expr),
script = "cat",
output = Seq(
AttributeReference("a", CalendarIntervalType)(),
@ -408,11 +376,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18)
checkAnswer(
df,
df.select('a, 'b),
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(
df.col("a").expr,
df.col("b").expr),
script = "cat",
output = Seq(
AttributeReference("a", StringType)(),
@ -493,14 +458,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
checkAnswer(
df,
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(
df.col("a").expr,
df.col("b").expr,
df.col("c").expr,
df.col("d").expr,
df.col("e").expr,
df.col("f").expr,
df.col("g").expr),
script = "cat",
output = Seq(
AttributeReference("a", ArrayType(IntegerType))(),

View file

@ -18,7 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.TestUtils
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.test.SharedSparkSession
@ -26,13 +26,11 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with
import testImplicits._
override def createScriptTransformationExec(
input: Seq[Expression],
script: String,
output: Seq[Attribute],
child: SparkPlan,
ioschema: ScriptTransformationIOSchema): BaseScriptTransformationExec = {
SparkScriptTransformationExec(
input = input,
script = script,
output = output,
child = child,

View file

@ -296,7 +296,6 @@ class SparkSqlParserSuite extends AnalysisTest {
|FROM testData
""".stripMargin,
ScriptTransformation(
Seq(UnresolvedStar(None)),
"cat",
Seq(AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
@ -316,7 +315,6 @@ class SparkSqlParserSuite extends AnalysisTest {
|HAVING sum(b) > 10
""".stripMargin,
ScriptTransformation(
Seq(UnresolvedStar(None)),
"cat",
Seq(AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
@ -347,7 +345,6 @@ class SparkSqlParserSuite extends AnalysisTest {
|WINDOW w AS (PARTITION BY a ORDER BY b)
""".stripMargin,
ScriptTransformation(
Seq(UnresolvedStar(None)),
"cat",
Seq(AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
@ -385,7 +382,6 @@ class SparkSqlParserSuite extends AnalysisTest {
|HAVING sum(b) > 10
""".stripMargin,
ScriptTransformation(
Seq(UnresolvedStar(None)),
"cat",
Seq(AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command
import java.util.Locale
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedStar}
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans
import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
@ -280,7 +280,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession {
test("transform query spec") {
val p = Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), plans.table("e"))
val s = ScriptTransformation(Seq(UnresolvedStar(None)), "func", Seq.empty, p, null)
val s = ScriptTransformation("func", Seq.empty, p, null)
compareTransformQuery("select transform(a, b) using 'func' from e where f < 10",
s.copy(child = p.copy(child = p.child.where('f < 10)),

View file

@ -243,9 +243,9 @@ private[hive] trait HiveStrategies {
object HiveScripts extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ScriptTransformation(input, script, output, child, ioschema) =>
case ScriptTransformation(script, output, child, ioschema) =>
val hiveIoSchema = ScriptTransformationIOSchema(ioschema)
HiveScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil
HiveScriptTransformationExec(script, output, planLater(child), hiveIoSchema) :: Nil
case _ => Nil
}
}

View file

@ -42,14 +42,12 @@ import org.apache.spark.util.{CircularBuffer, Utils}
/**
* Transforms the input by forking and running the specified script.
*
* @param input the set of expression that should be passed to the script.
* @param script the command that should be executed.
* @param output the attributes that are produced by the script.
* @param child logical plan whose output is transformed.
* @param ioschema the class set that defines how to handle input/output data.
*/
private[hive] case class HiveScriptTransformationExec(
input: Seq[Expression],
script: String,
output: Seq[Attribute],
child: SparkPlan,
@ -142,14 +140,14 @@ private[hive] case class HiveScriptTransformationExec(
val (outputStream, proc, inputStream, stderrBuffer) = initProc
val (inputSerde, inputSoi) = initInputSerDe(ioschema, input).getOrElse((null, null))
val (inputSerde, inputSoi) = initInputSerDe(ioschema, child.output).getOrElse((null, null))
// For HiveScriptTransformationExec, if inputSerde == null, but outputSerde != null
// We will use StringBuffer to pass data, in this case, we should cast data as string too.
val finalInput = if (inputSerde == null) {
inputExpressionsWithoutSerde
} else {
input
child.output
}
val outputProjection = new InterpretedProjection(finalInput, child.output)

View file

@ -26,7 +26,7 @@ import org.scalatest.exceptions.TestFailedException
import org.apache.spark.{SparkException, TestUtils}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.util.DateTimeConstants
import org.apache.spark.sql.execution._
import org.apache.spark.sql.functions._
@ -40,13 +40,11 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
import ScriptTransformationIOSchema._
override def createScriptTransformationExec(
input: Seq[Expression],
script: String,
output: Seq[Attribute],
child: SparkPlan,
ioschema: ScriptTransformationIOSchema): BaseScriptTransformationExec = {
HiveScriptTransformationExec(
input = input,
script = script,
output = output,
child = child,
@ -68,7 +66,6 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
checkAnswer(
rowsDf,
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(rowsDf.col("a").expr),
script = "cat",
output = Seq(AttributeReference("a", StringType)()),
child = child,
@ -86,7 +83,6 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
checkAnswer(
rowsDf,
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(rowsDf.col("a").expr),
script = "cat",
output = Seq(AttributeReference("a", StringType)()),
child = ExceptionInjectingOperator(child),
@ -107,7 +103,6 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
val e = intercept[SparkException] {
val plan =
createScriptTransformationExec(
input = Seq(rowsDf.col("a").expr),
script = "some_non_existent_command",
output = Seq(AttributeReference("a", StringType)()),
child = rowsDf.queryExecution.sparkPlan,
@ -129,7 +124,6 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
checkAnswer(
rowsDf,
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(rowsDf.col("name").expr),
script = "cat",
output = Seq(AttributeReference("name", StringType)()),
child = child,
@ -146,7 +140,6 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
val e = intercept[SparkException] {
val plan =
createScriptTransformationExec(
input = Seq(rowsDf.col("a").expr),
script = "some_non_existent_command",
output = Seq(AttributeReference("a", StringType)()),
child = rowsDf.queryExecution.sparkPlan,
@ -334,12 +327,8 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
// Hive serde support ArrayType/MapType/StructType as input and output data type
checkAnswer(
df,
df.select('c, 'd, 'e),
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(
df.col("c").expr,
df.col("d").expr,
df.col("e").expr),
script = "cat",
output = Seq(
AttributeReference("c", ArrayType(IntegerType))(),
@ -387,12 +376,11 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
val e1 = intercept[SparkException] {
val plan = createScriptTransformationExec(
input = Seq(df.col("a").expr, df.col("b").expr),
script = "cat",
output = Seq(
AttributeReference("a", IntegerType)(),
AttributeReference("b", CalendarIntervalType)()),
child = df.queryExecution.sparkPlan,
child = df.select('a, 'b).queryExecution.sparkPlan,
ioschema = hiveIOSchema)
SparkPlanTest.executePlan(plan, hiveContext)
}.getMessage
@ -400,12 +388,11 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
val e2 = intercept[SparkException] {
val plan = createScriptTransformationExec(
input = Seq(df.col("a").expr, df.col("c").expr),
script = "cat",
output = Seq(
AttributeReference("a", IntegerType)(),
AttributeReference("c", new TestUDT.MyDenseVectorUDT)()),
child = df.queryExecution.sparkPlan,
child = df.select('a, 'c).queryExecution.sparkPlan,
ioschema = hiveIOSchema)
SparkPlanTest.executePlan(plan, hiveContext)
}.getMessage
@ -551,11 +538,6 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
checkAnswer(
df,
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(
df.col("a").expr,
df.col("b").expr,
df.col("c").expr,
df.col("d").expr),
script = "cat",
output = Seq(
AttributeReference("a", DayTimeIntervalType)(),
@ -578,7 +560,6 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
checkAnswer(
df,
(child: SparkPlan) => createScriptTransformationExec(
input = Seq(df.col("a").expr),
script = "cat",
output = Seq(AttributeReference("a", DayTimeIntervalType)()),
child = child,