[SPARK-34969][SPARK-34906][SQL] Followup for Refactor TreeNode's children handling methods into specialized traits

### What changes were proposed in this pull request?

This is a followup for https://github.com/apache/spark/pull/31932.
In this PR we:
- Introduce the `QuaternaryLike` trait for node types with 4 children.
- Specialize more node types
- Fix a number of style errors that were introduced in the original PR.

### Why are the changes needed?

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

This is a refactoring, passes existing tests.

Closes #32065 from dbaliafroozeh/FollowupSPARK-34906.

Authored-by: Ali Afroozeh <ali.afroozeh@databricks.com>
Signed-off-by: herman <herman@databricks.com>
This commit is contained in:
Ali Afroozeh 2021-04-07 09:50:30 +02:00 committed by herman
parent 0aa2c284e4
commit 06c09a79b3
49 changed files with 127 additions and 87 deletions

View file

@ -28,6 +28,7 @@ import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes} import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.functions.lit import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
@ -348,7 +349,9 @@ private[spark] object SummaryBuilderImpl extends Logging {
weightExpr: Expression, weightExpr: Expression,
mutableAggBufferOffset: Int, mutableAggBufferOffset: Int,
inputAggBufferOffset: Int) inputAggBufferOffset: Int)
extends TypedImperativeAggregate[SummarizerBuffer] with ImplicitCastInputTypes { extends TypedImperativeAggregate[SummarizerBuffer]
with ImplicitCastInputTypes
with BinaryLike[Expression] {
override def eval(state: SummarizerBuffer): Any = { override def eval(state: SummarizerBuffer): Any = {
val metrics = requestedMetrics.map { val metrics = requestedMetrics.map {
@ -368,7 +371,8 @@ private[spark] object SummaryBuilderImpl extends Logging {
override def inputTypes: Seq[DataType] = vectorUDT :: DoubleType :: Nil override def inputTypes: Seq[DataType] = vectorUDT :: DoubleType :: Nil
override def children: Seq[Expression] = featuresExpr :: weightExpr :: Nil override def left: Expression = featuresExpr
override def right: Expression = weightExpr
override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = { override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = {
val features = vectorUDT.deserialize(featuresExpr.eval(row)) val features = vectorUDT.deserialize(featuresExpr.eval(row))

View file

@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TernaryLike, TreeNode, UnaryLike} import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike}
import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
@ -786,7 +786,7 @@ abstract class TernaryExpression extends Expression with TernaryLike[Expression]
* An expression with four inputs and one output. The output is by default evaluated to null * An expression with four inputs and one output. The output is by default evaluated to null
* if any input is evaluated to null. * if any input is evaluated to null.
*/ */
abstract class QuaternaryExpression extends Expression { abstract class QuaternaryExpression extends Expression with QuaternaryLike[Expression] {
override def foldable: Boolean = children.forall(_.foldable) override def foldable: Boolean = children.forall(_.foldable)
@ -797,14 +797,13 @@ abstract class QuaternaryExpression extends Expression {
* If subclass of QuaternaryExpression override nullable, probably should also override this. * If subclass of QuaternaryExpression override nullable, probably should also override this.
*/ */
override def eval(input: InternalRow): Any = { override def eval(input: InternalRow): Any = {
val exprs = children val value1 = first.eval(input)
val value1 = exprs(0).eval(input)
if (value1 != null) { if (value1 != null) {
val value2 = exprs(1).eval(input) val value2 = second.eval(input)
if (value2 != null) { if (value2 != null) {
val value3 = exprs(2).eval(input) val value3 = third.eval(input)
if (value3 != null) { if (value3 != null) {
val value4 = exprs(3).eval(input) val value4 = fourth.eval(input)
if (value4 != null) { if (value4 != null) {
return nullSafeEval(value1, value2, value3, value4) return nullSafeEval(value1, value2, value3, value4)
} }

View file

@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{DataType, IntegerType}
* }}} * }}}
*/ */
abstract class PartitionTransformExpression extends Expression with Unevaluable abstract class PartitionTransformExpression extends Expression with Unevaluable
with UnaryLike[Expression] { with UnaryLike[Expression] {
override def nullable: Boolean = true override def nullable: Boolean = true
} }

View file

@ -36,7 +36,7 @@ import org.apache.spark.sql.types._
group = "agg_funcs", group = "agg_funcs",
since = "1.0.0") since = "1.0.0")
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
with UnaryLike[Expression] { with UnaryLike[Expression] {
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg") override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")

View file

@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, Long
group = "agg_funcs", group = "agg_funcs",
since = "3.0.0") since = "3.0.0")
case class CountIf(predicate: Expression) extends UnevaluableAggregate with ImplicitCastInputTypes case class CountIf(predicate: Expression) extends UnevaluableAggregate with ImplicitCastInputTypes
with UnaryLike[Expression] { with UnaryLike[Expression] {
override def prettyName: String = "count_if" override def prettyName: String = "count_if"

View file

@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal} import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal}
import org.apache.spark.sql.catalyst.trees.QuaternaryLike
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.CountMinSketch import org.apache.spark.util.sketch.CountMinSketch
@ -60,7 +61,9 @@ case class CountMinSketchAgg(
seedExpression: Expression, seedExpression: Expression,
override val mutableAggBufferOffset: Int, override val mutableAggBufferOffset: Int,
override val inputAggBufferOffset: Int) override val inputAggBufferOffset: Int)
extends TypedImperativeAggregate[CountMinSketch] with ExpectsInputTypes { extends TypedImperativeAggregate[CountMinSketch]
with ExpectsInputTypes
with QuaternaryLike[Expression] {
def this( def this(
child: Expression, child: Expression,
@ -145,8 +148,10 @@ case class CountMinSketchAgg(
override def defaultResult: Option[Literal] = override def defaultResult: Option[Literal] =
Option(Literal.create(eval(createAggregationBuffer()), dataType)) Option(Literal.create(eval(createAggregationBuffer()), dataType))
override def children: Seq[Expression] =
Seq(child, epsExpression, confidenceExpression, seedExpression)
override def prettyName: String = "count_min_sketch" override def prettyName: String = "count_min_sketch"
override def first: Expression = child
override def second: Expression = epsExpression
override def third: Expression = confidenceExpression
override def fourth: Expression = seedExpression
} }

View file

@ -27,11 +27,9 @@ import org.apache.spark.sql.types._
* Compute the covariance between two expressions. * Compute the covariance between two expressions.
* When applied on empty data (i.e., count is zero), it returns NULL. * When applied on empty data (i.e., count is zero), it returns NULL.
*/ */
abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Boolean) abstract class Covariance(val left: Expression, val right: Expression, nullOnDivideByZero: Boolean)
extends DeclarativeAggregate with ImplicitCastInputTypes with BinaryLike[Expression] { extends DeclarativeAggregate with ImplicitCastInputTypes with BinaryLike[Expression] {
override def left: Expression = x
override def right: Expression = y
override def nullable: Boolean = true override def nullable: Boolean = true
override def dataType: DataType = DoubleType override def dataType: DataType = DoubleType
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
@ -72,14 +70,14 @@ abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Bool
protected def updateExpressionsDef: Seq[Expression] = { protected def updateExpressionsDef: Seq[Expression] = {
val newN = n + 1.0 val newN = n + 1.0
val dx = x - xAvg val dx = left - xAvg
val dy = y - yAvg val dy = right - yAvg
val dyN = dy / newN val dyN = dy / newN
val newXAvg = xAvg + dx / newN val newXAvg = xAvg + dx / newN
val newYAvg = yAvg + dyN val newYAvg = yAvg + dyN
val newCk = ck + dx * (y - newYAvg) val newCk = ck + dx * (right - newYAvg)
val isNull = x.isNull || y.isNull val isNull = left.isNull || right.isNull
Seq( Seq(
If(isNull, n, newN), If(isNull, n, newN),
If(isNull, xAvg, newXAvg), If(isNull, xAvg, newXAvg),

View file

@ -39,7 +39,7 @@ import org.apache.spark.sql.types._
group = "agg_funcs", group = "agg_funcs",
since = "1.0.0") since = "1.0.0")
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
with UnaryLike[Expression] { with UnaryLike[Expression] {
override def nullable: Boolean = true override def nullable: Boolean = true

View file

@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegralType} import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegralType}
abstract class BitAggregate extends DeclarativeAggregate with ExpectsInputTypes abstract class BitAggregate extends DeclarativeAggregate with ExpectsInputTypes
with UnaryLike[Expression] { with UnaryLike[Expression] {
val child: Expression val child: Expression

View file

@ -35,7 +35,7 @@ import org.apache.spark.sql.types._
* can cause GC paused and eventually OutOfMemory Errors. * can cause GC paused and eventually OutOfMemory Errors.
*/ */
abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T] abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T]
with UnaryLike[Expression] { with UnaryLike[Expression] {
val child: Expression val child: Expression

View file

@ -175,7 +175,7 @@ object GroupingSets {
group = "agg_funcs") group = "agg_funcs")
// scalastyle:on line.size.limit line.contains.tab // scalastyle:on line.size.limit line.contains.tab
case class Grouping(child: Expression) extends Expression with Unevaluable case class Grouping(child: Expression) extends Expression with Unevaluable
with UnaryLike[Expression] { with UnaryLike[Expression] {
@transient @transient
override lazy val references: AttributeSet = override lazy val references: AttributeSet =
AttributeSet(VirtualColumn.groupingIdAttribute :: Nil) AttributeSet(VirtualColumn.groupingIdAttribute :: Nil)

View file

@ -25,6 +25,7 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike}
import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
@ -119,8 +120,6 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes {
override def nullable: Boolean = arguments.exists(_.nullable) override def nullable: Boolean = arguments.exists(_.nullable)
override def children: Seq[Expression] = arguments ++ functions
/** /**
* Arguments of the higher ordered function. * Arguments of the higher ordered function.
*/ */
@ -182,7 +181,7 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes {
/** /**
* Trait for functions having as input one argument and one function. * Trait for functions having as input one argument and one function.
*/ */
trait SimpleHigherOrderFunction extends HigherOrderFunction { trait SimpleHigherOrderFunction extends HigherOrderFunction with BinaryLike[Expression] {
def argument: Expression def argument: Expression
@ -202,6 +201,9 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction {
def functionForEval: Expression = functionsForEval.head def functionForEval: Expression = functionsForEval.head
override def left: Expression = argument
override def right: Expression = function
/** /**
* Called by [[eval]]. If a subclass keeps the default nullability, it can override this method * Called by [[eval]]. If a subclass keeps the default nullability, it can override this method
* in order to save null-check code. * in order to save null-check code.
@ -694,7 +696,7 @@ case class ArrayAggregate(
zero: Expression, zero: Expression,
merge: Expression, merge: Expression,
finish: Expression) finish: Expression)
extends HigherOrderFunction with CodegenFallback { extends HigherOrderFunction with CodegenFallback with QuaternaryLike[Expression] {
def this(argument: Expression, zero: Expression, merge: Expression) = { def this(argument: Expression, zero: Expression, merge: Expression) = {
this(argument, zero, merge, LambdaFunction.identity) this(argument, zero, merge, LambdaFunction.identity)
@ -760,6 +762,11 @@ case class ArrayAggregate(
} }
override def prettyName: String = "aggregate" override def prettyName: String = "aggregate"
override def first: Expression = argument
override def second: Expression = zero
override def third: Expression = merge
override def fourth: Expression = finish
} }
/** /**
@ -884,7 +891,7 @@ case class TransformValues(
since = "3.0.0", since = "3.0.0",
group = "lambda_funcs") group = "lambda_funcs")
case class MapZipWith(left: Expression, right: Expression, function: Expression) case class MapZipWith(left: Expression, right: Expression, function: Expression)
extends HigherOrderFunction with CodegenFallback { extends HigherOrderFunction with CodegenFallback with TernaryLike[Expression] {
def functionForEval: Expression = functionsForEval.head def functionForEval: Expression = functionsForEval.head
@ -1045,6 +1052,10 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
} }
override def prettyName: String = "map_zip_with" override def prettyName: String = "map_zip_with"
override def first: Expression = left
override def second: Expression = right
override def third: Expression = function
} }
// scalastyle:off line.size.limit // scalastyle:off line.size.limit
@ -1063,7 +1074,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
group = "lambda_funcs") group = "lambda_funcs")
// scalastyle:on line.size.limit // scalastyle:on line.size.limit
case class ZipWith(left: Expression, right: Expression, function: Expression) case class ZipWith(left: Expression, right: Expression, function: Expression)
extends HigherOrderFunction with CodegenFallback { extends HigherOrderFunction with CodegenFallback with TernaryLike[Expression] {
def functionForEval: Expression = functionsForEval.head def functionForEval: Expression = functionsForEval.head
@ -1071,7 +1082,7 @@ case class ZipWith(left: Expression, right: Expression, function: Expression)
override def argumentTypes: Seq[AbstractDataType] = ArrayType :: ArrayType :: Nil override def argumentTypes: Seq[AbstractDataType] = ArrayType :: ArrayType :: Nil
override def functions: Seq[Expression] = List(function) override def functions: Seq[Expression] = function :: Nil
override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil
@ -1121,4 +1132,8 @@ case class ZipWith(left: Expression, right: Expression, function: Expression)
} }
override def prettyName: String = "zip_with" override def prettyName: String = "zip_with"
override def first: Expression = left
override def second: Expression = right
override def third: Expression = function
} }

View file

@ -1488,7 +1488,6 @@ case class WidthBucket(
numBucket: Expression) numBucket: Expression)
extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant { extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def children: Seq[Expression] = Seq(value, minValue, maxValue, numBucket)
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType, DoubleType, LongType) override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType, DoubleType, LongType)
override def dataType: DataType = LongType override def dataType: DataType = LongType
override def nullable: Boolean = true override def nullable: Boolean = true
@ -1507,4 +1506,9 @@ case class WidthBucket(
"org.apache.spark.sql.catalyst.expressions.WidthBucket" + "org.apache.spark.sql.catalyst.expressions.WidthBucket" +
s".computeBucketNumber($input, $min, $max, $numBucket)") s".computeBucketNumber($input, $min, $max, $numBucket)")
} }
override def first: Expression = value
override def second: Expression = minValue
override def third: Expression = maxValue
override def fourth: Expression = numBucket
} }

View file

@ -562,7 +562,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
override def dataType: DataType = StringType override def dataType: DataType = StringType
override def inputTypes: Seq[AbstractDataType] = override def inputTypes: Seq[AbstractDataType] =
Seq(StringType, StringType, StringType, IntegerType) Seq(StringType, StringType, StringType, IntegerType)
override def children: Seq[Expression] = subject :: regexp :: rep :: pos :: Nil
override def prettyName: String = "regexp_replace" override def prettyName: String = "regexp_replace"
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@ -618,6 +617,11 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
""" """
}) })
} }
override def first: Expression = subject
override def second: Expression = regexp
override def third: Expression = rep
override def fourth: Expression = pos
} }
object RegExpReplace { object RegExpReplace {

View file

@ -593,8 +593,6 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len:
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType), override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType),
TypeCollection(StringType, BinaryType), IntegerType, IntegerType) TypeCollection(StringType, BinaryType), IntegerType, IntegerType)
override def children: Seq[Expression] = input :: replace :: pos :: len :: Nil
override def checkInputDataTypes(): TypeCheckResult = { override def checkInputDataTypes(): TypeCheckResult = {
val inputTypeCheck = super.checkInputDataTypes() val inputTypeCheck = super.checkInputDataTypes()
if (inputTypeCheck.isSuccess) { if (inputTypeCheck.isSuccess) {
@ -631,6 +629,11 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len:
"org.apache.spark.sql.catalyst.expressions.Overlay" + "org.apache.spark.sql.catalyst.expressions.Overlay" +
s".calculate($input, $replace, $pos, $len);") s".calculate($input, $replace, $pos, $len);")
} }
override def first: Expression = input
override def second: Expression = replace
override def third: Expression = pos
override def fourth: Expression = len
} }
object StringTranslate { object StringTranslate {

View file

@ -742,7 +742,7 @@ case class NthValue(input: Expression, offset: Expression, ignoreNulls: Boolean)
group = "window_funcs") group = "window_funcs")
// scalastyle:on line.size.limit line.contains.tab // scalastyle:on line.size.limit line.contains.tab
case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction
with UnaryLike[Expression] { with UnaryLike[Expression] {
def this() = this(Literal(1)) def this() = this(Literal(1))

View file

@ -431,7 +431,7 @@ case class InsertAction(
} }
case class Assignment(key: Expression, value: Expression) extends Expression case class Assignment(key: Expression, value: Expression) extends Expression
with Unevaluable with BinaryLike[Expression] { with Unevaluable with BinaryLike[Expression] {
override def nullable: Boolean = false override def nullable: Boolean = false
override def dataType: DataType = throw new UnresolvedException("nullable") override def dataType: DataType = throw new UnresolvedException("nullable")
override def left: Expression = key override def left: Expression = key

View file

@ -850,3 +850,11 @@ trait TernaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
def third: T def third: T
@transient override final lazy val children: Seq[T] = first :: second :: third :: Nil @transient override final lazy val children: Seq[T] = first :: second :: third :: Nil
} }
trait QuaternaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
def first: T
def second: T
def third: T
def fourth: T
@transient override final lazy val children: Seq[T] = first :: second :: third :: fourth :: Nil
}

View file

@ -560,10 +560,9 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
} }
} }
case class HugeCodeIntExpression(value: Int) extends Expression { case class HugeCodeIntExpression(value: Int) extends LeafExpression {
override def nullable: Boolean = true override def nullable: Boolean = true
override def dataType: DataType = IntegerType override def dataType: DataType = IntegerType
override def children: Seq[Expression] = Nil
override def eval(input: InternalRow): Any = value override def eval(input: InternalRow): Any = value
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Assuming HugeMethodLimit to be 8000 // Assuming HugeMethodLimit to be 8000

View file

@ -22,7 +22,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryCommand}
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
@ -31,7 +31,7 @@ import org.apache.spark.util.SerializableConfiguration
/** /**
* A special `Command` which writes data out and updates metrics. * A special `Command` which writes data out and updates metrics.
*/ */
trait DataWritingCommand extends Command { trait DataWritingCommand extends UnaryCommand {
/** /**
* The input query plan that produces the data to be written. * The input query plan that produces the data to be written.
* IMPORTANT: the input query plan MUST be analyzed, so that we can carry its output columns * IMPORTANT: the input query plan MUST be analyzed, so that we can carry its output columns
@ -39,7 +39,7 @@ trait DataWritingCommand extends Command {
*/ */
def query: LogicalPlan def query: LogicalPlan
override final def children: Seq[LogicalPlan] = query :: Nil override final def child: LogicalPlan = query
// Output column names of the analyzed input query plan. // Output column names of the analyzed input query plan.
def outputColumnNames: Seq[String] def outputColumnNames: Seq[String]

View file

@ -25,7 +25,7 @@ import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LeafCommand, LogicalPlan} import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.connector.ExternalCommandRunner import org.apache.spark.sql.connector.ExternalCommandRunner
import org.apache.spark.sql.execution.{ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.{ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.metric.SQLMetric
@ -37,7 +37,9 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
* A logical command that is executed for its side-effects. `RunnableCommand`s are * A logical command that is executed for its side-effects. `RunnableCommand`s are
* wrapped in `ExecutedCommand` during execution. * wrapped in `ExecutedCommand` during execution.
*/ */
trait RunnableCommand extends LeafCommand { trait RunnableCommand extends Command {
override def children: Seq[LogicalPlan] = Nil
// The map used to record the metrics of running the command. This will be passed to // The map used to record the metrics of running the command. This will be passed to
// `ExecutedCommand` during query planning. // `ExecutedCommand` during query planning.

View file

@ -31,7 +31,7 @@ case class AddPartitionExec(
table: SupportsPartitionManagement, table: SupportsPartitionManagement,
partSpecs: Seq[ResolvedPartitionSpec], partSpecs: Seq[ResolvedPartitionSpec],
ignoreIfExists: Boolean, ignoreIfExists: Boolean,
refreshCache: () => Unit) extends V2CommandExec { refreshCache: () => Unit) extends LeafV2CommandExec {
import DataSourceV2Implicits._ import DataSourceV2Implicits._
override def output: Seq[Attribute] = Seq.empty override def output: Seq[Attribute] = Seq.empty

View file

@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.catalog.{NamespaceChange, SupportsNamespac
case class AlterNamespaceSetPropertiesExec( case class AlterNamespaceSetPropertiesExec(
catalog: SupportsNamespaces, catalog: SupportsNamespaces,
namespace: Seq[String], namespace: Seq[String],
props: Map[String, String]) extends V2CommandExec { props: Map[String, String]) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
val changes = props.map{ case (k, v) => val changes = props.map{ case (k, v) =>
NamespaceChange.setProperty(k, v) NamespaceChange.setProperty(k, v)

View file

@ -28,7 +28,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
case class AlterTableExec( case class AlterTableExec(
catalog: TableCatalog, catalog: TableCatalog,
ident: Identifier, ident: Identifier,
changes: Seq[TableChange]) extends V2CommandExec { changes: Seq[TableChange]) extends LeafV2CommandExec {
override def output: Seq[Attribute] = Seq.empty override def output: Seq[Attribute] = Seq.empty

View file

@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdenti
import org.apache.spark.sql.execution.command.CreateViewCommand import org.apache.spark.sql.execution.command.CreateViewCommand
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel
trait BaseCacheTableExec extends V2CommandExec { trait BaseCacheTableExec extends LeafV2CommandExec {
def relationName: String def relationName: String
def planToCache: LogicalPlan def planToCache: LogicalPlan
def dataFrameForCachedPlan: DataFrame def dataFrameForCachedPlan: DataFrame
@ -117,7 +117,7 @@ case class CacheTableAsSelectExec(
case class UncacheTableExec( case class UncacheTableExec(
relation: LogicalPlan, relation: LogicalPlan,
cascade: Boolean) extends V2CommandExec { cascade: Boolean) extends LeafV2CommandExec {
override def run(): Seq[InternalRow] = { override def run(): Seq[InternalRow] = {
val sparkSession = sqlContext.sparkSession val sparkSession = sqlContext.sparkSession
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, relation, cascade) sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, relation, cascade)

View file

@ -34,7 +34,7 @@ case class CreateNamespaceExec(
namespace: Seq[String], namespace: Seq[String],
ifNotExists: Boolean, ifNotExists: Boolean,
private var properties: Map[String, String]) private var properties: Map[String, String])
extends V2CommandExec { extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.SupportsNamespaces._ import org.apache.spark.sql.connector.catalog.SupportsNamespaces._

View file

@ -33,7 +33,7 @@ case class CreateTableExec(
tableSchema: StructType, tableSchema: StructType,
partitioning: Seq[Transform], partitioning: Seq[Transform],
tableProperties: Map[String, String], tableProperties: Map[String, String],
ignoreIfExists: Boolean) extends V2CommandExec { ignoreIfExists: Boolean) extends LeafV2CommandExec {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {

View file

@ -25,7 +25,7 @@ import org.apache.spark.sql.sources.Filter
case class DeleteFromTableExec( case class DeleteFromTableExec(
table: SupportsDelete, table: SupportsDelete,
condition: Array[Filter], condition: Array[Filter],
refreshCache: () => Unit) extends V2CommandExec { refreshCache: () => Unit) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
table.deleteWhere(condition) table.deleteWhere(condition)

View file

@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
case class DescribeColumnExec( case class DescribeColumnExec(
override val output: Seq[Attribute], override val output: Seq[Attribute],
column: Attribute, column: Attribute,
isExtended: Boolean) extends V2CommandExec { isExtended: Boolean) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
val rows = new ArrayBuffer[InternalRow]() val rows = new ArrayBuffer[InternalRow]()

View file

@ -31,7 +31,7 @@ case class DescribeNamespaceExec(
output: Seq[Attribute], output: Seq[Attribute],
catalog: SupportsNamespaces, catalog: SupportsNamespaces,
namespace: Seq[String], namespace: Seq[String],
isExtended: Boolean) extends V2CommandExec { isExtended: Boolean) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
val rows = new ArrayBuffer[InternalRow]() val rows = new ArrayBuffer[InternalRow]()
val ns = namespace.toArray val ns = namespace.toArray

View file

@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsMetadataCo
case class DescribeTableExec( case class DescribeTableExec(
output: Seq[Attribute], output: Seq[Attribute],
table: Table, table: Table,
isExtended: Boolean) extends V2CommandExec { isExtended: Boolean) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
val rows = new ArrayBuffer[InternalRow]() val rows = new ArrayBuffer[InternalRow]()
addSchema(rows) addSchema(rows)

View file

@ -30,7 +30,7 @@ case class DropNamespaceExec(
namespace: Seq[String], namespace: Seq[String],
ifExists: Boolean, ifExists: Boolean,
cascade: Boolean) cascade: Boolean)
extends V2CommandExec { extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

View file

@ -30,7 +30,7 @@ case class DropPartitionExec(
partSpecs: Seq[ResolvedPartitionSpec], partSpecs: Seq[ResolvedPartitionSpec],
ignoreIfNotExists: Boolean, ignoreIfNotExists: Boolean,
purge: Boolean, purge: Boolean,
refreshCache: () => Unit) extends V2CommandExec { refreshCache: () => Unit) extends LeafV2CommandExec {
import DataSourceV2Implicits._ import DataSourceV2Implicits._
override def output: Seq[Attribute] = Seq.empty override def output: Seq[Attribute] = Seq.empty

View file

@ -30,7 +30,7 @@ case class DropTableExec(
ident: Identifier, ident: Identifier,
ifExists: Boolean, ifExists: Boolean,
purge: Boolean, purge: Boolean,
invalidateCache: () => Unit) extends V2CommandExec { invalidateCache: () => Unit) extends LeafV2CommandExec {
override def run(): Seq[InternalRow] = { override def run(): Seq[InternalRow] = {
if (catalog.tableExists(ident)) { if (catalog.tableExists(ident)) {

View file

@ -24,7 +24,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
case class RefreshTableExec( case class RefreshTableExec(
catalog: TableCatalog, catalog: TableCatalog,
ident: Identifier, ident: Identifier,
refreshCache: () => Unit) extends V2CommandExec { refreshCache: () => Unit) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
catalog.invalidateTable(ident) catalog.invalidateTable(ident)

View file

@ -29,7 +29,7 @@ case class RenamePartitionExec(
table: SupportsPartitionManagement, table: SupportsPartitionManagement,
from: ResolvedPartitionSpec, from: ResolvedPartitionSpec,
to: ResolvedPartitionSpec, to: ResolvedPartitionSpec,
refreshCache: () => Unit) extends V2CommandExec { refreshCache: () => Unit) extends LeafV2CommandExec {
override def output: Seq[Attribute] = Seq.empty override def output: Seq[Attribute] = Seq.empty

View file

@ -33,7 +33,7 @@ case class RenameTableExec(
newIdent: Identifier, newIdent: Identifier,
invalidateCache: () => Option[StorageLevel], invalidateCache: () => Option[StorageLevel],
cacheTable: (SparkSession, LogicalPlan, Option[String], StorageLevel) => Unit) cacheTable: (SparkSession, LogicalPlan, Option[String], StorageLevel) => Unit)
extends V2CommandExec { extends LeafV2CommandExec {
override def output: Seq[Attribute] = Seq.empty override def output: Seq[Attribute] = Seq.empty

View file

@ -35,7 +35,7 @@ case class ReplaceTableExec(
partitioning: Seq[Transform], partitioning: Seq[Transform],
tableProperties: Map[String, String], tableProperties: Map[String, String],
orCreate: Boolean, orCreate: Boolean,
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends V2CommandExec { invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
if (catalog.tableExists(ident)) { if (catalog.tableExists(ident)) {
@ -59,7 +59,7 @@ case class AtomicReplaceTableExec(
partitioning: Seq[Transform], partitioning: Seq[Transform],
tableProperties: Map[String, String], tableProperties: Map[String, String],
orCreate: Boolean, orCreate: Boolean,
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends V2CommandExec { invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
if (catalog.tableExists(identifier)) { if (catalog.tableExists(identifier)) {

View file

@ -28,7 +28,7 @@ case class SetCatalogAndNamespaceExec(
catalogManager: CatalogManager, catalogManager: CatalogManager,
catalogName: Option[String], catalogName: Option[String],
namespace: Option[Seq[String]]) namespace: Option[Seq[String]])
extends V2CommandExec { extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
// The catalog is updated first because CatalogManager resets the current namespace // The catalog is updated first because CatalogManager resets the current namespace
// when the current catalog is set. // when the current catalog is set.

View file

@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.NamespaceHelper
case class ShowCurrentNamespaceExec( case class ShowCurrentNamespaceExec(
output: Seq[Attribute], output: Seq[Attribute],
catalogManager: CatalogManager) catalogManager: CatalogManager)
extends V2CommandExec { extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
Seq(toCatalystRow(catalogManager.currentCatalog.name, catalogManager.currentNamespace.quoted)) Seq(toCatalystRow(catalogManager.currentCatalog.name, catalogManager.currentNamespace.quoted))
} }

View file

@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Table}
case class ShowTablePropertiesExec( case class ShowTablePropertiesExec(
output: Seq[Attribute], output: Seq[Attribute],
catalogTable: Table, catalogTable: Table,
propertyKey: Option[String]) extends V2CommandExec { propertyKey: Option[String]) extends LeafV2CommandExec {
override protected def run(): Seq[InternalRow] = { override protected def run(): Seq[InternalRow] = {
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._

View file

@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsAtomicPartitionManagement
case class TruncatePartitionExec( case class TruncatePartitionExec(
table: SupportsPartitionManagement, table: SupportsPartitionManagement,
partSpec: ResolvedPartitionSpec, partSpec: ResolvedPartitionSpec,
refreshCache: () => Unit) extends V2CommandExec { refreshCache: () => Unit) extends LeafV2CommandExec {
override def output: Seq[Attribute] = Seq.empty override def output: Seq[Attribute] = Seq.empty

View file

@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.catalog.TruncatableTable
*/ */
case class TruncateTableExec( case class TruncateTableExec(
table: TruncatableTable, table: TruncatableTable,
refreshCache: () => Unit) extends V2CommandExec { refreshCache: () => Unit) extends LeafV2CommandExec {
override def output: Seq[Attribute] = Seq.empty override def output: Seq[Attribute] = Seq.empty

View file

@ -55,9 +55,8 @@ case class OverwriteByExpressionExecV1(
write: V1Write) extends V1FallbackWriters write: V1Write) extends V1FallbackWriters
/** Some helper interfaces that use V2 write semantics through the V1 writer interface. */ /** Some helper interfaces that use V2 write semantics through the V1 writer interface. */
sealed trait V1FallbackWriters extends V2CommandExec with SupportsV1Write { sealed trait V1FallbackWriters extends LeafV2CommandExec with SupportsV1Write {
override def output: Seq[Attribute] = Nil override def output: Seq[Attribute] = Nil
override final def children: Seq[SparkPlan] = Nil
def table: SupportsWrite def table: SupportsWrite
def refreshCache: () => Unit def refreshCache: () => Unit

View file

@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, GenericRowWithSchema} import org.apache.spark.sql.catalyst.expressions.{AttributeSet, GenericRowWithSchema}
import org.apache.spark.sql.catalyst.trees.LeafLike
import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructType
@ -57,8 +58,6 @@ abstract class V2CommandExec extends SparkPlan {
sqlContext.sparkContext.parallelize(result, 1) sqlContext.sparkContext.parallelize(result, 1)
} }
override def children: Seq[SparkPlan] = Nil
override def producedAttributes: AttributeSet = outputSet override def producedAttributes: AttributeSet = outputSet
protected def toCatalystRow(values: Any*): InternalRow = { protected def toCatalystRow(values: Any*): InternalRow = {
@ -69,3 +68,5 @@ abstract class V2CommandExec extends SparkPlan {
RowEncoder(StructType.fromAttributes(output)).resolveAndBind().createSerializer() RowEncoder(StructType.fromAttributes(output)).resolveAndBind().createSerializer()
} }
} }
trait LeafV2CommandExec extends V2CommandExec with LeafLike[SparkPlan]

View file

@ -24,7 +24,7 @@ import scala.util.Random
import org.apache.spark.SparkException import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC} import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
@ -3632,8 +3632,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
} }
object DataFrameFunctionsSuite { object DataFrameFunctionsSuite {
case class CodegenFallbackExpr(child: Expression) extends Expression with CodegenFallback { case class CodegenFallbackExpr(child: Expression) extends UnaryExpression with CodegenFallback {
override def children: Seq[Expression] = Seq(child)
override def nullable: Boolean = child.nullable override def nullable: Boolean = child.nullable
override def dataType: DataType = child.dataType override def dataType: DataType = child.dataType
override lazy val resolved = true override lazy val resolved = true

View file

@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, Generator} import org.apache.spark.sql.catalyst.expressions.{Expression, Generator}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.LeafLike
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.types.{IntegerType, StructType}
@ -358,8 +359,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession {
} }
} }
case class EmptyGenerator() extends Generator { case class EmptyGenerator() extends Generator with LeafLike[Expression] {
override def children: Seq[Expression] = Nil
override def elementSchema: StructType = new StructType().add("id", IntegerType) override def elementSchema: StructType = new StructType().add("id", IntegerType)
override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {

View file

@ -23,6 +23,7 @@ import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, ImplicitCastInputTypes, SpecificInternalRow} import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, ImplicitCastInputTypes, SpecificInternalRow}
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
@ -232,8 +233,9 @@ object TypedImperativeAggregateSuite {
nullable: Boolean = false, nullable: Boolean = false,
mutableAggBufferOffset: Int = 0, mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes { extends TypedImperativeAggregate[MaxValue]
with ImplicitCastInputTypes
with UnaryLike[Expression] {
override def createAggregationBuffer(): MaxValue = { override def createAggregationBuffer(): MaxValue = {
// Returns Int.MinValue if all inputs are null // Returns Int.MinValue if all inputs are null
@ -270,8 +272,6 @@ object TypedImperativeAggregateSuite {
override lazy val deterministic: Boolean = true override lazy val deterministic: Boolean = true
override def children: Seq[Expression] = Seq(child)
override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType) override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType)
override def dataType: DataType = IntegerType override def dataType: DataType = IntegerType

View file

@ -22,6 +22,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.hive.execution.TestingTypedCount.State import org.apache.spark.sql.hive.execution.TestingTypedCount.State
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
@ -32,12 +33,11 @@ case class TestingTypedCount(
child: Expression, child: Expression,
mutableAggBufferOffset: Int = 0, mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[TestingTypedCount.State] { extends TypedImperativeAggregate[TestingTypedCount.State]
with UnaryLike[Expression] {
def this(child: Expression) = this(child, 0, 0) def this(child: Expression) = this(child, 0, 0)
override def children: Seq[Expression] = child :: Nil
override def dataType: DataType = LongType override def dataType: DataType = LongType
override def nullable: Boolean = false override def nullable: Boolean = false