[SPARK-34969][SPARK-34906][SQL] Followup for Refactor TreeNode's children handling methods into specialized traits

### What changes were proposed in this pull request?

This is a followup for https://github.com/apache/spark/pull/31932.
In this PR we:
- Introduce the `QuaternaryLike` trait for node types with 4 children.
- Specialize more node types
- Fix a number of style errors that were introduced in the original PR.

### Why are the changes needed?

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

This is a refactoring, passes existing tests.

Closes #32065 from dbaliafroozeh/FollowupSPARK-34906.

Authored-by: Ali Afroozeh <ali.afroozeh@databricks.com>
Signed-off-by: herman <herman@databricks.com>
This commit is contained in:
Ali Afroozeh 2021-04-07 09:50:30 +02:00 committed by herman
parent 0aa2c284e4
commit 06c09a79b3
49 changed files with 127 additions and 87 deletions

View file

@ -28,6 +28,7 @@ import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
@ -348,7 +349,9 @@ private[spark] object SummaryBuilderImpl extends Logging {
weightExpr: Expression,
mutableAggBufferOffset: Int,
inputAggBufferOffset: Int)
extends TypedImperativeAggregate[SummarizerBuffer] with ImplicitCastInputTypes {
extends TypedImperativeAggregate[SummarizerBuffer]
with ImplicitCastInputTypes
with BinaryLike[Expression] {
override def eval(state: SummarizerBuffer): Any = {
val metrics = requestedMetrics.map {
@ -368,7 +371,8 @@ private[spark] object SummaryBuilderImpl extends Logging {
override def inputTypes: Seq[DataType] = vectorUDT :: DoubleType :: Nil
override def children: Seq[Expression] = featuresExpr :: weightExpr :: Nil
override def left: Expression = featuresExpr
override def right: Expression = weightExpr
override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = {
val features = vectorUDT.deserialize(featuresExpr.eval(row))

View file

@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TernaryLike, TreeNode, UnaryLike}
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
@ -786,7 +786,7 @@ abstract class TernaryExpression extends Expression with TernaryLike[Expression]
* An expression with four inputs and one output. The output is by default evaluated to null
* if any input is evaluated to null.
*/
abstract class QuaternaryExpression extends Expression {
abstract class QuaternaryExpression extends Expression with QuaternaryLike[Expression] {
override def foldable: Boolean = children.forall(_.foldable)
@ -797,14 +797,13 @@ abstract class QuaternaryExpression extends Expression {
* If subclass of QuaternaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val exprs = children
val value1 = exprs(0).eval(input)
val value1 = first.eval(input)
if (value1 != null) {
val value2 = exprs(1).eval(input)
val value2 = second.eval(input)
if (value2 != null) {
val value3 = exprs(2).eval(input)
val value3 = third.eval(input)
if (value3 != null) {
val value4 = exprs(3).eval(input)
val value4 = fourth.eval(input)
if (value4 != null) {
return nullSafeEval(value1, value2, value3, value4)
}

View file

@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{DataType, IntegerType}
* }}}
*/
abstract class PartitionTransformExpression extends Expression with Unevaluable
with UnaryLike[Expression] {
with UnaryLike[Expression] {
override def nullable: Boolean = true
}

View file

@ -36,7 +36,7 @@ import org.apache.spark.sql.types._
group = "agg_funcs",
since = "1.0.0")
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
with UnaryLike[Expression] {
with UnaryLike[Expression] {
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")

View file

@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, Long
group = "agg_funcs",
since = "3.0.0")
case class CountIf(predicate: Expression) extends UnevaluableAggregate with ImplicitCastInputTypes
with UnaryLike[Expression] {
with UnaryLike[Expression] {
override def prettyName: String = "count_if"

View file

@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal}
import org.apache.spark.sql.catalyst.trees.QuaternaryLike
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.CountMinSketch
@ -60,7 +61,9 @@ case class CountMinSketchAgg(
seedExpression: Expression,
override val mutableAggBufferOffset: Int,
override val inputAggBufferOffset: Int)
extends TypedImperativeAggregate[CountMinSketch] with ExpectsInputTypes {
extends TypedImperativeAggregate[CountMinSketch]
with ExpectsInputTypes
with QuaternaryLike[Expression] {
def this(
child: Expression,
@ -145,8 +148,10 @@ case class CountMinSketchAgg(
override def defaultResult: Option[Literal] =
Option(Literal.create(eval(createAggregationBuffer()), dataType))
override def children: Seq[Expression] =
Seq(child, epsExpression, confidenceExpression, seedExpression)
override def prettyName: String = "count_min_sketch"
override def first: Expression = child
override def second: Expression = epsExpression
override def third: Expression = confidenceExpression
override def fourth: Expression = seedExpression
}

View file

@ -27,11 +27,9 @@ import org.apache.spark.sql.types._
* Compute the covariance between two expressions.
* When applied on empty data (i.e., count is zero), it returns NULL.
*/
abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Boolean)
abstract class Covariance(val left: Expression, val right: Expression, nullOnDivideByZero: Boolean)
extends DeclarativeAggregate with ImplicitCastInputTypes with BinaryLike[Expression] {
override def left: Expression = x
override def right: Expression = y
override def nullable: Boolean = true
override def dataType: DataType = DoubleType
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
@ -72,14 +70,14 @@ abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Bool
protected def updateExpressionsDef: Seq[Expression] = {
val newN = n + 1.0
val dx = x - xAvg
val dy = y - yAvg
val dx = left - xAvg
val dy = right - yAvg
val dyN = dy / newN
val newXAvg = xAvg + dx / newN
val newYAvg = yAvg + dyN
val newCk = ck + dx * (y - newYAvg)
val newCk = ck + dx * (right - newYAvg)
val isNull = x.isNull || y.isNull
val isNull = left.isNull || right.isNull
Seq(
If(isNull, n, newN),
If(isNull, xAvg, newXAvg),

View file

@ -39,7 +39,7 @@ import org.apache.spark.sql.types._
group = "agg_funcs",
since = "1.0.0")
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
with UnaryLike[Expression] {
with UnaryLike[Expression] {
override def nullable: Boolean = true

View file

@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegralType}
abstract class BitAggregate extends DeclarativeAggregate with ExpectsInputTypes
with UnaryLike[Expression] {
with UnaryLike[Expression] {
val child: Expression

View file

@ -35,7 +35,7 @@ import org.apache.spark.sql.types._
* can cause GC paused and eventually OutOfMemory Errors.
*/
abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T]
with UnaryLike[Expression] {
with UnaryLike[Expression] {
val child: Expression

View file

@ -175,7 +175,7 @@ object GroupingSets {
group = "agg_funcs")
// scalastyle:on line.size.limit line.contains.tab
case class Grouping(child: Expression) extends Expression with Unevaluable
with UnaryLike[Expression] {
with UnaryLike[Expression] {
@transient
override lazy val references: AttributeSet =
AttributeSet(VirtualColumn.groupingIdAttribute :: Nil)

View file

@ -25,6 +25,7 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
@ -119,8 +120,6 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes {
override def nullable: Boolean = arguments.exists(_.nullable)
override def children: Seq[Expression] = arguments ++ functions
/**
* Arguments of the higher ordered function.
*/
@ -182,7 +181,7 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes {
/**
* Trait for functions having as input one argument and one function.
*/
trait SimpleHigherOrderFunction extends HigherOrderFunction {
trait SimpleHigherOrderFunction extends HigherOrderFunction with BinaryLike[Expression] {
def argument: Expression
@ -202,6 +201,9 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction {
def functionForEval: Expression = functionsForEval.head
override def left: Expression = argument
override def right: Expression = function
/**
* Called by [[eval]]. If a subclass keeps the default nullability, it can override this method
* in order to save null-check code.
@ -694,7 +696,7 @@ case class ArrayAggregate(
zero: Expression,
merge: Expression,
finish: Expression)
extends HigherOrderFunction with CodegenFallback {
extends HigherOrderFunction with CodegenFallback with QuaternaryLike[Expression] {
def this(argument: Expression, zero: Expression, merge: Expression) = {
this(argument, zero, merge, LambdaFunction.identity)
@ -760,6 +762,11 @@ case class ArrayAggregate(
}
override def prettyName: String = "aggregate"
override def first: Expression = argument
override def second: Expression = zero
override def third: Expression = merge
override def fourth: Expression = finish
}
/**
@ -884,7 +891,7 @@ case class TransformValues(
since = "3.0.0",
group = "lambda_funcs")
case class MapZipWith(left: Expression, right: Expression, function: Expression)
extends HigherOrderFunction with CodegenFallback {
extends HigherOrderFunction with CodegenFallback with TernaryLike[Expression] {
def functionForEval: Expression = functionsForEval.head
@ -1045,6 +1052,10 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
}
override def prettyName: String = "map_zip_with"
override def first: Expression = left
override def second: Expression = right
override def third: Expression = function
}
// scalastyle:off line.size.limit
@ -1063,7 +1074,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
group = "lambda_funcs")
// scalastyle:on line.size.limit
case class ZipWith(left: Expression, right: Expression, function: Expression)
extends HigherOrderFunction with CodegenFallback {
extends HigherOrderFunction with CodegenFallback with TernaryLike[Expression] {
def functionForEval: Expression = functionsForEval.head
@ -1071,7 +1082,7 @@ case class ZipWith(left: Expression, right: Expression, function: Expression)
override def argumentTypes: Seq[AbstractDataType] = ArrayType :: ArrayType :: Nil
override def functions: Seq[Expression] = List(function)
override def functions: Seq[Expression] = function :: Nil
override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil
@ -1121,4 +1132,8 @@ case class ZipWith(left: Expression, right: Expression, function: Expression)
}
override def prettyName: String = "zip_with"
override def first: Expression = left
override def second: Expression = right
override def third: Expression = function
}

View file

@ -1488,7 +1488,6 @@ case class WidthBucket(
numBucket: Expression)
extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def children: Seq[Expression] = Seq(value, minValue, maxValue, numBucket)
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType, DoubleType, LongType)
override def dataType: DataType = LongType
override def nullable: Boolean = true
@ -1507,4 +1506,9 @@ case class WidthBucket(
"org.apache.spark.sql.catalyst.expressions.WidthBucket" +
s".computeBucketNumber($input, $min, $max, $numBucket)")
}
override def first: Expression = value
override def second: Expression = minValue
override def third: Expression = maxValue
override def fourth: Expression = numBucket
}

View file

@ -562,7 +562,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
override def dataType: DataType = StringType
override def inputTypes: Seq[AbstractDataType] =
Seq(StringType, StringType, StringType, IntegerType)
override def children: Seq[Expression] = subject :: regexp :: rep :: pos :: Nil
override def prettyName: String = "regexp_replace"
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@ -618,6 +617,11 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
"""
})
}
override def first: Expression = subject
override def second: Expression = regexp
override def third: Expression = rep
override def fourth: Expression = pos
}
object RegExpReplace {

View file

@ -593,8 +593,6 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len:
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType),
TypeCollection(StringType, BinaryType), IntegerType, IntegerType)
override def children: Seq[Expression] = input :: replace :: pos :: len :: Nil
override def checkInputDataTypes(): TypeCheckResult = {
val inputTypeCheck = super.checkInputDataTypes()
if (inputTypeCheck.isSuccess) {
@ -631,6 +629,11 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len:
"org.apache.spark.sql.catalyst.expressions.Overlay" +
s".calculate($input, $replace, $pos, $len);")
}
override def first: Expression = input
override def second: Expression = replace
override def third: Expression = pos
override def fourth: Expression = len
}
object StringTranslate {

View file

@ -742,7 +742,7 @@ case class NthValue(input: Expression, offset: Expression, ignoreNulls: Boolean)
group = "window_funcs")
// scalastyle:on line.size.limit line.contains.tab
case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction
with UnaryLike[Expression] {
with UnaryLike[Expression] {
def this() = this(Literal(1))

View file

@ -431,7 +431,7 @@ case class InsertAction(
}
case class Assignment(key: Expression, value: Expression) extends Expression
with Unevaluable with BinaryLike[Expression] {
with Unevaluable with BinaryLike[Expression] {
override def nullable: Boolean = false
override def dataType: DataType = throw new UnresolvedException("nullable")
override def left: Expression = key

View file

@ -850,3 +850,11 @@ trait TernaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
def third: T
@transient override final lazy val children: Seq[T] = first :: second :: third :: Nil
}
trait QuaternaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
def first: T
def second: T
def third: T
def fourth: T
@transient override final lazy val children: Seq[T] = first :: second :: third :: fourth :: Nil
}

View file

@ -560,10 +560,9 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
case class HugeCodeIntExpression(value: Int) extends Expression {
case class HugeCodeIntExpression(value: Int) extends LeafExpression {
override def nullable: Boolean = true
override def dataType: DataType = IntegerType
override def children: Seq[Expression] = Nil
override def eval(input: InternalRow): Any = value
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Assuming HugeMethodLimit to be 8000

View file

@ -22,7 +22,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.SparkContext
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryCommand}
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
@ -31,7 +31,7 @@ import org.apache.spark.util.SerializableConfiguration
/**
* A special `Command` which writes data out and updates metrics.
*/
trait DataWritingCommand extends Command {
trait DataWritingCommand extends UnaryCommand {
/**
* The input query plan that produces the data to be written.
* IMPORTANT: the input query plan MUST be analyzed, so that we can carry its output columns
@ -39,7 +39,7 @@ trait DataWritingCommand extends Command {
*/
def query: LogicalPlan
override final def children: Seq[LogicalPlan] = query :: Nil
override final def child: LogicalPlan = query
// Output column names of the analyzed input query plan.
def outputColumnNames: Seq[String]

View file

@ -25,7 +25,7 @@ import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LeafCommand, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.connector.ExternalCommandRunner
import org.apache.spark.sql.execution.{ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.SQLMetric
@ -37,7 +37,9 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
* A logical command that is executed for its side-effects. `RunnableCommand`s are
* wrapped in `ExecutedCommand` during execution.
*/
trait RunnableCommand extends LeafCommand {
trait RunnableCommand extends Command {
override def children: Seq[LogicalPlan] = Nil
// The map used to record the metrics of running the command. This will be passed to
// `ExecutedCommand` during query planning.

View file

@ -31,7 +31,7 @@ case class AddPartitionExec(
table: SupportsPartitionManagement,
partSpecs: Seq[ResolvedPartitionSpec],
ignoreIfExists: Boolean,
refreshCache: () => Unit) extends V2CommandExec {
refreshCache: () => Unit) extends LeafV2CommandExec {
import DataSourceV2Implicits._
override def output: Seq[Attribute] = Seq.empty

View file

@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.catalog.{NamespaceChange, SupportsNamespac
case class AlterNamespaceSetPropertiesExec(
catalog: SupportsNamespaces,
namespace: Seq[String],
props: Map[String, String]) extends V2CommandExec {
props: Map[String, String]) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = {
val changes = props.map{ case (k, v) =>
NamespaceChange.setProperty(k, v)

View file

@ -28,7 +28,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
case class AlterTableExec(
catalog: TableCatalog,
ident: Identifier,
changes: Seq[TableChange]) extends V2CommandExec {
changes: Seq[TableChange]) extends LeafV2CommandExec {
override def output: Seq[Attribute] = Seq.empty

View file

@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdenti
import org.apache.spark.sql.execution.command.CreateViewCommand
import org.apache.spark.storage.StorageLevel
trait BaseCacheTableExec extends V2CommandExec {
trait BaseCacheTableExec extends LeafV2CommandExec {
def relationName: String
def planToCache: LogicalPlan
def dataFrameForCachedPlan: DataFrame
@ -117,7 +117,7 @@ case class CacheTableAsSelectExec(
case class UncacheTableExec(
relation: LogicalPlan,
cascade: Boolean) extends V2CommandExec {
cascade: Boolean) extends LeafV2CommandExec {
override def run(): Seq[InternalRow] = {
val sparkSession = sqlContext.sparkSession
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, relation, cascade)

View file

@ -34,7 +34,7 @@ case class CreateNamespaceExec(
namespace: Seq[String],
ifNotExists: Boolean,
private var properties: Map[String, String])
extends V2CommandExec {
extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.SupportsNamespaces._

View file

@ -33,7 +33,7 @@ case class CreateTableExec(
tableSchema: StructType,
partitioning: Seq[Transform],
tableProperties: Map[String, String],
ignoreIfExists: Boolean) extends V2CommandExec {
ignoreIfExists: Boolean) extends LeafV2CommandExec {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
override protected def run(): Seq[InternalRow] = {

View file

@ -25,7 +25,7 @@ import org.apache.spark.sql.sources.Filter
case class DeleteFromTableExec(
table: SupportsDelete,
condition: Array[Filter],
refreshCache: () => Unit) extends V2CommandExec {
refreshCache: () => Unit) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = {
table.deleteWhere(condition)

View file

@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
case class DescribeColumnExec(
override val output: Seq[Attribute],
column: Attribute,
isExtended: Boolean) extends V2CommandExec {
isExtended: Boolean) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = {
val rows = new ArrayBuffer[InternalRow]()

View file

@ -31,7 +31,7 @@ case class DescribeNamespaceExec(
output: Seq[Attribute],
catalog: SupportsNamespaces,
namespace: Seq[String],
isExtended: Boolean) extends V2CommandExec {
isExtended: Boolean) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = {
val rows = new ArrayBuffer[InternalRow]()
val ns = namespace.toArray

View file

@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsMetadataCo
case class DescribeTableExec(
output: Seq[Attribute],
table: Table,
isExtended: Boolean) extends V2CommandExec {
isExtended: Boolean) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = {
val rows = new ArrayBuffer[InternalRow]()
addSchema(rows)

View file

@ -30,7 +30,7 @@ case class DropNamespaceExec(
namespace: Seq[String],
ifExists: Boolean,
cascade: Boolean)
extends V2CommandExec {
extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

View file

@ -30,7 +30,7 @@ case class DropPartitionExec(
partSpecs: Seq[ResolvedPartitionSpec],
ignoreIfNotExists: Boolean,
purge: Boolean,
refreshCache: () => Unit) extends V2CommandExec {
refreshCache: () => Unit) extends LeafV2CommandExec {
import DataSourceV2Implicits._
override def output: Seq[Attribute] = Seq.empty

View file

@ -30,7 +30,7 @@ case class DropTableExec(
ident: Identifier,
ifExists: Boolean,
purge: Boolean,
invalidateCache: () => Unit) extends V2CommandExec {
invalidateCache: () => Unit) extends LeafV2CommandExec {
override def run(): Seq[InternalRow] = {
if (catalog.tableExists(ident)) {

View file

@ -24,7 +24,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
case class RefreshTableExec(
catalog: TableCatalog,
ident: Identifier,
refreshCache: () => Unit) extends V2CommandExec {
refreshCache: () => Unit) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = {
catalog.invalidateTable(ident)

View file

@ -29,7 +29,7 @@ case class RenamePartitionExec(
table: SupportsPartitionManagement,
from: ResolvedPartitionSpec,
to: ResolvedPartitionSpec,
refreshCache: () => Unit) extends V2CommandExec {
refreshCache: () => Unit) extends LeafV2CommandExec {
override def output: Seq[Attribute] = Seq.empty

View file

@ -33,7 +33,7 @@ case class RenameTableExec(
newIdent: Identifier,
invalidateCache: () => Option[StorageLevel],
cacheTable: (SparkSession, LogicalPlan, Option[String], StorageLevel) => Unit)
extends V2CommandExec {
extends LeafV2CommandExec {
override def output: Seq[Attribute] = Seq.empty

View file

@ -35,7 +35,7 @@ case class ReplaceTableExec(
partitioning: Seq[Transform],
tableProperties: Map[String, String],
orCreate: Boolean,
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends V2CommandExec {
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = {
if (catalog.tableExists(ident)) {
@ -59,7 +59,7 @@ case class AtomicReplaceTableExec(
partitioning: Seq[Transform],
tableProperties: Map[String, String],
orCreate: Boolean,
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends V2CommandExec {
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = {
if (catalog.tableExists(identifier)) {

View file

@ -28,7 +28,7 @@ case class SetCatalogAndNamespaceExec(
catalogManager: CatalogManager,
catalogName: Option[String],
namespace: Option[Seq[String]])
extends V2CommandExec {
extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = {
// The catalog is updated first because CatalogManager resets the current namespace
// when the current catalog is set.

View file

@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.NamespaceHelper
case class ShowCurrentNamespaceExec(
output: Seq[Attribute],
catalogManager: CatalogManager)
extends V2CommandExec {
extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = {
Seq(toCatalystRow(catalogManager.currentCatalog.name, catalogManager.currentNamespace.quoted))
}

View file

@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Table}
case class ShowTablePropertiesExec(
output: Seq[Attribute],
catalogTable: Table,
propertyKey: Option[String]) extends V2CommandExec {
propertyKey: Option[String]) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = {
import scala.collection.JavaConverters._

View file

@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsAtomicPartitionManagement
case class TruncatePartitionExec(
table: SupportsPartitionManagement,
partSpec: ResolvedPartitionSpec,
refreshCache: () => Unit) extends V2CommandExec {
refreshCache: () => Unit) extends LeafV2CommandExec {
override def output: Seq[Attribute] = Seq.empty

View file

@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.catalog.TruncatableTable
*/
case class TruncateTableExec(
table: TruncatableTable,
refreshCache: () => Unit) extends V2CommandExec {
refreshCache: () => Unit) extends LeafV2CommandExec {
override def output: Seq[Attribute] = Seq.empty

View file

@ -55,9 +55,8 @@ case class OverwriteByExpressionExecV1(
write: V1Write) extends V1FallbackWriters
/** Some helper interfaces that use V2 write semantics through the V1 writer interface. */
sealed trait V1FallbackWriters extends V2CommandExec with SupportsV1Write {
sealed trait V1FallbackWriters extends LeafV2CommandExec with SupportsV1Write {
override def output: Seq[Attribute] = Nil
override final def children: Seq[SparkPlan] = Nil
def table: SupportsWrite
def refreshCache: () => Unit

View file

@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, GenericRowWithSchema}
import org.apache.spark.sql.catalyst.trees.LeafLike
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.StructType
@ -57,8 +58,6 @@ abstract class V2CommandExec extends SparkPlan {
sqlContext.sparkContext.parallelize(result, 1)
}
override def children: Seq[SparkPlan] = Nil
override def producedAttributes: AttributeSet = outputSet
protected def toCatalystRow(values: Any*): InternalRow = {
@ -69,3 +68,5 @@ abstract class V2CommandExec extends SparkPlan {
RowEncoder(StructType.fromAttributes(output)).resolveAndBind().createSerializer()
}
}
trait LeafV2CommandExec extends V2CommandExec with LeafLike[SparkPlan]

View file

@ -24,7 +24,7 @@ import scala.util.Random
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
@ -3632,8 +3632,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
}
object DataFrameFunctionsSuite {
case class CodegenFallbackExpr(child: Expression) extends Expression with CodegenFallback {
override def children: Seq[Expression] = Seq(child)
case class CodegenFallbackExpr(child: Expression) extends UnaryExpression with CodegenFallback {
override def nullable: Boolean = child.nullable
override def dataType: DataType = child.dataType
override lazy val resolved = true

View file

@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, Generator}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.LeafLike
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructType}
@ -358,8 +359,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession {
}
}
case class EmptyGenerator() extends Generator {
override def children: Seq[Expression] = Nil
case class EmptyGenerator() extends Generator with LeafLike[Expression] {
override def elementSchema: StructType = new StructType().add("id", IntegerType)
override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {

View file

@ -23,6 +23,7 @@ import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, ImplicitCastInputTypes, SpecificInternalRow}
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
@ -232,8 +233,9 @@ object TypedImperativeAggregateSuite {
nullable: Boolean = false,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes {
extends TypedImperativeAggregate[MaxValue]
with ImplicitCastInputTypes
with UnaryLike[Expression] {
override def createAggregationBuffer(): MaxValue = {
// Returns Int.MinValue if all inputs are null
@ -270,8 +272,6 @@ object TypedImperativeAggregateSuite {
override lazy val deterministic: Boolean = true
override def children: Seq[Expression] = Seq(child)
override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType)
override def dataType: DataType = IntegerType

View file

@ -22,6 +22,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.hive.execution.TestingTypedCount.State
import org.apache.spark.sql.types._
@ -32,12 +33,11 @@ case class TestingTypedCount(
child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[TestingTypedCount.State] {
extends TypedImperativeAggregate[TestingTypedCount.State]
with UnaryLike[Expression] {
def this(child: Expression) = this(child, 0, 0)
override def children: Seq[Expression] = child :: Nil
override def dataType: DataType = LongType
override def nullable: Boolean = false