[SPARK-34969][SPARK-34906][SQL] Followup for Refactor TreeNode's children handling methods into specialized traits
### What changes were proposed in this pull request? This is a followup for https://github.com/apache/spark/pull/31932. In this PR we: - Introduce the `QuaternaryLike` trait for node types with 4 children. - Specialize more node types - Fix a number of style errors that were introduced in the original PR. ### Why are the changes needed? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? This is a refactoring, passes existing tests. Closes #32065 from dbaliafroozeh/FollowupSPARK-34906. Authored-by: Ali Afroozeh <ali.afroozeh@databricks.com> Signed-off-by: herman <herman@databricks.com>
This commit is contained in:
parent
0aa2c284e4
commit
06c09a79b3
|
@ -28,6 +28,7 @@ import org.apache.spark.sql.Column
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate}
|
||||
import org.apache.spark.sql.catalyst.trees.BinaryLike
|
||||
import org.apache.spark.sql.functions.lit
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -348,7 +349,9 @@ private[spark] object SummaryBuilderImpl extends Logging {
|
|||
weightExpr: Expression,
|
||||
mutableAggBufferOffset: Int,
|
||||
inputAggBufferOffset: Int)
|
||||
extends TypedImperativeAggregate[SummarizerBuffer] with ImplicitCastInputTypes {
|
||||
extends TypedImperativeAggregate[SummarizerBuffer]
|
||||
with ImplicitCastInputTypes
|
||||
with BinaryLike[Expression] {
|
||||
|
||||
override def eval(state: SummarizerBuffer): Any = {
|
||||
val metrics = requestedMetrics.map {
|
||||
|
@ -368,7 +371,8 @@ private[spark] object SummaryBuilderImpl extends Logging {
|
|||
|
||||
override def inputTypes: Seq[DataType] = vectorUDT :: DoubleType :: Nil
|
||||
|
||||
override def children: Seq[Expression] = featuresExpr :: weightExpr :: Nil
|
||||
override def left: Expression = featuresExpr
|
||||
override def right: Expression = weightExpr
|
||||
|
||||
override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = {
|
||||
val features = vectorUDT.deserialize(featuresExpr.eval(row))
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
|
|||
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TernaryLike, TreeNode, UnaryLike}
|
||||
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike}
|
||||
import org.apache.spark.sql.catalyst.util.truncatedString
|
||||
import org.apache.spark.sql.errors.QueryExecutionErrors
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
@ -786,7 +786,7 @@ abstract class TernaryExpression extends Expression with TernaryLike[Expression]
|
|||
* An expression with four inputs and one output. The output is by default evaluated to null
|
||||
* if any input is evaluated to null.
|
||||
*/
|
||||
abstract class QuaternaryExpression extends Expression {
|
||||
abstract class QuaternaryExpression extends Expression with QuaternaryLike[Expression] {
|
||||
|
||||
override def foldable: Boolean = children.forall(_.foldable)
|
||||
|
||||
|
@ -797,14 +797,13 @@ abstract class QuaternaryExpression extends Expression {
|
|||
* If subclass of QuaternaryExpression override nullable, probably should also override this.
|
||||
*/
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val exprs = children
|
||||
val value1 = exprs(0).eval(input)
|
||||
val value1 = first.eval(input)
|
||||
if (value1 != null) {
|
||||
val value2 = exprs(1).eval(input)
|
||||
val value2 = second.eval(input)
|
||||
if (value2 != null) {
|
||||
val value3 = exprs(2).eval(input)
|
||||
val value3 = third.eval(input)
|
||||
if (value3 != null) {
|
||||
val value4 = exprs(3).eval(input)
|
||||
val value4 = fourth.eval(input)
|
||||
if (value4 != null) {
|
||||
return nullSafeEval(value1, value2, value3, value4)
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{DataType, IntegerType}
|
|||
* }}}
|
||||
*/
|
||||
abstract class PartitionTransformExpression extends Expression with Unevaluable
|
||||
with UnaryLike[Expression] {
|
||||
with UnaryLike[Expression] {
|
||||
override def nullable: Boolean = true
|
||||
}
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ import org.apache.spark.sql.types._
|
|||
group = "agg_funcs",
|
||||
since = "1.0.0")
|
||||
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
|
||||
with UnaryLike[Expression] {
|
||||
with UnaryLike[Expression] {
|
||||
|
||||
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, Long
|
|||
group = "agg_funcs",
|
||||
since = "3.0.0")
|
||||
case class CountIf(predicate: Expression) extends UnevaluableAggregate with ImplicitCastInputTypes
|
||||
with UnaryLike[Expression] {
|
||||
with UnaryLike[Expression] {
|
||||
|
||||
override def prettyName: String = "count_if"
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
|
||||
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal}
|
||||
import org.apache.spark.sql.catalyst.trees.QuaternaryLike
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
import org.apache.spark.util.sketch.CountMinSketch
|
||||
|
@ -60,7 +61,9 @@ case class CountMinSketchAgg(
|
|||
seedExpression: Expression,
|
||||
override val mutableAggBufferOffset: Int,
|
||||
override val inputAggBufferOffset: Int)
|
||||
extends TypedImperativeAggregate[CountMinSketch] with ExpectsInputTypes {
|
||||
extends TypedImperativeAggregate[CountMinSketch]
|
||||
with ExpectsInputTypes
|
||||
with QuaternaryLike[Expression] {
|
||||
|
||||
def this(
|
||||
child: Expression,
|
||||
|
@ -145,8 +148,10 @@ case class CountMinSketchAgg(
|
|||
override def defaultResult: Option[Literal] =
|
||||
Option(Literal.create(eval(createAggregationBuffer()), dataType))
|
||||
|
||||
override def children: Seq[Expression] =
|
||||
Seq(child, epsExpression, confidenceExpression, seedExpression)
|
||||
|
||||
override def prettyName: String = "count_min_sketch"
|
||||
|
||||
override def first: Expression = child
|
||||
override def second: Expression = epsExpression
|
||||
override def third: Expression = confidenceExpression
|
||||
override def fourth: Expression = seedExpression
|
||||
}
|
||||
|
|
|
@ -27,11 +27,9 @@ import org.apache.spark.sql.types._
|
|||
* Compute the covariance between two expressions.
|
||||
* When applied on empty data (i.e., count is zero), it returns NULL.
|
||||
*/
|
||||
abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Boolean)
|
||||
abstract class Covariance(val left: Expression, val right: Expression, nullOnDivideByZero: Boolean)
|
||||
extends DeclarativeAggregate with ImplicitCastInputTypes with BinaryLike[Expression] {
|
||||
|
||||
override def left: Expression = x
|
||||
override def right: Expression = y
|
||||
override def nullable: Boolean = true
|
||||
override def dataType: DataType = DoubleType
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
|
||||
|
@ -72,14 +70,14 @@ abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Bool
|
|||
|
||||
protected def updateExpressionsDef: Seq[Expression] = {
|
||||
val newN = n + 1.0
|
||||
val dx = x - xAvg
|
||||
val dy = y - yAvg
|
||||
val dx = left - xAvg
|
||||
val dy = right - yAvg
|
||||
val dyN = dy / newN
|
||||
val newXAvg = xAvg + dx / newN
|
||||
val newYAvg = yAvg + dyN
|
||||
val newCk = ck + dx * (y - newYAvg)
|
||||
val newCk = ck + dx * (right - newYAvg)
|
||||
|
||||
val isNull = x.isNull || y.isNull
|
||||
val isNull = left.isNull || right.isNull
|
||||
Seq(
|
||||
If(isNull, n, newN),
|
||||
If(isNull, xAvg, newXAvg),
|
||||
|
|
|
@ -39,7 +39,7 @@ import org.apache.spark.sql.types._
|
|||
group = "agg_funcs",
|
||||
since = "1.0.0")
|
||||
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
|
||||
with UnaryLike[Expression] {
|
||||
with UnaryLike[Expression] {
|
||||
|
||||
override def nullable: Boolean = true
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.trees.UnaryLike
|
|||
import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegralType}
|
||||
|
||||
abstract class BitAggregate extends DeclarativeAggregate with ExpectsInputTypes
|
||||
with UnaryLike[Expression] {
|
||||
with UnaryLike[Expression] {
|
||||
|
||||
val child: Expression
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ import org.apache.spark.sql.types._
|
|||
* can cause GC paused and eventually OutOfMemory Errors.
|
||||
*/
|
||||
abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T]
|
||||
with UnaryLike[Expression] {
|
||||
with UnaryLike[Expression] {
|
||||
|
||||
val child: Expression
|
||||
|
||||
|
|
|
@ -175,7 +175,7 @@ object GroupingSets {
|
|||
group = "agg_funcs")
|
||||
// scalastyle:on line.size.limit line.contains.tab
|
||||
case class Grouping(child: Expression) extends Expression with Unevaluable
|
||||
with UnaryLike[Expression] {
|
||||
with UnaryLike[Expression] {
|
||||
@transient
|
||||
override lazy val references: AttributeSet =
|
||||
AttributeSet(VirtualColumn.groupingIdAttribute :: Nil)
|
||||
|
|
|
@ -25,6 +25,7 @@ import scala.collection.mutable
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike}
|
||||
import org.apache.spark.sql.catalyst.util._
|
||||
import org.apache.spark.sql.errors.QueryExecutionErrors
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
|
@ -119,8 +120,6 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes {
|
|||
|
||||
override def nullable: Boolean = arguments.exists(_.nullable)
|
||||
|
||||
override def children: Seq[Expression] = arguments ++ functions
|
||||
|
||||
/**
|
||||
* Arguments of the higher ordered function.
|
||||
*/
|
||||
|
@ -182,7 +181,7 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes {
|
|||
/**
|
||||
* Trait for functions having as input one argument and one function.
|
||||
*/
|
||||
trait SimpleHigherOrderFunction extends HigherOrderFunction {
|
||||
trait SimpleHigherOrderFunction extends HigherOrderFunction with BinaryLike[Expression] {
|
||||
|
||||
def argument: Expression
|
||||
|
||||
|
@ -202,6 +201,9 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction {
|
|||
|
||||
def functionForEval: Expression = functionsForEval.head
|
||||
|
||||
override def left: Expression = argument
|
||||
override def right: Expression = function
|
||||
|
||||
/**
|
||||
* Called by [[eval]]. If a subclass keeps the default nullability, it can override this method
|
||||
* in order to save null-check code.
|
||||
|
@ -694,7 +696,7 @@ case class ArrayAggregate(
|
|||
zero: Expression,
|
||||
merge: Expression,
|
||||
finish: Expression)
|
||||
extends HigherOrderFunction with CodegenFallback {
|
||||
extends HigherOrderFunction with CodegenFallback with QuaternaryLike[Expression] {
|
||||
|
||||
def this(argument: Expression, zero: Expression, merge: Expression) = {
|
||||
this(argument, zero, merge, LambdaFunction.identity)
|
||||
|
@ -760,6 +762,11 @@ case class ArrayAggregate(
|
|||
}
|
||||
|
||||
override def prettyName: String = "aggregate"
|
||||
|
||||
override def first: Expression = argument
|
||||
override def second: Expression = zero
|
||||
override def third: Expression = merge
|
||||
override def fourth: Expression = finish
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -884,7 +891,7 @@ case class TransformValues(
|
|||
since = "3.0.0",
|
||||
group = "lambda_funcs")
|
||||
case class MapZipWith(left: Expression, right: Expression, function: Expression)
|
||||
extends HigherOrderFunction with CodegenFallback {
|
||||
extends HigherOrderFunction with CodegenFallback with TernaryLike[Expression] {
|
||||
|
||||
def functionForEval: Expression = functionsForEval.head
|
||||
|
||||
|
@ -1045,6 +1052,10 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
|
|||
}
|
||||
|
||||
override def prettyName: String = "map_zip_with"
|
||||
|
||||
override def first: Expression = left
|
||||
override def second: Expression = right
|
||||
override def third: Expression = function
|
||||
}
|
||||
|
||||
// scalastyle:off line.size.limit
|
||||
|
@ -1063,7 +1074,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
|
|||
group = "lambda_funcs")
|
||||
// scalastyle:on line.size.limit
|
||||
case class ZipWith(left: Expression, right: Expression, function: Expression)
|
||||
extends HigherOrderFunction with CodegenFallback {
|
||||
extends HigherOrderFunction with CodegenFallback with TernaryLike[Expression] {
|
||||
|
||||
def functionForEval: Expression = functionsForEval.head
|
||||
|
||||
|
@ -1071,7 +1082,7 @@ case class ZipWith(left: Expression, right: Expression, function: Expression)
|
|||
|
||||
override def argumentTypes: Seq[AbstractDataType] = ArrayType :: ArrayType :: Nil
|
||||
|
||||
override def functions: Seq[Expression] = List(function)
|
||||
override def functions: Seq[Expression] = function :: Nil
|
||||
|
||||
override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil
|
||||
|
||||
|
@ -1121,4 +1132,8 @@ case class ZipWith(left: Expression, right: Expression, function: Expression)
|
|||
}
|
||||
|
||||
override def prettyName: String = "zip_with"
|
||||
|
||||
override def first: Expression = left
|
||||
override def second: Expression = right
|
||||
override def third: Expression = function
|
||||
}
|
||||
|
|
|
@ -1488,7 +1488,6 @@ case class WidthBucket(
|
|||
numBucket: Expression)
|
||||
extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant {
|
||||
|
||||
override def children: Seq[Expression] = Seq(value, minValue, maxValue, numBucket)
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType, DoubleType, LongType)
|
||||
override def dataType: DataType = LongType
|
||||
override def nullable: Boolean = true
|
||||
|
@ -1507,4 +1506,9 @@ case class WidthBucket(
|
|||
"org.apache.spark.sql.catalyst.expressions.WidthBucket" +
|
||||
s".computeBucketNumber($input, $min, $max, $numBucket)")
|
||||
}
|
||||
|
||||
override def first: Expression = value
|
||||
override def second: Expression = minValue
|
||||
override def third: Expression = maxValue
|
||||
override def fourth: Expression = numBucket
|
||||
}
|
||||
|
|
|
@ -562,7 +562,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
|
|||
override def dataType: DataType = StringType
|
||||
override def inputTypes: Seq[AbstractDataType] =
|
||||
Seq(StringType, StringType, StringType, IntegerType)
|
||||
override def children: Seq[Expression] = subject :: regexp :: rep :: pos :: Nil
|
||||
override def prettyName: String = "regexp_replace"
|
||||
|
||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
|
@ -618,6 +617,11 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
|
|||
"""
|
||||
})
|
||||
}
|
||||
|
||||
override def first: Expression = subject
|
||||
override def second: Expression = regexp
|
||||
override def third: Expression = rep
|
||||
override def fourth: Expression = pos
|
||||
}
|
||||
|
||||
object RegExpReplace {
|
||||
|
|
|
@ -593,8 +593,6 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len:
|
|||
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType),
|
||||
TypeCollection(StringType, BinaryType), IntegerType, IntegerType)
|
||||
|
||||
override def children: Seq[Expression] = input :: replace :: pos :: len :: Nil
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
val inputTypeCheck = super.checkInputDataTypes()
|
||||
if (inputTypeCheck.isSuccess) {
|
||||
|
@ -631,6 +629,11 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len:
|
|||
"org.apache.spark.sql.catalyst.expressions.Overlay" +
|
||||
s".calculate($input, $replace, $pos, $len);")
|
||||
}
|
||||
|
||||
override def first: Expression = input
|
||||
override def second: Expression = replace
|
||||
override def third: Expression = pos
|
||||
override def fourth: Expression = len
|
||||
}
|
||||
|
||||
object StringTranslate {
|
||||
|
|
|
@ -742,7 +742,7 @@ case class NthValue(input: Expression, offset: Expression, ignoreNulls: Boolean)
|
|||
group = "window_funcs")
|
||||
// scalastyle:on line.size.limit line.contains.tab
|
||||
case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction
|
||||
with UnaryLike[Expression] {
|
||||
with UnaryLike[Expression] {
|
||||
|
||||
def this() = this(Literal(1))
|
||||
|
||||
|
|
|
@ -431,7 +431,7 @@ case class InsertAction(
|
|||
}
|
||||
|
||||
case class Assignment(key: Expression, value: Expression) extends Expression
|
||||
with Unevaluable with BinaryLike[Expression] {
|
||||
with Unevaluable with BinaryLike[Expression] {
|
||||
override def nullable: Boolean = false
|
||||
override def dataType: DataType = throw new UnresolvedException("nullable")
|
||||
override def left: Expression = key
|
||||
|
|
|
@ -850,3 +850,11 @@ trait TernaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
|
|||
def third: T
|
||||
@transient override final lazy val children: Seq[T] = first :: second :: third :: Nil
|
||||
}
|
||||
|
||||
trait QuaternaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
|
||||
def first: T
|
||||
def second: T
|
||||
def third: T
|
||||
def fourth: T
|
||||
@transient override final lazy val children: Seq[T] = first :: second :: third :: fourth :: Nil
|
||||
}
|
||||
|
|
|
@ -560,10 +560,9 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
}
|
||||
}
|
||||
|
||||
case class HugeCodeIntExpression(value: Int) extends Expression {
|
||||
case class HugeCodeIntExpression(value: Int) extends LeafExpression {
|
||||
override def nullable: Boolean = true
|
||||
override def dataType: DataType = IntegerType
|
||||
override def children: Seq[Expression] = Nil
|
||||
override def eval(input: InternalRow): Any = value
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
// Assuming HugeMethodLimit to be 8000
|
||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.hadoop.conf.Configuration
|
|||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.sql.{Row, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryCommand}
|
||||
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
|
||||
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
|
||||
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
|
||||
|
@ -31,7 +31,7 @@ import org.apache.spark.util.SerializableConfiguration
|
|||
/**
|
||||
* A special `Command` which writes data out and updates metrics.
|
||||
*/
|
||||
trait DataWritingCommand extends Command {
|
||||
trait DataWritingCommand extends UnaryCommand {
|
||||
/**
|
||||
* The input query plan that produces the data to be written.
|
||||
* IMPORTANT: the input query plan MUST be analyzed, so that we can carry its output columns
|
||||
|
@ -39,7 +39,7 @@ trait DataWritingCommand extends Command {
|
|||
*/
|
||||
def query: LogicalPlan
|
||||
|
||||
override final def children: Seq[LogicalPlan] = query :: Nil
|
||||
override final def child: LogicalPlan = query
|
||||
|
||||
// Output column names of the analyzed input query plan.
|
||||
def outputColumnNames: Seq[String]
|
||||
|
|
|
@ -25,7 +25,7 @@ import org.apache.spark.sql.{Row, SparkSession}
|
|||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
|
||||
import org.apache.spark.sql.catalyst.plans.QueryPlan
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LeafCommand, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
|
||||
import org.apache.spark.sql.connector.ExternalCommandRunner
|
||||
import org.apache.spark.sql.execution.{ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode}
|
||||
import org.apache.spark.sql.execution.metric.SQLMetric
|
||||
|
@ -37,7 +37,9 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
|||
* A logical command that is executed for its side-effects. `RunnableCommand`s are
|
||||
* wrapped in `ExecutedCommand` during execution.
|
||||
*/
|
||||
trait RunnableCommand extends LeafCommand {
|
||||
trait RunnableCommand extends Command {
|
||||
|
||||
override def children: Seq[LogicalPlan] = Nil
|
||||
|
||||
// The map used to record the metrics of running the command. This will be passed to
|
||||
// `ExecutedCommand` during query planning.
|
||||
|
|
|
@ -31,7 +31,7 @@ case class AddPartitionExec(
|
|||
table: SupportsPartitionManagement,
|
||||
partSpecs: Seq[ResolvedPartitionSpec],
|
||||
ignoreIfExists: Boolean,
|
||||
refreshCache: () => Unit) extends V2CommandExec {
|
||||
refreshCache: () => Unit) extends LeafV2CommandExec {
|
||||
import DataSourceV2Implicits._
|
||||
|
||||
override def output: Seq[Attribute] = Seq.empty
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.catalog.{NamespaceChange, SupportsNamespac
|
|||
case class AlterNamespaceSetPropertiesExec(
|
||||
catalog: SupportsNamespaces,
|
||||
namespace: Seq[String],
|
||||
props: Map[String, String]) extends V2CommandExec {
|
||||
props: Map[String, String]) extends LeafV2CommandExec {
|
||||
override protected def run(): Seq[InternalRow] = {
|
||||
val changes = props.map{ case (k, v) =>
|
||||
NamespaceChange.setProperty(k, v)
|
||||
|
|
|
@ -28,7 +28,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
|
|||
case class AlterTableExec(
|
||||
catalog: TableCatalog,
|
||||
ident: Identifier,
|
||||
changes: Seq[TableChange]) extends V2CommandExec {
|
||||
changes: Seq[TableChange]) extends LeafV2CommandExec {
|
||||
|
||||
override def output: Seq[Attribute] = Seq.empty
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdenti
|
|||
import org.apache.spark.sql.execution.command.CreateViewCommand
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
||||
trait BaseCacheTableExec extends V2CommandExec {
|
||||
trait BaseCacheTableExec extends LeafV2CommandExec {
|
||||
def relationName: String
|
||||
def planToCache: LogicalPlan
|
||||
def dataFrameForCachedPlan: DataFrame
|
||||
|
@ -117,7 +117,7 @@ case class CacheTableAsSelectExec(
|
|||
|
||||
case class UncacheTableExec(
|
||||
relation: LogicalPlan,
|
||||
cascade: Boolean) extends V2CommandExec {
|
||||
cascade: Boolean) extends LeafV2CommandExec {
|
||||
override def run(): Seq[InternalRow] = {
|
||||
val sparkSession = sqlContext.sparkSession
|
||||
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, relation, cascade)
|
||||
|
|
|
@ -34,7 +34,7 @@ case class CreateNamespaceExec(
|
|||
namespace: Seq[String],
|
||||
ifNotExists: Boolean,
|
||||
private var properties: Map[String, String])
|
||||
extends V2CommandExec {
|
||||
extends LeafV2CommandExec {
|
||||
override protected def run(): Seq[InternalRow] = {
|
||||
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
|
||||
import org.apache.spark.sql.connector.catalog.SupportsNamespaces._
|
||||
|
|
|
@ -33,7 +33,7 @@ case class CreateTableExec(
|
|||
tableSchema: StructType,
|
||||
partitioning: Seq[Transform],
|
||||
tableProperties: Map[String, String],
|
||||
ignoreIfExists: Boolean) extends V2CommandExec {
|
||||
ignoreIfExists: Boolean) extends LeafV2CommandExec {
|
||||
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
|
||||
|
||||
override protected def run(): Seq[InternalRow] = {
|
||||
|
|
|
@ -25,7 +25,7 @@ import org.apache.spark.sql.sources.Filter
|
|||
case class DeleteFromTableExec(
|
||||
table: SupportsDelete,
|
||||
condition: Array[Filter],
|
||||
refreshCache: () => Unit) extends V2CommandExec {
|
||||
refreshCache: () => Unit) extends LeafV2CommandExec {
|
||||
|
||||
override protected def run(): Seq[InternalRow] = {
|
||||
table.deleteWhere(condition)
|
||||
|
|
|
@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
|
|||
case class DescribeColumnExec(
|
||||
override val output: Seq[Attribute],
|
||||
column: Attribute,
|
||||
isExtended: Boolean) extends V2CommandExec {
|
||||
isExtended: Boolean) extends LeafV2CommandExec {
|
||||
|
||||
override protected def run(): Seq[InternalRow] = {
|
||||
val rows = new ArrayBuffer[InternalRow]()
|
||||
|
|
|
@ -31,7 +31,7 @@ case class DescribeNamespaceExec(
|
|||
output: Seq[Attribute],
|
||||
catalog: SupportsNamespaces,
|
||||
namespace: Seq[String],
|
||||
isExtended: Boolean) extends V2CommandExec {
|
||||
isExtended: Boolean) extends LeafV2CommandExec {
|
||||
override protected def run(): Seq[InternalRow] = {
|
||||
val rows = new ArrayBuffer[InternalRow]()
|
||||
val ns = namespace.toArray
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsMetadataCo
|
|||
case class DescribeTableExec(
|
||||
output: Seq[Attribute],
|
||||
table: Table,
|
||||
isExtended: Boolean) extends V2CommandExec {
|
||||
isExtended: Boolean) extends LeafV2CommandExec {
|
||||
override protected def run(): Seq[InternalRow] = {
|
||||
val rows = new ArrayBuffer[InternalRow]()
|
||||
addSchema(rows)
|
||||
|
|
|
@ -30,7 +30,7 @@ case class DropNamespaceExec(
|
|||
namespace: Seq[String],
|
||||
ifExists: Boolean,
|
||||
cascade: Boolean)
|
||||
extends V2CommandExec {
|
||||
extends LeafV2CommandExec {
|
||||
override protected def run(): Seq[InternalRow] = {
|
||||
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ case class DropPartitionExec(
|
|||
partSpecs: Seq[ResolvedPartitionSpec],
|
||||
ignoreIfNotExists: Boolean,
|
||||
purge: Boolean,
|
||||
refreshCache: () => Unit) extends V2CommandExec {
|
||||
refreshCache: () => Unit) extends LeafV2CommandExec {
|
||||
import DataSourceV2Implicits._
|
||||
|
||||
override def output: Seq[Attribute] = Seq.empty
|
||||
|
|
|
@ -30,7 +30,7 @@ case class DropTableExec(
|
|||
ident: Identifier,
|
||||
ifExists: Boolean,
|
||||
purge: Boolean,
|
||||
invalidateCache: () => Unit) extends V2CommandExec {
|
||||
invalidateCache: () => Unit) extends LeafV2CommandExec {
|
||||
|
||||
override def run(): Seq[InternalRow] = {
|
||||
if (catalog.tableExists(ident)) {
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
|
|||
case class RefreshTableExec(
|
||||
catalog: TableCatalog,
|
||||
ident: Identifier,
|
||||
refreshCache: () => Unit) extends V2CommandExec {
|
||||
refreshCache: () => Unit) extends LeafV2CommandExec {
|
||||
override protected def run(): Seq[InternalRow] = {
|
||||
catalog.invalidateTable(ident)
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ case class RenamePartitionExec(
|
|||
table: SupportsPartitionManagement,
|
||||
from: ResolvedPartitionSpec,
|
||||
to: ResolvedPartitionSpec,
|
||||
refreshCache: () => Unit) extends V2CommandExec {
|
||||
refreshCache: () => Unit) extends LeafV2CommandExec {
|
||||
|
||||
override def output: Seq[Attribute] = Seq.empty
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ case class RenameTableExec(
|
|||
newIdent: Identifier,
|
||||
invalidateCache: () => Option[StorageLevel],
|
||||
cacheTable: (SparkSession, LogicalPlan, Option[String], StorageLevel) => Unit)
|
||||
extends V2CommandExec {
|
||||
extends LeafV2CommandExec {
|
||||
|
||||
override def output: Seq[Attribute] = Seq.empty
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ case class ReplaceTableExec(
|
|||
partitioning: Seq[Transform],
|
||||
tableProperties: Map[String, String],
|
||||
orCreate: Boolean,
|
||||
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends V2CommandExec {
|
||||
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends LeafV2CommandExec {
|
||||
|
||||
override protected def run(): Seq[InternalRow] = {
|
||||
if (catalog.tableExists(ident)) {
|
||||
|
@ -59,7 +59,7 @@ case class AtomicReplaceTableExec(
|
|||
partitioning: Seq[Transform],
|
||||
tableProperties: Map[String, String],
|
||||
orCreate: Boolean,
|
||||
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends V2CommandExec {
|
||||
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends LeafV2CommandExec {
|
||||
|
||||
override protected def run(): Seq[InternalRow] = {
|
||||
if (catalog.tableExists(identifier)) {
|
||||
|
|
|
@ -28,7 +28,7 @@ case class SetCatalogAndNamespaceExec(
|
|||
catalogManager: CatalogManager,
|
||||
catalogName: Option[String],
|
||||
namespace: Option[Seq[String]])
|
||||
extends V2CommandExec {
|
||||
extends LeafV2CommandExec {
|
||||
override protected def run(): Seq[InternalRow] = {
|
||||
// The catalog is updated first because CatalogManager resets the current namespace
|
||||
// when the current catalog is set.
|
||||
|
|
|
@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.NamespaceHelper
|
|||
case class ShowCurrentNamespaceExec(
|
||||
output: Seq[Attribute],
|
||||
catalogManager: CatalogManager)
|
||||
extends V2CommandExec {
|
||||
extends LeafV2CommandExec {
|
||||
override protected def run(): Seq[InternalRow] = {
|
||||
Seq(toCatalystRow(catalogManager.currentCatalog.name, catalogManager.currentNamespace.quoted))
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Table}
|
|||
case class ShowTablePropertiesExec(
|
||||
output: Seq[Attribute],
|
||||
catalogTable: Table,
|
||||
propertyKey: Option[String]) extends V2CommandExec {
|
||||
propertyKey: Option[String]) extends LeafV2CommandExec {
|
||||
|
||||
override protected def run(): Seq[InternalRow] = {
|
||||
import scala.collection.JavaConverters._
|
||||
|
|
|
@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsAtomicPartitionManagement
|
|||
case class TruncatePartitionExec(
|
||||
table: SupportsPartitionManagement,
|
||||
partSpec: ResolvedPartitionSpec,
|
||||
refreshCache: () => Unit) extends V2CommandExec {
|
||||
refreshCache: () => Unit) extends LeafV2CommandExec {
|
||||
|
||||
override def output: Seq[Attribute] = Seq.empty
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.catalog.TruncatableTable
|
|||
*/
|
||||
case class TruncateTableExec(
|
||||
table: TruncatableTable,
|
||||
refreshCache: () => Unit) extends V2CommandExec {
|
||||
refreshCache: () => Unit) extends LeafV2CommandExec {
|
||||
|
||||
override def output: Seq[Attribute] = Seq.empty
|
||||
|
||||
|
|
|
@ -55,9 +55,8 @@ case class OverwriteByExpressionExecV1(
|
|||
write: V1Write) extends V1FallbackWriters
|
||||
|
||||
/** Some helper interfaces that use V2 write semantics through the V1 writer interface. */
|
||||
sealed trait V1FallbackWriters extends V2CommandExec with SupportsV1Write {
|
||||
sealed trait V1FallbackWriters extends LeafV2CommandExec with SupportsV1Write {
|
||||
override def output: Seq[Attribute] = Nil
|
||||
override final def children: Seq[SparkPlan] = Nil
|
||||
|
||||
def table: SupportsWrite
|
||||
def refreshCache: () => Unit
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
||||
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, GenericRowWithSchema}
|
||||
import org.apache.spark.sql.catalyst.trees.LeafLike
|
||||
import org.apache.spark.sql.execution.SparkPlan
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
|
@ -57,8 +58,6 @@ abstract class V2CommandExec extends SparkPlan {
|
|||
sqlContext.sparkContext.parallelize(result, 1)
|
||||
}
|
||||
|
||||
override def children: Seq[SparkPlan] = Nil
|
||||
|
||||
override def producedAttributes: AttributeSet = outputSet
|
||||
|
||||
protected def toCatalystRow(values: Any*): InternalRow = {
|
||||
|
@ -69,3 +68,5 @@ abstract class V2CommandExec extends SparkPlan {
|
|||
RowEncoder(StructType.fromAttributes(output)).resolveAndBind().createSerializer()
|
||||
}
|
||||
}
|
||||
|
||||
trait LeafV2CommandExec extends V2CommandExec with LeafLike[SparkPlan]
|
||||
|
|
|
@ -24,7 +24,7 @@ import scala.util.Random
|
|||
|
||||
import org.apache.spark.SparkException
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
||||
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
|
||||
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
|
||||
|
@ -3632,8 +3632,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
|
|||
}
|
||||
|
||||
object DataFrameFunctionsSuite {
|
||||
case class CodegenFallbackExpr(child: Expression) extends Expression with CodegenFallback {
|
||||
override def children: Seq[Expression] = Seq(child)
|
||||
case class CodegenFallbackExpr(child: Expression) extends UnaryExpression with CodegenFallback {
|
||||
override def nullable: Boolean = child.nullable
|
||||
override def dataType: DataType = child.dataType
|
||||
override lazy val resolved = true
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
|||
import org.apache.spark.sql.catalyst.expressions.{Expression, Generator}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.trees.LeafLike
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.SharedSparkSession
|
||||
import org.apache.spark.sql.types.{IntegerType, StructType}
|
||||
|
@ -358,8 +359,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession {
|
|||
}
|
||||
}
|
||||
|
||||
case class EmptyGenerator() extends Generator {
|
||||
override def children: Seq[Expression] = Nil
|
||||
case class EmptyGenerator() extends Generator with LeafLike[Expression] {
|
||||
override def elementSchema: StructType = new StructType().add("id", IntegerType)
|
||||
override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty
|
||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, ImplicitCastInputTypes, SpecificInternalRow}
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
|
||||
import org.apache.spark.sql.catalyst.trees.UnaryLike
|
||||
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
|
||||
import org.apache.spark.sql.expressions.Window
|
||||
import org.apache.spark.sql.functions._
|
||||
|
@ -232,8 +233,9 @@ object TypedImperativeAggregateSuite {
|
|||
nullable: Boolean = false,
|
||||
mutableAggBufferOffset: Int = 0,
|
||||
inputAggBufferOffset: Int = 0)
|
||||
extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes {
|
||||
|
||||
extends TypedImperativeAggregate[MaxValue]
|
||||
with ImplicitCastInputTypes
|
||||
with UnaryLike[Expression] {
|
||||
|
||||
override def createAggregationBuffer(): MaxValue = {
|
||||
// Returns Int.MinValue if all inputs are null
|
||||
|
@ -270,8 +272,6 @@ object TypedImperativeAggregateSuite {
|
|||
|
||||
override lazy val deterministic: Boolean = true
|
||||
|
||||
override def children: Seq[Expression] = Seq(child)
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType)
|
||||
|
||||
override def dataType: DataType = IntegerType
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
|
||||
import org.apache.spark.sql.catalyst.trees.UnaryLike
|
||||
import org.apache.spark.sql.hive.execution.TestingTypedCount.State
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -32,12 +33,11 @@ case class TestingTypedCount(
|
|||
child: Expression,
|
||||
mutableAggBufferOffset: Int = 0,
|
||||
inputAggBufferOffset: Int = 0)
|
||||
extends TypedImperativeAggregate[TestingTypedCount.State] {
|
||||
extends TypedImperativeAggregate[TestingTypedCount.State]
|
||||
with UnaryLike[Expression] {
|
||||
|
||||
def this(child: Expression) = this(child, 0, 0)
|
||||
|
||||
override def children: Seq[Expression] = child :: Nil
|
||||
|
||||
override def dataType: DataType = LongType
|
||||
|
||||
override def nullable: Boolean = false
|
||||
|
|
Loading…
Reference in a new issue