[SPARK-34969][SPARK-34906][SQL] Followup for Refactor TreeNode's children handling methods into specialized traits
### What changes were proposed in this pull request? This is a followup for https://github.com/apache/spark/pull/31932. In this PR we: - Introduce the `QuaternaryLike` trait for node types with 4 children. - Specialize more node types - Fix a number of style errors that were introduced in the original PR. ### Why are the changes needed? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? This is a refactoring, passes existing tests. Closes #32065 from dbaliafroozeh/FollowupSPARK-34906. Authored-by: Ali Afroozeh <ali.afroozeh@databricks.com> Signed-off-by: herman <herman@databricks.com>
This commit is contained in:
parent
0aa2c284e4
commit
06c09a79b3
|
@ -28,6 +28,7 @@ import org.apache.spark.sql.Column
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
|
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
|
||||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate}
|
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, TypedImperativeAggregate}
|
||||||
|
import org.apache.spark.sql.catalyst.trees.BinaryLike
|
||||||
import org.apache.spark.sql.functions.lit
|
import org.apache.spark.sql.functions.lit
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
|
@ -348,7 +349,9 @@ private[spark] object SummaryBuilderImpl extends Logging {
|
||||||
weightExpr: Expression,
|
weightExpr: Expression,
|
||||||
mutableAggBufferOffset: Int,
|
mutableAggBufferOffset: Int,
|
||||||
inputAggBufferOffset: Int)
|
inputAggBufferOffset: Int)
|
||||||
extends TypedImperativeAggregate[SummarizerBuffer] with ImplicitCastInputTypes {
|
extends TypedImperativeAggregate[SummarizerBuffer]
|
||||||
|
with ImplicitCastInputTypes
|
||||||
|
with BinaryLike[Expression] {
|
||||||
|
|
||||||
override def eval(state: SummarizerBuffer): Any = {
|
override def eval(state: SummarizerBuffer): Any = {
|
||||||
val metrics = requestedMetrics.map {
|
val metrics = requestedMetrics.map {
|
||||||
|
@ -368,7 +371,8 @@ private[spark] object SummaryBuilderImpl extends Logging {
|
||||||
|
|
||||||
override def inputTypes: Seq[DataType] = vectorUDT :: DoubleType :: Nil
|
override def inputTypes: Seq[DataType] = vectorUDT :: DoubleType :: Nil
|
||||||
|
|
||||||
override def children: Seq[Expression] = featuresExpr :: weightExpr :: Nil
|
override def left: Expression = featuresExpr
|
||||||
|
override def right: Expression = weightExpr
|
||||||
|
|
||||||
override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = {
|
override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = {
|
||||||
val features = vectorUDT.deserialize(featuresExpr.eval(row))
|
val features = vectorUDT.deserialize(featuresExpr.eval(row))
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
|
||||||
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
|
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||||
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TernaryLike, TreeNode, UnaryLike}
|
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike}
|
||||||
import org.apache.spark.sql.catalyst.util.truncatedString
|
import org.apache.spark.sql.catalyst.util.truncatedString
|
||||||
import org.apache.spark.sql.errors.QueryExecutionErrors
|
import org.apache.spark.sql.errors.QueryExecutionErrors
|
||||||
import org.apache.spark.sql.internal.SQLConf
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
|
@ -786,7 +786,7 @@ abstract class TernaryExpression extends Expression with TernaryLike[Expression]
|
||||||
* An expression with four inputs and one output. The output is by default evaluated to null
|
* An expression with four inputs and one output. The output is by default evaluated to null
|
||||||
* if any input is evaluated to null.
|
* if any input is evaluated to null.
|
||||||
*/
|
*/
|
||||||
abstract class QuaternaryExpression extends Expression {
|
abstract class QuaternaryExpression extends Expression with QuaternaryLike[Expression] {
|
||||||
|
|
||||||
override def foldable: Boolean = children.forall(_.foldable)
|
override def foldable: Boolean = children.forall(_.foldable)
|
||||||
|
|
||||||
|
@ -797,14 +797,13 @@ abstract class QuaternaryExpression extends Expression {
|
||||||
* If subclass of QuaternaryExpression override nullable, probably should also override this.
|
* If subclass of QuaternaryExpression override nullable, probably should also override this.
|
||||||
*/
|
*/
|
||||||
override def eval(input: InternalRow): Any = {
|
override def eval(input: InternalRow): Any = {
|
||||||
val exprs = children
|
val value1 = first.eval(input)
|
||||||
val value1 = exprs(0).eval(input)
|
|
||||||
if (value1 != null) {
|
if (value1 != null) {
|
||||||
val value2 = exprs(1).eval(input)
|
val value2 = second.eval(input)
|
||||||
if (value2 != null) {
|
if (value2 != null) {
|
||||||
val value3 = exprs(2).eval(input)
|
val value3 = third.eval(input)
|
||||||
if (value3 != null) {
|
if (value3 != null) {
|
||||||
val value4 = exprs(3).eval(input)
|
val value4 = fourth.eval(input)
|
||||||
if (value4 != null) {
|
if (value4 != null) {
|
||||||
return nullSafeEval(value1, value2, value3, value4)
|
return nullSafeEval(value1, value2, value3, value4)
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{DataType, IntegerType}
|
||||||
* }}}
|
* }}}
|
||||||
*/
|
*/
|
||||||
abstract class PartitionTransformExpression extends Expression with Unevaluable
|
abstract class PartitionTransformExpression extends Expression with Unevaluable
|
||||||
with UnaryLike[Expression] {
|
with UnaryLike[Expression] {
|
||||||
override def nullable: Boolean = true
|
override def nullable: Boolean = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ import org.apache.spark.sql.types._
|
||||||
group = "agg_funcs",
|
group = "agg_funcs",
|
||||||
since = "1.0.0")
|
since = "1.0.0")
|
||||||
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
|
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
|
||||||
with UnaryLike[Expression] {
|
with UnaryLike[Expression] {
|
||||||
|
|
||||||
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")
|
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, Long
|
||||||
group = "agg_funcs",
|
group = "agg_funcs",
|
||||||
since = "3.0.0")
|
since = "3.0.0")
|
||||||
case class CountIf(predicate: Expression) extends UnevaluableAggregate with ImplicitCastInputTypes
|
case class CountIf(predicate: Expression) extends UnevaluableAggregate with ImplicitCastInputTypes
|
||||||
with UnaryLike[Expression] {
|
with UnaryLike[Expression] {
|
||||||
|
|
||||||
override def prettyName: String = "count_if"
|
override def prettyName: String = "count_if"
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
|
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
|
||||||
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal}
|
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal}
|
||||||
|
import org.apache.spark.sql.catalyst.trees.QuaternaryLike
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.unsafe.types.UTF8String
|
import org.apache.spark.unsafe.types.UTF8String
|
||||||
import org.apache.spark.util.sketch.CountMinSketch
|
import org.apache.spark.util.sketch.CountMinSketch
|
||||||
|
@ -60,7 +61,9 @@ case class CountMinSketchAgg(
|
||||||
seedExpression: Expression,
|
seedExpression: Expression,
|
||||||
override val mutableAggBufferOffset: Int,
|
override val mutableAggBufferOffset: Int,
|
||||||
override val inputAggBufferOffset: Int)
|
override val inputAggBufferOffset: Int)
|
||||||
extends TypedImperativeAggregate[CountMinSketch] with ExpectsInputTypes {
|
extends TypedImperativeAggregate[CountMinSketch]
|
||||||
|
with ExpectsInputTypes
|
||||||
|
with QuaternaryLike[Expression] {
|
||||||
|
|
||||||
def this(
|
def this(
|
||||||
child: Expression,
|
child: Expression,
|
||||||
|
@ -145,8 +148,10 @@ case class CountMinSketchAgg(
|
||||||
override def defaultResult: Option[Literal] =
|
override def defaultResult: Option[Literal] =
|
||||||
Option(Literal.create(eval(createAggregationBuffer()), dataType))
|
Option(Literal.create(eval(createAggregationBuffer()), dataType))
|
||||||
|
|
||||||
override def children: Seq[Expression] =
|
|
||||||
Seq(child, epsExpression, confidenceExpression, seedExpression)
|
|
||||||
|
|
||||||
override def prettyName: String = "count_min_sketch"
|
override def prettyName: String = "count_min_sketch"
|
||||||
|
|
||||||
|
override def first: Expression = child
|
||||||
|
override def second: Expression = epsExpression
|
||||||
|
override def third: Expression = confidenceExpression
|
||||||
|
override def fourth: Expression = seedExpression
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,11 +27,9 @@ import org.apache.spark.sql.types._
|
||||||
* Compute the covariance between two expressions.
|
* Compute the covariance between two expressions.
|
||||||
* When applied on empty data (i.e., count is zero), it returns NULL.
|
* When applied on empty data (i.e., count is zero), it returns NULL.
|
||||||
*/
|
*/
|
||||||
abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Boolean)
|
abstract class Covariance(val left: Expression, val right: Expression, nullOnDivideByZero: Boolean)
|
||||||
extends DeclarativeAggregate with ImplicitCastInputTypes with BinaryLike[Expression] {
|
extends DeclarativeAggregate with ImplicitCastInputTypes with BinaryLike[Expression] {
|
||||||
|
|
||||||
override def left: Expression = x
|
|
||||||
override def right: Expression = y
|
|
||||||
override def nullable: Boolean = true
|
override def nullable: Boolean = true
|
||||||
override def dataType: DataType = DoubleType
|
override def dataType: DataType = DoubleType
|
||||||
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
|
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
|
||||||
|
@ -72,14 +70,14 @@ abstract class Covariance(x: Expression, y: Expression, nullOnDivideByZero: Bool
|
||||||
|
|
||||||
protected def updateExpressionsDef: Seq[Expression] = {
|
protected def updateExpressionsDef: Seq[Expression] = {
|
||||||
val newN = n + 1.0
|
val newN = n + 1.0
|
||||||
val dx = x - xAvg
|
val dx = left - xAvg
|
||||||
val dy = y - yAvg
|
val dy = right - yAvg
|
||||||
val dyN = dy / newN
|
val dyN = dy / newN
|
||||||
val newXAvg = xAvg + dx / newN
|
val newXAvg = xAvg + dx / newN
|
||||||
val newYAvg = yAvg + dyN
|
val newYAvg = yAvg + dyN
|
||||||
val newCk = ck + dx * (y - newYAvg)
|
val newCk = ck + dx * (right - newYAvg)
|
||||||
|
|
||||||
val isNull = x.isNull || y.isNull
|
val isNull = left.isNull || right.isNull
|
||||||
Seq(
|
Seq(
|
||||||
If(isNull, n, newN),
|
If(isNull, n, newN),
|
||||||
If(isNull, xAvg, newXAvg),
|
If(isNull, xAvg, newXAvg),
|
||||||
|
|
|
@ -39,7 +39,7 @@ import org.apache.spark.sql.types._
|
||||||
group = "agg_funcs",
|
group = "agg_funcs",
|
||||||
since = "1.0.0")
|
since = "1.0.0")
|
||||||
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
|
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
|
||||||
with UnaryLike[Expression] {
|
with UnaryLike[Expression] {
|
||||||
|
|
||||||
override def nullable: Boolean = true
|
override def nullable: Boolean = true
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.trees.UnaryLike
|
||||||
import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegralType}
|
import org.apache.spark.sql.types.{AbstractDataType, DataType, IntegralType}
|
||||||
|
|
||||||
abstract class BitAggregate extends DeclarativeAggregate with ExpectsInputTypes
|
abstract class BitAggregate extends DeclarativeAggregate with ExpectsInputTypes
|
||||||
with UnaryLike[Expression] {
|
with UnaryLike[Expression] {
|
||||||
|
|
||||||
val child: Expression
|
val child: Expression
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ import org.apache.spark.sql.types._
|
||||||
* can cause GC paused and eventually OutOfMemory Errors.
|
* can cause GC paused and eventually OutOfMemory Errors.
|
||||||
*/
|
*/
|
||||||
abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T]
|
abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T]
|
||||||
with UnaryLike[Expression] {
|
with UnaryLike[Expression] {
|
||||||
|
|
||||||
val child: Expression
|
val child: Expression
|
||||||
|
|
||||||
|
|
|
@ -175,7 +175,7 @@ object GroupingSets {
|
||||||
group = "agg_funcs")
|
group = "agg_funcs")
|
||||||
// scalastyle:on line.size.limit line.contains.tab
|
// scalastyle:on line.size.limit line.contains.tab
|
||||||
case class Grouping(child: Expression) extends Expression with Unevaluable
|
case class Grouping(child: Expression) extends Expression with Unevaluable
|
||||||
with UnaryLike[Expression] {
|
with UnaryLike[Expression] {
|
||||||
@transient
|
@transient
|
||||||
override lazy val references: AttributeSet =
|
override lazy val references: AttributeSet =
|
||||||
AttributeSet(VirtualColumn.groupingIdAttribute :: Nil)
|
AttributeSet(VirtualColumn.groupingIdAttribute :: Nil)
|
||||||
|
|
|
@ -25,6 +25,7 @@ import scala.collection.mutable
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException}
|
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException}
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||||
|
import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike}
|
||||||
import org.apache.spark.sql.catalyst.util._
|
import org.apache.spark.sql.catalyst.util._
|
||||||
import org.apache.spark.sql.errors.QueryExecutionErrors
|
import org.apache.spark.sql.errors.QueryExecutionErrors
|
||||||
import org.apache.spark.sql.internal.SQLConf
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
|
@ -119,8 +120,6 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes {
|
||||||
|
|
||||||
override def nullable: Boolean = arguments.exists(_.nullable)
|
override def nullable: Boolean = arguments.exists(_.nullable)
|
||||||
|
|
||||||
override def children: Seq[Expression] = arguments ++ functions
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Arguments of the higher ordered function.
|
* Arguments of the higher ordered function.
|
||||||
*/
|
*/
|
||||||
|
@ -182,7 +181,7 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes {
|
||||||
/**
|
/**
|
||||||
* Trait for functions having as input one argument and one function.
|
* Trait for functions having as input one argument and one function.
|
||||||
*/
|
*/
|
||||||
trait SimpleHigherOrderFunction extends HigherOrderFunction {
|
trait SimpleHigherOrderFunction extends HigherOrderFunction with BinaryLike[Expression] {
|
||||||
|
|
||||||
def argument: Expression
|
def argument: Expression
|
||||||
|
|
||||||
|
@ -202,6 +201,9 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction {
|
||||||
|
|
||||||
def functionForEval: Expression = functionsForEval.head
|
def functionForEval: Expression = functionsForEval.head
|
||||||
|
|
||||||
|
override def left: Expression = argument
|
||||||
|
override def right: Expression = function
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Called by [[eval]]. If a subclass keeps the default nullability, it can override this method
|
* Called by [[eval]]. If a subclass keeps the default nullability, it can override this method
|
||||||
* in order to save null-check code.
|
* in order to save null-check code.
|
||||||
|
@ -694,7 +696,7 @@ case class ArrayAggregate(
|
||||||
zero: Expression,
|
zero: Expression,
|
||||||
merge: Expression,
|
merge: Expression,
|
||||||
finish: Expression)
|
finish: Expression)
|
||||||
extends HigherOrderFunction with CodegenFallback {
|
extends HigherOrderFunction with CodegenFallback with QuaternaryLike[Expression] {
|
||||||
|
|
||||||
def this(argument: Expression, zero: Expression, merge: Expression) = {
|
def this(argument: Expression, zero: Expression, merge: Expression) = {
|
||||||
this(argument, zero, merge, LambdaFunction.identity)
|
this(argument, zero, merge, LambdaFunction.identity)
|
||||||
|
@ -760,6 +762,11 @@ case class ArrayAggregate(
|
||||||
}
|
}
|
||||||
|
|
||||||
override def prettyName: String = "aggregate"
|
override def prettyName: String = "aggregate"
|
||||||
|
|
||||||
|
override def first: Expression = argument
|
||||||
|
override def second: Expression = zero
|
||||||
|
override def third: Expression = merge
|
||||||
|
override def fourth: Expression = finish
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -884,7 +891,7 @@ case class TransformValues(
|
||||||
since = "3.0.0",
|
since = "3.0.0",
|
||||||
group = "lambda_funcs")
|
group = "lambda_funcs")
|
||||||
case class MapZipWith(left: Expression, right: Expression, function: Expression)
|
case class MapZipWith(left: Expression, right: Expression, function: Expression)
|
||||||
extends HigherOrderFunction with CodegenFallback {
|
extends HigherOrderFunction with CodegenFallback with TernaryLike[Expression] {
|
||||||
|
|
||||||
def functionForEval: Expression = functionsForEval.head
|
def functionForEval: Expression = functionsForEval.head
|
||||||
|
|
||||||
|
@ -1045,6 +1052,10 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def prettyName: String = "map_zip_with"
|
override def prettyName: String = "map_zip_with"
|
||||||
|
|
||||||
|
override def first: Expression = left
|
||||||
|
override def second: Expression = right
|
||||||
|
override def third: Expression = function
|
||||||
}
|
}
|
||||||
|
|
||||||
// scalastyle:off line.size.limit
|
// scalastyle:off line.size.limit
|
||||||
|
@ -1063,7 +1074,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
|
||||||
group = "lambda_funcs")
|
group = "lambda_funcs")
|
||||||
// scalastyle:on line.size.limit
|
// scalastyle:on line.size.limit
|
||||||
case class ZipWith(left: Expression, right: Expression, function: Expression)
|
case class ZipWith(left: Expression, right: Expression, function: Expression)
|
||||||
extends HigherOrderFunction with CodegenFallback {
|
extends HigherOrderFunction with CodegenFallback with TernaryLike[Expression] {
|
||||||
|
|
||||||
def functionForEval: Expression = functionsForEval.head
|
def functionForEval: Expression = functionsForEval.head
|
||||||
|
|
||||||
|
@ -1071,7 +1082,7 @@ case class ZipWith(left: Expression, right: Expression, function: Expression)
|
||||||
|
|
||||||
override def argumentTypes: Seq[AbstractDataType] = ArrayType :: ArrayType :: Nil
|
override def argumentTypes: Seq[AbstractDataType] = ArrayType :: ArrayType :: Nil
|
||||||
|
|
||||||
override def functions: Seq[Expression] = List(function)
|
override def functions: Seq[Expression] = function :: Nil
|
||||||
|
|
||||||
override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil
|
override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil
|
||||||
|
|
||||||
|
@ -1121,4 +1132,8 @@ case class ZipWith(left: Expression, right: Expression, function: Expression)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def prettyName: String = "zip_with"
|
override def prettyName: String = "zip_with"
|
||||||
|
|
||||||
|
override def first: Expression = left
|
||||||
|
override def second: Expression = right
|
||||||
|
override def third: Expression = function
|
||||||
}
|
}
|
||||||
|
|
|
@ -1488,7 +1488,6 @@ case class WidthBucket(
|
||||||
numBucket: Expression)
|
numBucket: Expression)
|
||||||
extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant {
|
extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant {
|
||||||
|
|
||||||
override def children: Seq[Expression] = Seq(value, minValue, maxValue, numBucket)
|
|
||||||
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType, DoubleType, LongType)
|
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType, DoubleType, LongType)
|
||||||
override def dataType: DataType = LongType
|
override def dataType: DataType = LongType
|
||||||
override def nullable: Boolean = true
|
override def nullable: Boolean = true
|
||||||
|
@ -1507,4 +1506,9 @@ case class WidthBucket(
|
||||||
"org.apache.spark.sql.catalyst.expressions.WidthBucket" +
|
"org.apache.spark.sql.catalyst.expressions.WidthBucket" +
|
||||||
s".computeBucketNumber($input, $min, $max, $numBucket)")
|
s".computeBucketNumber($input, $min, $max, $numBucket)")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def first: Expression = value
|
||||||
|
override def second: Expression = minValue
|
||||||
|
override def third: Expression = maxValue
|
||||||
|
override def fourth: Expression = numBucket
|
||||||
}
|
}
|
||||||
|
|
|
@ -562,7 +562,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
|
||||||
override def dataType: DataType = StringType
|
override def dataType: DataType = StringType
|
||||||
override def inputTypes: Seq[AbstractDataType] =
|
override def inputTypes: Seq[AbstractDataType] =
|
||||||
Seq(StringType, StringType, StringType, IntegerType)
|
Seq(StringType, StringType, StringType, IntegerType)
|
||||||
override def children: Seq[Expression] = subject :: regexp :: rep :: pos :: Nil
|
|
||||||
override def prettyName: String = "regexp_replace"
|
override def prettyName: String = "regexp_replace"
|
||||||
|
|
||||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||||
|
@ -618,6 +617,11 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
|
||||||
"""
|
"""
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def first: Expression = subject
|
||||||
|
override def second: Expression = regexp
|
||||||
|
override def third: Expression = rep
|
||||||
|
override def fourth: Expression = pos
|
||||||
}
|
}
|
||||||
|
|
||||||
object RegExpReplace {
|
object RegExpReplace {
|
||||||
|
|
|
@ -593,8 +593,6 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len:
|
||||||
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType),
|
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType),
|
||||||
TypeCollection(StringType, BinaryType), IntegerType, IntegerType)
|
TypeCollection(StringType, BinaryType), IntegerType, IntegerType)
|
||||||
|
|
||||||
override def children: Seq[Expression] = input :: replace :: pos :: len :: Nil
|
|
||||||
|
|
||||||
override def checkInputDataTypes(): TypeCheckResult = {
|
override def checkInputDataTypes(): TypeCheckResult = {
|
||||||
val inputTypeCheck = super.checkInputDataTypes()
|
val inputTypeCheck = super.checkInputDataTypes()
|
||||||
if (inputTypeCheck.isSuccess) {
|
if (inputTypeCheck.isSuccess) {
|
||||||
|
@ -631,6 +629,11 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len:
|
||||||
"org.apache.spark.sql.catalyst.expressions.Overlay" +
|
"org.apache.spark.sql.catalyst.expressions.Overlay" +
|
||||||
s".calculate($input, $replace, $pos, $len);")
|
s".calculate($input, $replace, $pos, $len);")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override def first: Expression = input
|
||||||
|
override def second: Expression = replace
|
||||||
|
override def third: Expression = pos
|
||||||
|
override def fourth: Expression = len
|
||||||
}
|
}
|
||||||
|
|
||||||
object StringTranslate {
|
object StringTranslate {
|
||||||
|
|
|
@ -742,7 +742,7 @@ case class NthValue(input: Expression, offset: Expression, ignoreNulls: Boolean)
|
||||||
group = "window_funcs")
|
group = "window_funcs")
|
||||||
// scalastyle:on line.size.limit line.contains.tab
|
// scalastyle:on line.size.limit line.contains.tab
|
||||||
case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction
|
case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction
|
||||||
with UnaryLike[Expression] {
|
with UnaryLike[Expression] {
|
||||||
|
|
||||||
def this() = this(Literal(1))
|
def this() = this(Literal(1))
|
||||||
|
|
||||||
|
|
|
@ -431,7 +431,7 @@ case class InsertAction(
|
||||||
}
|
}
|
||||||
|
|
||||||
case class Assignment(key: Expression, value: Expression) extends Expression
|
case class Assignment(key: Expression, value: Expression) extends Expression
|
||||||
with Unevaluable with BinaryLike[Expression] {
|
with Unevaluable with BinaryLike[Expression] {
|
||||||
override def nullable: Boolean = false
|
override def nullable: Boolean = false
|
||||||
override def dataType: DataType = throw new UnresolvedException("nullable")
|
override def dataType: DataType = throw new UnresolvedException("nullable")
|
||||||
override def left: Expression = key
|
override def left: Expression = key
|
||||||
|
|
|
@ -850,3 +850,11 @@ trait TernaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
|
||||||
def third: T
|
def third: T
|
||||||
@transient override final lazy val children: Seq[T] = first :: second :: third :: Nil
|
@transient override final lazy val children: Seq[T] = first :: second :: third :: Nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trait QuaternaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
|
||||||
|
def first: T
|
||||||
|
def second: T
|
||||||
|
def third: T
|
||||||
|
def fourth: T
|
||||||
|
@transient override final lazy val children: Seq[T] = first :: second :: third :: fourth :: Nil
|
||||||
|
}
|
||||||
|
|
|
@ -560,10 +560,9 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case class HugeCodeIntExpression(value: Int) extends Expression {
|
case class HugeCodeIntExpression(value: Int) extends LeafExpression {
|
||||||
override def nullable: Boolean = true
|
override def nullable: Boolean = true
|
||||||
override def dataType: DataType = IntegerType
|
override def dataType: DataType = IntegerType
|
||||||
override def children: Seq[Expression] = Nil
|
|
||||||
override def eval(input: InternalRow): Any = value
|
override def eval(input: InternalRow): Any = value
|
||||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||||
// Assuming HugeMethodLimit to be 8000
|
// Assuming HugeMethodLimit to be 8000
|
||||||
|
|
|
@ -22,7 +22,7 @@ import org.apache.hadoop.conf.Configuration
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.sql.{Row, SparkSession}
|
import org.apache.spark.sql.{Row, SparkSession}
|
||||||
import org.apache.spark.sql.catalyst.expressions.Attribute
|
import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
|
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryCommand}
|
||||||
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
|
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
|
||||||
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
|
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
|
||||||
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
|
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
|
||||||
|
@ -31,7 +31,7 @@ import org.apache.spark.util.SerializableConfiguration
|
||||||
/**
|
/**
|
||||||
* A special `Command` which writes data out and updates metrics.
|
* A special `Command` which writes data out and updates metrics.
|
||||||
*/
|
*/
|
||||||
trait DataWritingCommand extends Command {
|
trait DataWritingCommand extends UnaryCommand {
|
||||||
/**
|
/**
|
||||||
* The input query plan that produces the data to be written.
|
* The input query plan that produces the data to be written.
|
||||||
* IMPORTANT: the input query plan MUST be analyzed, so that we can carry its output columns
|
* IMPORTANT: the input query plan MUST be analyzed, so that we can carry its output columns
|
||||||
|
@ -39,7 +39,7 @@ trait DataWritingCommand extends Command {
|
||||||
*/
|
*/
|
||||||
def query: LogicalPlan
|
def query: LogicalPlan
|
||||||
|
|
||||||
override final def children: Seq[LogicalPlan] = query :: Nil
|
override final def child: LogicalPlan = query
|
||||||
|
|
||||||
// Output column names of the analyzed input query plan.
|
// Output column names of the analyzed input query plan.
|
||||||
def outputColumnNames: Seq[String]
|
def outputColumnNames: Seq[String]
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.apache.spark.sql.{Row, SparkSession}
|
||||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||||
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
|
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
|
||||||
import org.apache.spark.sql.catalyst.plans.QueryPlan
|
import org.apache.spark.sql.catalyst.plans.QueryPlan
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.{LeafCommand, LogicalPlan}
|
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
|
||||||
import org.apache.spark.sql.connector.ExternalCommandRunner
|
import org.apache.spark.sql.connector.ExternalCommandRunner
|
||||||
import org.apache.spark.sql.execution.{ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode}
|
import org.apache.spark.sql.execution.{ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode}
|
||||||
import org.apache.spark.sql.execution.metric.SQLMetric
|
import org.apache.spark.sql.execution.metric.SQLMetric
|
||||||
|
@ -37,7 +37,9 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
||||||
* A logical command that is executed for its side-effects. `RunnableCommand`s are
|
* A logical command that is executed for its side-effects. `RunnableCommand`s are
|
||||||
* wrapped in `ExecutedCommand` during execution.
|
* wrapped in `ExecutedCommand` during execution.
|
||||||
*/
|
*/
|
||||||
trait RunnableCommand extends LeafCommand {
|
trait RunnableCommand extends Command {
|
||||||
|
|
||||||
|
override def children: Seq[LogicalPlan] = Nil
|
||||||
|
|
||||||
// The map used to record the metrics of running the command. This will be passed to
|
// The map used to record the metrics of running the command. This will be passed to
|
||||||
// `ExecutedCommand` during query planning.
|
// `ExecutedCommand` during query planning.
|
||||||
|
|
|
@ -31,7 +31,7 @@ case class AddPartitionExec(
|
||||||
table: SupportsPartitionManagement,
|
table: SupportsPartitionManagement,
|
||||||
partSpecs: Seq[ResolvedPartitionSpec],
|
partSpecs: Seq[ResolvedPartitionSpec],
|
||||||
ignoreIfExists: Boolean,
|
ignoreIfExists: Boolean,
|
||||||
refreshCache: () => Unit) extends V2CommandExec {
|
refreshCache: () => Unit) extends LeafV2CommandExec {
|
||||||
import DataSourceV2Implicits._
|
import DataSourceV2Implicits._
|
||||||
|
|
||||||
override def output: Seq[Attribute] = Seq.empty
|
override def output: Seq[Attribute] = Seq.empty
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.catalog.{NamespaceChange, SupportsNamespac
|
||||||
case class AlterNamespaceSetPropertiesExec(
|
case class AlterNamespaceSetPropertiesExec(
|
||||||
catalog: SupportsNamespaces,
|
catalog: SupportsNamespaces,
|
||||||
namespace: Seq[String],
|
namespace: Seq[String],
|
||||||
props: Map[String, String]) extends V2CommandExec {
|
props: Map[String, String]) extends LeafV2CommandExec {
|
||||||
override protected def run(): Seq[InternalRow] = {
|
override protected def run(): Seq[InternalRow] = {
|
||||||
val changes = props.map{ case (k, v) =>
|
val changes = props.map{ case (k, v) =>
|
||||||
NamespaceChange.setProperty(k, v)
|
NamespaceChange.setProperty(k, v)
|
||||||
|
|
|
@ -28,7 +28,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
|
||||||
case class AlterTableExec(
|
case class AlterTableExec(
|
||||||
catalog: TableCatalog,
|
catalog: TableCatalog,
|
||||||
ident: Identifier,
|
ident: Identifier,
|
||||||
changes: Seq[TableChange]) extends V2CommandExec {
|
changes: Seq[TableChange]) extends LeafV2CommandExec {
|
||||||
|
|
||||||
override def output: Seq[Attribute] = Seq.empty
|
override def output: Seq[Attribute] = Seq.empty
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdenti
|
||||||
import org.apache.spark.sql.execution.command.CreateViewCommand
|
import org.apache.spark.sql.execution.command.CreateViewCommand
|
||||||
import org.apache.spark.storage.StorageLevel
|
import org.apache.spark.storage.StorageLevel
|
||||||
|
|
||||||
trait BaseCacheTableExec extends V2CommandExec {
|
trait BaseCacheTableExec extends LeafV2CommandExec {
|
||||||
def relationName: String
|
def relationName: String
|
||||||
def planToCache: LogicalPlan
|
def planToCache: LogicalPlan
|
||||||
def dataFrameForCachedPlan: DataFrame
|
def dataFrameForCachedPlan: DataFrame
|
||||||
|
@ -117,7 +117,7 @@ case class CacheTableAsSelectExec(
|
||||||
|
|
||||||
case class UncacheTableExec(
|
case class UncacheTableExec(
|
||||||
relation: LogicalPlan,
|
relation: LogicalPlan,
|
||||||
cascade: Boolean) extends V2CommandExec {
|
cascade: Boolean) extends LeafV2CommandExec {
|
||||||
override def run(): Seq[InternalRow] = {
|
override def run(): Seq[InternalRow] = {
|
||||||
val sparkSession = sqlContext.sparkSession
|
val sparkSession = sqlContext.sparkSession
|
||||||
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, relation, cascade)
|
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, relation, cascade)
|
||||||
|
|
|
@ -34,7 +34,7 @@ case class CreateNamespaceExec(
|
||||||
namespace: Seq[String],
|
namespace: Seq[String],
|
||||||
ifNotExists: Boolean,
|
ifNotExists: Boolean,
|
||||||
private var properties: Map[String, String])
|
private var properties: Map[String, String])
|
||||||
extends V2CommandExec {
|
extends LeafV2CommandExec {
|
||||||
override protected def run(): Seq[InternalRow] = {
|
override protected def run(): Seq[InternalRow] = {
|
||||||
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
|
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
|
||||||
import org.apache.spark.sql.connector.catalog.SupportsNamespaces._
|
import org.apache.spark.sql.connector.catalog.SupportsNamespaces._
|
||||||
|
|
|
@ -33,7 +33,7 @@ case class CreateTableExec(
|
||||||
tableSchema: StructType,
|
tableSchema: StructType,
|
||||||
partitioning: Seq[Transform],
|
partitioning: Seq[Transform],
|
||||||
tableProperties: Map[String, String],
|
tableProperties: Map[String, String],
|
||||||
ignoreIfExists: Boolean) extends V2CommandExec {
|
ignoreIfExists: Boolean) extends LeafV2CommandExec {
|
||||||
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
|
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
|
||||||
|
|
||||||
override protected def run(): Seq[InternalRow] = {
|
override protected def run(): Seq[InternalRow] = {
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.apache.spark.sql.sources.Filter
|
||||||
case class DeleteFromTableExec(
|
case class DeleteFromTableExec(
|
||||||
table: SupportsDelete,
|
table: SupportsDelete,
|
||||||
condition: Array[Filter],
|
condition: Array[Filter],
|
||||||
refreshCache: () => Unit) extends V2CommandExec {
|
refreshCache: () => Unit) extends LeafV2CommandExec {
|
||||||
|
|
||||||
override protected def run(): Seq[InternalRow] = {
|
override protected def run(): Seq[InternalRow] = {
|
||||||
table.deleteWhere(condition)
|
table.deleteWhere(condition)
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||||
case class DescribeColumnExec(
|
case class DescribeColumnExec(
|
||||||
override val output: Seq[Attribute],
|
override val output: Seq[Attribute],
|
||||||
column: Attribute,
|
column: Attribute,
|
||||||
isExtended: Boolean) extends V2CommandExec {
|
isExtended: Boolean) extends LeafV2CommandExec {
|
||||||
|
|
||||||
override protected def run(): Seq[InternalRow] = {
|
override protected def run(): Seq[InternalRow] = {
|
||||||
val rows = new ArrayBuffer[InternalRow]()
|
val rows = new ArrayBuffer[InternalRow]()
|
||||||
|
|
|
@ -31,7 +31,7 @@ case class DescribeNamespaceExec(
|
||||||
output: Seq[Attribute],
|
output: Seq[Attribute],
|
||||||
catalog: SupportsNamespaces,
|
catalog: SupportsNamespaces,
|
||||||
namespace: Seq[String],
|
namespace: Seq[String],
|
||||||
isExtended: Boolean) extends V2CommandExec {
|
isExtended: Boolean) extends LeafV2CommandExec {
|
||||||
override protected def run(): Seq[InternalRow] = {
|
override protected def run(): Seq[InternalRow] = {
|
||||||
val rows = new ArrayBuffer[InternalRow]()
|
val rows = new ArrayBuffer[InternalRow]()
|
||||||
val ns = namespace.toArray
|
val ns = namespace.toArray
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsMetadataCo
|
||||||
case class DescribeTableExec(
|
case class DescribeTableExec(
|
||||||
output: Seq[Attribute],
|
output: Seq[Attribute],
|
||||||
table: Table,
|
table: Table,
|
||||||
isExtended: Boolean) extends V2CommandExec {
|
isExtended: Boolean) extends LeafV2CommandExec {
|
||||||
override protected def run(): Seq[InternalRow] = {
|
override protected def run(): Seq[InternalRow] = {
|
||||||
val rows = new ArrayBuffer[InternalRow]()
|
val rows = new ArrayBuffer[InternalRow]()
|
||||||
addSchema(rows)
|
addSchema(rows)
|
||||||
|
|
|
@ -30,7 +30,7 @@ case class DropNamespaceExec(
|
||||||
namespace: Seq[String],
|
namespace: Seq[String],
|
||||||
ifExists: Boolean,
|
ifExists: Boolean,
|
||||||
cascade: Boolean)
|
cascade: Boolean)
|
||||||
extends V2CommandExec {
|
extends LeafV2CommandExec {
|
||||||
override protected def run(): Seq[InternalRow] = {
|
override protected def run(): Seq[InternalRow] = {
|
||||||
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
|
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ case class DropPartitionExec(
|
||||||
partSpecs: Seq[ResolvedPartitionSpec],
|
partSpecs: Seq[ResolvedPartitionSpec],
|
||||||
ignoreIfNotExists: Boolean,
|
ignoreIfNotExists: Boolean,
|
||||||
purge: Boolean,
|
purge: Boolean,
|
||||||
refreshCache: () => Unit) extends V2CommandExec {
|
refreshCache: () => Unit) extends LeafV2CommandExec {
|
||||||
import DataSourceV2Implicits._
|
import DataSourceV2Implicits._
|
||||||
|
|
||||||
override def output: Seq[Attribute] = Seq.empty
|
override def output: Seq[Attribute] = Seq.empty
|
||||||
|
|
|
@ -30,7 +30,7 @@ case class DropTableExec(
|
||||||
ident: Identifier,
|
ident: Identifier,
|
||||||
ifExists: Boolean,
|
ifExists: Boolean,
|
||||||
purge: Boolean,
|
purge: Boolean,
|
||||||
invalidateCache: () => Unit) extends V2CommandExec {
|
invalidateCache: () => Unit) extends LeafV2CommandExec {
|
||||||
|
|
||||||
override def run(): Seq[InternalRow] = {
|
override def run(): Seq[InternalRow] = {
|
||||||
if (catalog.tableExists(ident)) {
|
if (catalog.tableExists(ident)) {
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
|
||||||
case class RefreshTableExec(
|
case class RefreshTableExec(
|
||||||
catalog: TableCatalog,
|
catalog: TableCatalog,
|
||||||
ident: Identifier,
|
ident: Identifier,
|
||||||
refreshCache: () => Unit) extends V2CommandExec {
|
refreshCache: () => Unit) extends LeafV2CommandExec {
|
||||||
override protected def run(): Seq[InternalRow] = {
|
override protected def run(): Seq[InternalRow] = {
|
||||||
catalog.invalidateTable(ident)
|
catalog.invalidateTable(ident)
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ case class RenamePartitionExec(
|
||||||
table: SupportsPartitionManagement,
|
table: SupportsPartitionManagement,
|
||||||
from: ResolvedPartitionSpec,
|
from: ResolvedPartitionSpec,
|
||||||
to: ResolvedPartitionSpec,
|
to: ResolvedPartitionSpec,
|
||||||
refreshCache: () => Unit) extends V2CommandExec {
|
refreshCache: () => Unit) extends LeafV2CommandExec {
|
||||||
|
|
||||||
override def output: Seq[Attribute] = Seq.empty
|
override def output: Seq[Attribute] = Seq.empty
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ case class RenameTableExec(
|
||||||
newIdent: Identifier,
|
newIdent: Identifier,
|
||||||
invalidateCache: () => Option[StorageLevel],
|
invalidateCache: () => Option[StorageLevel],
|
||||||
cacheTable: (SparkSession, LogicalPlan, Option[String], StorageLevel) => Unit)
|
cacheTable: (SparkSession, LogicalPlan, Option[String], StorageLevel) => Unit)
|
||||||
extends V2CommandExec {
|
extends LeafV2CommandExec {
|
||||||
|
|
||||||
override def output: Seq[Attribute] = Seq.empty
|
override def output: Seq[Attribute] = Seq.empty
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ case class ReplaceTableExec(
|
||||||
partitioning: Seq[Transform],
|
partitioning: Seq[Transform],
|
||||||
tableProperties: Map[String, String],
|
tableProperties: Map[String, String],
|
||||||
orCreate: Boolean,
|
orCreate: Boolean,
|
||||||
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends V2CommandExec {
|
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends LeafV2CommandExec {
|
||||||
|
|
||||||
override protected def run(): Seq[InternalRow] = {
|
override protected def run(): Seq[InternalRow] = {
|
||||||
if (catalog.tableExists(ident)) {
|
if (catalog.tableExists(ident)) {
|
||||||
|
@ -59,7 +59,7 @@ case class AtomicReplaceTableExec(
|
||||||
partitioning: Seq[Transform],
|
partitioning: Seq[Transform],
|
||||||
tableProperties: Map[String, String],
|
tableProperties: Map[String, String],
|
||||||
orCreate: Boolean,
|
orCreate: Boolean,
|
||||||
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends V2CommandExec {
|
invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends LeafV2CommandExec {
|
||||||
|
|
||||||
override protected def run(): Seq[InternalRow] = {
|
override protected def run(): Seq[InternalRow] = {
|
||||||
if (catalog.tableExists(identifier)) {
|
if (catalog.tableExists(identifier)) {
|
||||||
|
|
|
@ -28,7 +28,7 @@ case class SetCatalogAndNamespaceExec(
|
||||||
catalogManager: CatalogManager,
|
catalogManager: CatalogManager,
|
||||||
catalogName: Option[String],
|
catalogName: Option[String],
|
||||||
namespace: Option[Seq[String]])
|
namespace: Option[Seq[String]])
|
||||||
extends V2CommandExec {
|
extends LeafV2CommandExec {
|
||||||
override protected def run(): Seq[InternalRow] = {
|
override protected def run(): Seq[InternalRow] = {
|
||||||
// The catalog is updated first because CatalogManager resets the current namespace
|
// The catalog is updated first because CatalogManager resets the current namespace
|
||||||
// when the current catalog is set.
|
// when the current catalog is set.
|
||||||
|
|
|
@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.NamespaceHelper
|
||||||
case class ShowCurrentNamespaceExec(
|
case class ShowCurrentNamespaceExec(
|
||||||
output: Seq[Attribute],
|
output: Seq[Attribute],
|
||||||
catalogManager: CatalogManager)
|
catalogManager: CatalogManager)
|
||||||
extends V2CommandExec {
|
extends LeafV2CommandExec {
|
||||||
override protected def run(): Seq[InternalRow] = {
|
override protected def run(): Seq[InternalRow] = {
|
||||||
Seq(toCatalystRow(catalogManager.currentCatalog.name, catalogManager.currentNamespace.quoted))
|
Seq(toCatalystRow(catalogManager.currentCatalog.name, catalogManager.currentNamespace.quoted))
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Table}
|
||||||
case class ShowTablePropertiesExec(
|
case class ShowTablePropertiesExec(
|
||||||
output: Seq[Attribute],
|
output: Seq[Attribute],
|
||||||
catalogTable: Table,
|
catalogTable: Table,
|
||||||
propertyKey: Option[String]) extends V2CommandExec {
|
propertyKey: Option[String]) extends LeafV2CommandExec {
|
||||||
|
|
||||||
override protected def run(): Seq[InternalRow] = {
|
override protected def run(): Seq[InternalRow] = {
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
|
@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsAtomicPartitionManagement
|
||||||
case class TruncatePartitionExec(
|
case class TruncatePartitionExec(
|
||||||
table: SupportsPartitionManagement,
|
table: SupportsPartitionManagement,
|
||||||
partSpec: ResolvedPartitionSpec,
|
partSpec: ResolvedPartitionSpec,
|
||||||
refreshCache: () => Unit) extends V2CommandExec {
|
refreshCache: () => Unit) extends LeafV2CommandExec {
|
||||||
|
|
||||||
override def output: Seq[Attribute] = Seq.empty
|
override def output: Seq[Attribute] = Seq.empty
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.catalog.TruncatableTable
|
||||||
*/
|
*/
|
||||||
case class TruncateTableExec(
|
case class TruncateTableExec(
|
||||||
table: TruncatableTable,
|
table: TruncatableTable,
|
||||||
refreshCache: () => Unit) extends V2CommandExec {
|
refreshCache: () => Unit) extends LeafV2CommandExec {
|
||||||
|
|
||||||
override def output: Seq[Attribute] = Seq.empty
|
override def output: Seq[Attribute] = Seq.empty
|
||||||
|
|
||||||
|
|
|
@ -55,9 +55,8 @@ case class OverwriteByExpressionExecV1(
|
||||||
write: V1Write) extends V1FallbackWriters
|
write: V1Write) extends V1FallbackWriters
|
||||||
|
|
||||||
/** Some helper interfaces that use V2 write semantics through the V1 writer interface. */
|
/** Some helper interfaces that use V2 write semantics through the V1 writer interface. */
|
||||||
sealed trait V1FallbackWriters extends V2CommandExec with SupportsV1Write {
|
sealed trait V1FallbackWriters extends LeafV2CommandExec with SupportsV1Write {
|
||||||
override def output: Seq[Attribute] = Nil
|
override def output: Seq[Attribute] = Nil
|
||||||
override final def children: Seq[SparkPlan] = Nil
|
|
||||||
|
|
||||||
def table: SupportsWrite
|
def table: SupportsWrite
|
||||||
def refreshCache: () => Unit
|
def refreshCache: () => Unit
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
import org.apache.spark.sql.catalyst.encoders.RowEncoder
|
||||||
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, GenericRowWithSchema}
|
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, GenericRowWithSchema}
|
||||||
|
import org.apache.spark.sql.catalyst.trees.LeafLike
|
||||||
import org.apache.spark.sql.execution.SparkPlan
|
import org.apache.spark.sql.execution.SparkPlan
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.StructType
|
||||||
|
|
||||||
|
@ -57,8 +58,6 @@ abstract class V2CommandExec extends SparkPlan {
|
||||||
sqlContext.sparkContext.parallelize(result, 1)
|
sqlContext.sparkContext.parallelize(result, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def children: Seq[SparkPlan] = Nil
|
|
||||||
|
|
||||||
override def producedAttributes: AttributeSet = outputSet
|
override def producedAttributes: AttributeSet = outputSet
|
||||||
|
|
||||||
protected def toCatalystRow(values: Any*): InternalRow = {
|
protected def toCatalystRow(values: Any*): InternalRow = {
|
||||||
|
@ -69,3 +68,5 @@ abstract class V2CommandExec extends SparkPlan {
|
||||||
RowEncoder(StructType.fromAttributes(output)).resolveAndBind().createSerializer()
|
RowEncoder(StructType.fromAttributes(output)).resolveAndBind().createSerializer()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trait LeafV2CommandExec extends V2CommandExec with LeafLike[SparkPlan]
|
||||||
|
|
|
@ -24,7 +24,7 @@ import scala.util.Random
|
||||||
|
|
||||||
import org.apache.spark.SparkException
|
import org.apache.spark.SparkException
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.expressions.Expression
|
import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
|
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
|
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
|
||||||
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
|
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, UTC}
|
||||||
|
@ -3632,8 +3632,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
|
||||||
}
|
}
|
||||||
|
|
||||||
object DataFrameFunctionsSuite {
|
object DataFrameFunctionsSuite {
|
||||||
case class CodegenFallbackExpr(child: Expression) extends Expression with CodegenFallback {
|
case class CodegenFallbackExpr(child: Expression) extends UnaryExpression with CodegenFallback {
|
||||||
override def children: Seq[Expression] = Seq(child)
|
|
||||||
override def nullable: Boolean = child.nullable
|
override def nullable: Boolean = child.nullable
|
||||||
override def dataType: DataType = child.dataType
|
override def dataType: DataType = child.dataType
|
||||||
override lazy val resolved = true
|
override lazy val resolved = true
|
||||||
|
|
|
@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.expressions.{Expression, Generator}
|
import org.apache.spark.sql.catalyst.expressions.{Expression, Generator}
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||||
|
import org.apache.spark.sql.catalyst.trees.LeafLike
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.test.SharedSparkSession
|
import org.apache.spark.sql.test.SharedSparkSession
|
||||||
import org.apache.spark.sql.types.{IntegerType, StructType}
|
import org.apache.spark.sql.types.{IntegerType, StructType}
|
||||||
|
@ -358,8 +359,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case class EmptyGenerator() extends Generator {
|
case class EmptyGenerator() extends Generator with LeafLike[Expression] {
|
||||||
override def children: Seq[Expression] = Nil
|
|
||||||
override def elementSchema: StructType = new StructType().add("id", IntegerType)
|
override def elementSchema: StructType = new StructType().add("id", IntegerType)
|
||||||
override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty
|
override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty
|
||||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, ImplicitCastInputTypes, SpecificInternalRow}
|
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, ImplicitCastInputTypes, SpecificInternalRow}
|
||||||
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
|
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
|
||||||
|
import org.apache.spark.sql.catalyst.trees.UnaryLike
|
||||||
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
|
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
|
||||||
import org.apache.spark.sql.expressions.Window
|
import org.apache.spark.sql.expressions.Window
|
||||||
import org.apache.spark.sql.functions._
|
import org.apache.spark.sql.functions._
|
||||||
|
@ -232,8 +233,9 @@ object TypedImperativeAggregateSuite {
|
||||||
nullable: Boolean = false,
|
nullable: Boolean = false,
|
||||||
mutableAggBufferOffset: Int = 0,
|
mutableAggBufferOffset: Int = 0,
|
||||||
inputAggBufferOffset: Int = 0)
|
inputAggBufferOffset: Int = 0)
|
||||||
extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes {
|
extends TypedImperativeAggregate[MaxValue]
|
||||||
|
with ImplicitCastInputTypes
|
||||||
|
with UnaryLike[Expression] {
|
||||||
|
|
||||||
override def createAggregationBuffer(): MaxValue = {
|
override def createAggregationBuffer(): MaxValue = {
|
||||||
// Returns Int.MinValue if all inputs are null
|
// Returns Int.MinValue if all inputs are null
|
||||||
|
@ -270,8 +272,6 @@ object TypedImperativeAggregateSuite {
|
||||||
|
|
||||||
override lazy val deterministic: Boolean = true
|
override lazy val deterministic: Boolean = true
|
||||||
|
|
||||||
override def children: Seq[Expression] = Seq(child)
|
|
||||||
|
|
||||||
override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType)
|
override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType)
|
||||||
|
|
||||||
override def dataType: DataType = IntegerType
|
override def dataType: DataType = IntegerType
|
||||||
|
|
|
@ -22,6 +22,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
|
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
|
||||||
|
import org.apache.spark.sql.catalyst.trees.UnaryLike
|
||||||
import org.apache.spark.sql.hive.execution.TestingTypedCount.State
|
import org.apache.spark.sql.hive.execution.TestingTypedCount.State
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
|
@ -32,12 +33,11 @@ case class TestingTypedCount(
|
||||||
child: Expression,
|
child: Expression,
|
||||||
mutableAggBufferOffset: Int = 0,
|
mutableAggBufferOffset: Int = 0,
|
||||||
inputAggBufferOffset: Int = 0)
|
inputAggBufferOffset: Int = 0)
|
||||||
extends TypedImperativeAggregate[TestingTypedCount.State] {
|
extends TypedImperativeAggregate[TestingTypedCount.State]
|
||||||
|
with UnaryLike[Expression] {
|
||||||
|
|
||||||
def this(child: Expression) = this(child, 0, 0)
|
def this(child: Expression) = this(child, 0, 0)
|
||||||
|
|
||||||
override def children: Seq[Expression] = child :: Nil
|
|
||||||
|
|
||||||
override def dataType: DataType = LongType
|
override def dataType: DataType = LongType
|
||||||
|
|
||||||
override def nullable: Boolean = false
|
override def nullable: Boolean = false
|
||||||
|
|
Loading…
Reference in a new issue