[SPARK-29977][SQL] Remove newMutableProjection/newOrdering/newNaturalAscendingOrdering from SparkPlan

### What changes were proposed in this pull request?

This is to refactor `SparkPlan` code; it mainly removed `newMutableProjection`/`newOrdering`/`newNaturalAscendingOrdering` from `SparkPlan`.
The other modifications are listed below;
 - Move `BaseOrdering` from `o.a.s.sqlcatalyst.expressions.codegen.GenerateOrdering.scala` to `o.a.s.sqlcatalyst.expressions.ordering.scala`
 - `RowOrdering` extends `CodeGeneratorWithInterpretedFallback ` for `BaseOrdering`
 - Remove the unused variables (`subexpressionEliminationEnabled` and `codeGenFallBack`) from `SparkPlan`

### Why are the changes needed?

For better code/test coverage.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Existing.

Closes #26615 from maropu/RefactorOrdering.

Authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Takeshi Yamamuro 2019-11-21 23:51:12 +08:00 committed by Wenchen Fan
parent 6146dc4562
commit cdcd43cbf2
13 changed files with 55 additions and 70 deletions

View file

@ -29,19 +29,11 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
/**
* Inherits some default implementation for Java from `Ordering[Row]`
*/
class BaseOrdering extends Ordering[InternalRow] {
def compare(a: InternalRow, b: InternalRow): Int = {
throw new UnsupportedOperationException
}
}
/**
* Generates bytecode for an [[Ordering]] of rows for a given set of expressions.
*/
object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalRow]] with Logging {
object GenerateOrdering extends CodeGenerator[Seq[SortOrder], BaseOrdering] with Logging {
protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] =
in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])

View file

@ -19,18 +19,28 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
import org.apache.spark.sql.types._
/**
* A base class for generated/interpreted row ordering.
*/
class BaseOrdering extends Ordering[InternalRow] {
def compare(a: InternalRow, b: InternalRow): Int = {
throw new UnsupportedOperationException
}
}
/**
* An interpreted row ordering comparator.
*/
class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {
class InterpretedOrdering(ordering: Seq[SortOrder]) extends BaseOrdering {
def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(bindReferences(ordering, inputSchema))
def compare(a: InternalRow, b: InternalRow): Int = {
override def compare(a: InternalRow, b: InternalRow): Int = {
var i = 0
val size = ordering.size
while (i < size) {
@ -67,7 +77,7 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow
}
i += 1
}
return 0
0
}
}
@ -83,7 +93,7 @@ object InterpretedOrdering {
}
}
object RowOrdering {
object RowOrdering extends CodeGeneratorWithInterpretedFallback[Seq[SortOrder], BaseOrdering] {
/**
* Returns true iff the data type can be ordered (i.e. can be sorted).
@ -102,4 +112,26 @@ object RowOrdering {
* Returns true iff outputs from the expressions can be ordered.
*/
def isOrderable(exprs: Seq[Expression]): Boolean = exprs.forall(e => isOrderable(e.dataType))
override protected def createCodeGeneratedObject(in: Seq[SortOrder]): BaseOrdering = {
GenerateOrdering.generate(in)
}
override protected def createInterpretedObject(in: Seq[SortOrder]): BaseOrdering = {
new InterpretedOrdering(in)
}
def create(order: Seq[SortOrder], inputSchema: Seq[Attribute]): BaseOrdering = {
createObject(bindReferences(order, inputSchema))
}
/**
* Creates a row ordering for the given schema, in natural ascending order.
*/
def createNaturalAscendingOrdering(dataTypes: Seq[DataType]): BaseOrdering = {
val order: Seq[SortOrder] = dataTypes.zipWithIndex.map {
case (dt, index) => SortOrder(BoundReference(index, dt, nullable = true), Ascending)
}
create(order, Seq.empty)
}
}

View file

@ -29,7 +29,7 @@ import org.apache.spark.internal.config.package$;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
import org.apache.spark.sql.catalyst.expressions.BaseOrdering;
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.BlockManager;

View file

@ -71,7 +71,7 @@ case class SortExec(
* should make it public.
*/
def createSorter(): UnsafeExternalRowSorter = {
val ordering = newOrdering(sortOrder, output)
val ordering = RowOrdering.create(sortOrder, output)
// The comparator for comparing prefix
val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)

View file

@ -22,9 +22,6 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.ArrayBuffer
import org.codehaus.commons.compiler.CompileException
import org.codehaus.janino.InternalCompilerException
import org.apache.spark.{broadcast, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
@ -32,13 +29,11 @@ import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.ColumnarBatch
object SparkPlan {
@ -72,16 +67,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
val id: Int = SparkPlan.newPlanId()
// sqlContext will be null when SparkPlan nodes are created without the active sessions.
val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) {
sqlContext.conf.subexpressionEliminationEnabled
} else {
false
}
// whether we should fallback when hitting compilation errors caused by codegen
private val codeGenFallBack = (sqlContext == null) || sqlContext.conf.codegenFallback
/**
* Return true if this stage of the plan supports columnar execution.
*/
@ -462,29 +447,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
buf.toArray
}
protected def newMutableProjection(
expressions: Seq[Expression],
inputSchema: Seq[Attribute],
useSubexprElimination: Boolean = false): MutableProjection = {
log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema")
MutableProjection.create(expressions, inputSchema)
}
protected def newOrdering(
order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow] = {
GenerateOrdering.generate(order, inputSchema)
}
/**
* Creates a row ordering for the given schema, in natural ascending order.
*/
protected def newNaturalAscendingOrdering(dataTypes: Seq[DataType]): Ordering[InternalRow] = {
val order: Seq[SortOrder] = dataTypes.zipWithIndex.map {
case (dt, index) => SortOrder(BoundReference(index, dt, nullable = true), Ascending)
}
newOrdering(order, Seq.empty)
}
/**
* Cleans up the resources used by the physical operator (if any). In general, all the resources
* should be cleaned up when the task finishes but operators like SortMergeJoinExec and LimitExec

View file

@ -126,7 +126,7 @@ case class HashAggregateExec(
initialInputBufferOffset,
resultExpressions,
(expressions, inputSchema) =>
newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
MutableProjection.create(expressions, inputSchema),
child.output,
iter,
testFallbackStartsAt,
@ -486,10 +486,9 @@ case class HashAggregateExec(
// Create a MutableProjection to merge the rows of same key together
val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
val mergeProjection = newMutableProjection(
val mergeProjection = MutableProjection.create(
mergeExpr,
aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
subexpressionEliminationEnabled)
aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes))
val joinedRow = new JoinedRow()
var currentKey: UnsafeRow = null

View file

@ -22,7 +22,7 @@ import org.apache.spark.internal.{config, Logging}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.{BaseOrdering, GenerateOrdering}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf

View file

@ -122,7 +122,7 @@ case class ObjectHashAggregateExec(
initialInputBufferOffset,
resultExpressions,
(expressions, inputSchema) =>
newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
MutableProjection.create(expressions, inputSchema),
child.output,
iter,
fallbackCountThreshold,

View file

@ -93,7 +93,7 @@ case class SortAggregateExec(
initialInputBufferOffset,
resultExpressions,
(expressions, inputSchema) =>
newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
MutableProjection.create(expressions, inputSchema),
numOutputRows)
if (!hasInput && groupingExpressions.isEmpty) {
// There is no input and there is no grouping expressions.

View file

@ -175,7 +175,7 @@ case class SortMergeJoinExec(
}
// An ordering that can be used to compare keys from both sides.
val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
val keyOrdering = RowOrdering.createNaturalAscendingOrdering(leftKeys.map(_.dataType))
val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
joinType match {

View file

@ -113,7 +113,7 @@ abstract class EvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute],
}
}.toArray
}.toArray
val projection = newMutableProjection(allInputs, child.output)
val projection = MutableProjection.create(allInputs, child.output)
val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
StructField(s"_$i", dt)
})

View file

@ -73,7 +73,7 @@ abstract class WindowExecBase(
RowBoundOrdering(offset)
case (RangeFrame, CurrentRow) =>
val ordering = newOrdering(orderSpec, child.output)
val ordering = RowOrdering.create(orderSpec, child.output)
RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection)
case (RangeFrame, offset: Expression) if orderSpec.size == 1 =>
@ -82,7 +82,7 @@ abstract class WindowExecBase(
val expr = sortExpr.child
// Create the projection which returns the current 'value'.
val current = newMutableProjection(expr :: Nil, child.output)
val current = MutableProjection.create(expr :: Nil, child.output)
// Flip the sign of the offset when processing the order is descending
val boundOffset = sortExpr.direction match {
@ -97,13 +97,13 @@ abstract class WindowExecBase(
TimeAdd(expr, boundOffset, Some(timeZone))
case (a, b) if a == b => Add(expr, boundOffset)
}
val bound = newMutableProjection(boundExpr :: Nil, child.output)
val bound = MutableProjection.create(boundExpr :: Nil, child.output)
// Construct the ordering. This is used to compare the result of current value projection
// to the result of bound value projection. This is done manually because we want to use
// Code Generation (if it is enabled).
val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil
val ordering = newOrdering(boundSortExprs, Nil)
val ordering = RowOrdering.create(boundSortExprs, Nil)
RangeBoundOrdering(ordering, current, bound)
case (RangeFrame, _) =>
@ -167,7 +167,7 @@ abstract class WindowExecBase(
ordinal,
child.output,
(expressions, schema) =>
newMutableProjection(expressions, schema, subexpressionEliminationEnabled))
MutableProjection.create(expressions, schema))
}
// Create the factory
@ -182,7 +182,7 @@ abstract class WindowExecBase(
functions.map(_.asInstanceOf[OffsetWindowFunction]),
child.output,
(expressions, schema) =>
newMutableProjection(expressions, schema, subexpressionEliminationEnabled),
MutableProjection.create(expressions, schema),
offset)
// Entire Partition Frame.

View file

@ -21,7 +21,7 @@ import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
@ -41,7 +41,7 @@ case class ReferenceSort(
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
child.execute().mapPartitions( { iterator =>
val ordering = newOrdering(sortOrder, child.output)
val ordering = RowOrdering.create(sortOrder, child.output)
val sorter = new ExternalSorter[InternalRow, Null, InternalRow](
TaskContext.get(), ordering = Some(ordering))
sorter.insertAll(iterator.map(r => (r.copy(), null)))